mihaiciorobitca commited on
Commit
fe0de51
·
verified ·
1 Parent(s): c7c90f2

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ComfyUI/comfy/comfy_types/examples/example_nodes.py +28 -0
  2. ComfyUI/comfy/comfy_types/examples/input_options.png +0 -0
  3. ComfyUI/comfy/comfy_types/examples/input_types.png +0 -0
  4. ComfyUI/comfy/comfy_types/examples/required_hint.png +0 -0
  5. ComfyUI/comfy/ldm/ace/attention.py +761 -0
  6. ComfyUI/comfy/ldm/ace/lyric_encoder.py +1067 -0
  7. ComfyUI/comfy/ldm/ace/model.py +385 -0
  8. ComfyUI/comfy/ldm/ace/vae/music_log_mel.py +113 -0
  9. ComfyUI/comfy/ldm/audio/autoencoder.py +276 -0
  10. ComfyUI/comfy/ldm/audio/dit.py +896 -0
  11. ComfyUI/comfy/ldm/audio/embedders.py +108 -0
  12. ComfyUI/comfy/ldm/aura/mmdit.py +498 -0
  13. ComfyUI/comfy/ldm/cascade/common.py +154 -0
  14. ComfyUI/comfy/ldm/cascade/controlnet.py +92 -0
  15. ComfyUI/comfy/ldm/cascade/stage_a.py +259 -0
  16. ComfyUI/comfy/ldm/cascade/stage_b.py +256 -0
  17. ComfyUI/comfy/ldm/cascade/stage_c.py +273 -0
  18. ComfyUI/comfy/ldm/cascade/stage_c_coder.py +98 -0
  19. ComfyUI/comfy/ldm/chroma/layers.py +181 -0
  20. ComfyUI/comfy/ldm/chroma/model.py +270 -0
  21. ComfyUI/comfy/ldm/cosmos/blocks.py +797 -0
  22. ComfyUI/comfy/ldm/cosmos/model.py +512 -0
  23. ComfyUI/comfy/ldm/cosmos/position_embedding.py +207 -0
  24. ComfyUI/comfy/ldm/cosmos/predict2.py +864 -0
  25. ComfyUI/comfy/ldm/cosmos/vae.py +131 -0
  26. ComfyUI/comfy/ldm/flux/controlnet.py +208 -0
  27. ComfyUI/comfy/ldm/flux/layers.py +278 -0
  28. ComfyUI/comfy/ldm/flux/math.py +45 -0
  29. ComfyUI/comfy/ldm/flux/model.py +244 -0
  30. ComfyUI/comfy/ldm/flux/redux.py +25 -0
  31. ComfyUI/comfy/ldm/hidream/model.py +802 -0
  32. ComfyUI/comfy/ldm/hunyuan3d/model.py +135 -0
  33. ComfyUI/comfy/ldm/hunyuan3d/vae.py +587 -0
  34. ComfyUI/comfy/ldm/hunyuan_video/model.py +355 -0
  35. ComfyUI/comfy/ldm/hydit/attn_layers.py +218 -0
  36. ComfyUI/comfy/ldm/hydit/controlnet.py +311 -0
  37. ComfyUI/comfy/ldm/hydit/models.py +417 -0
  38. ComfyUI/comfy/ldm/hydit/poolers.py +36 -0
  39. ComfyUI/comfy/ldm/hydit/posemb_layers.py +224 -0
  40. ComfyUI/comfy/ldm/lightricks/model.py +506 -0
  41. ComfyUI/comfy/ldm/lightricks/symmetric_patchifier.py +117 -0
  42. ComfyUI/comfy/ldm/lightricks/vae/causal_conv3d.py +65 -0
  43. ComfyUI/comfy/ldm/lumina/model.py +622 -0
  44. ComfyUI/comfy/ldm/models/autoencoder.py +231 -0
  45. ComfyUI/comfy/ldm/modules/attention.py +1035 -0
  46. ComfyUI/comfy/ldm/modules/ema.py +80 -0
  47. ComfyUI/comfy/ldm/modules/sub_quadratic_attention.py +275 -0
  48. ComfyUI/comfy/ldm/modules/temporal_ae.py +246 -0
  49. ComfyUI/comfy/ldm/omnigen/omnigen2.py +469 -0
  50. ComfyUI/comfy/ldm/pixart/pixartms.py +256 -0
ComfyUI/comfy/comfy_types/examples/example_nodes.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
2
+ from inspect import cleandoc
3
+
4
+
5
+ class ExampleNode(ComfyNodeABC):
6
+ """An example node that just adds 1 to an input integer.
7
+
8
+ * Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
9
+ * This node is intended as an example for developers only.
10
+ """
11
+
12
+ DESCRIPTION = cleandoc(__doc__)
13
+ CATEGORY = "examples"
14
+
15
+ @classmethod
16
+ def INPUT_TYPES(s) -> InputTypeDict:
17
+ return {
18
+ "required": {
19
+ "input_int": (IO.INT, {"defaultInput": True}),
20
+ }
21
+ }
22
+
23
+ RETURN_TYPES = (IO.INT,)
24
+ RETURN_NAMES = ("input_plus_one",)
25
+ FUNCTION = "execute"
26
+
27
+ def execute(self, input_int: int):
28
+ return (input_int + 1,)
ComfyUI/comfy/comfy_types/examples/input_options.png ADDED
ComfyUI/comfy/comfy_types/examples/input_types.png ADDED
ComfyUI/comfy/comfy_types/examples/required_hint.png ADDED
ComfyUI/comfy/ldm/ace/attention.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Tuple, Union, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ import comfy.model_management
22
+ from comfy.ldm.modules.attention import optimized_attention
23
+
24
+ class Attention(nn.Module):
25
+ def __init__(
26
+ self,
27
+ query_dim: int,
28
+ cross_attention_dim: Optional[int] = None,
29
+ heads: int = 8,
30
+ kv_heads: Optional[int] = None,
31
+ dim_head: int = 64,
32
+ dropout: float = 0.0,
33
+ bias: bool = False,
34
+ qk_norm: Optional[str] = None,
35
+ added_kv_proj_dim: Optional[int] = None,
36
+ added_proj_bias: Optional[bool] = True,
37
+ out_bias: bool = True,
38
+ scale_qk: bool = True,
39
+ only_cross_attention: bool = False,
40
+ eps: float = 1e-5,
41
+ rescale_output_factor: float = 1.0,
42
+ residual_connection: bool = False,
43
+ processor=None,
44
+ out_dim: int = None,
45
+ out_context_dim: int = None,
46
+ context_pre_only=None,
47
+ pre_only=False,
48
+ elementwise_affine: bool = True,
49
+ is_causal: bool = False,
50
+ dtype=None, device=None, operations=None
51
+ ):
52
+ super().__init__()
53
+
54
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
55
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
56
+ self.query_dim = query_dim
57
+ self.use_bias = bias
58
+ self.is_cross_attention = cross_attention_dim is not None
59
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
60
+ self.rescale_output_factor = rescale_output_factor
61
+ self.residual_connection = residual_connection
62
+ self.dropout = dropout
63
+ self.fused_projections = False
64
+ self.out_dim = out_dim if out_dim is not None else query_dim
65
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
66
+ self.context_pre_only = context_pre_only
67
+ self.pre_only = pre_only
68
+ self.is_causal = is_causal
69
+
70
+ self.scale_qk = scale_qk
71
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
72
+
73
+ self.heads = out_dim // dim_head if out_dim is not None else heads
74
+ # for slice_size > 0 the attention score computation
75
+ # is split across the batch axis to save memory
76
+ # You can set slice_size with `set_attention_slice`
77
+ self.sliceable_head_dim = heads
78
+
79
+ self.added_kv_proj_dim = added_kv_proj_dim
80
+ self.only_cross_attention = only_cross_attention
81
+
82
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
83
+ raise ValueError(
84
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
85
+ )
86
+
87
+ self.group_norm = None
88
+ self.spatial_norm = None
89
+
90
+ self.norm_q = None
91
+ self.norm_k = None
92
+
93
+ self.norm_cross = None
94
+ self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
95
+
96
+ if not self.only_cross_attention:
97
+ # only relevant for the `AddedKVProcessor` classes
98
+ self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
99
+ self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
100
+ else:
101
+ self.to_k = None
102
+ self.to_v = None
103
+
104
+ self.added_proj_bias = added_proj_bias
105
+ if self.added_kv_proj_dim is not None:
106
+ self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
107
+ self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
108
+ if self.context_pre_only is not None:
109
+ self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
110
+ else:
111
+ self.add_q_proj = None
112
+ self.add_k_proj = None
113
+ self.add_v_proj = None
114
+
115
+ if not self.pre_only:
116
+ self.to_out = nn.ModuleList([])
117
+ self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
118
+ self.to_out.append(nn.Dropout(dropout))
119
+ else:
120
+ self.to_out = None
121
+
122
+ if self.context_pre_only is not None and not self.context_pre_only:
123
+ self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
124
+ else:
125
+ self.to_add_out = None
126
+
127
+ self.norm_added_q = None
128
+ self.norm_added_k = None
129
+ self.processor = processor
130
+
131
+ def forward(
132
+ self,
133
+ hidden_states: torch.Tensor,
134
+ encoder_hidden_states: Optional[torch.Tensor] = None,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ **cross_attention_kwargs,
137
+ ) -> torch.Tensor:
138
+ return self.processor(
139
+ self,
140
+ hidden_states,
141
+ encoder_hidden_states=encoder_hidden_states,
142
+ attention_mask=attention_mask,
143
+ **cross_attention_kwargs,
144
+ )
145
+
146
+
147
+ class CustomLiteLAProcessor2_0:
148
+ """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
149
+
150
+ def __init__(self):
151
+ self.kernel_func = nn.ReLU(inplace=False)
152
+ self.eps = 1e-15
153
+ self.pad_val = 1.0
154
+
155
+ def apply_rotary_emb(
156
+ self,
157
+ x: torch.Tensor,
158
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
159
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
160
+ """
161
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
162
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
163
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
164
+ tensors contain rotary embeddings and are returned as real tensors.
165
+
166
+ Args:
167
+ x (`torch.Tensor`):
168
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
169
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
170
+
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
173
+ """
174
+ cos, sin = freqs_cis # [S, D]
175
+ cos = cos[None, None]
176
+ sin = sin[None, None]
177
+ cos, sin = cos.to(x.device), sin.to(x.device)
178
+
179
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
180
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
181
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
182
+
183
+ return out
184
+
185
+ def __call__(
186
+ self,
187
+ attn: Attention,
188
+ hidden_states: torch.FloatTensor,
189
+ encoder_hidden_states: torch.FloatTensor = None,
190
+ attention_mask: Optional[torch.FloatTensor] = None,
191
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
192
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
193
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
194
+ *args,
195
+ **kwargs,
196
+ ) -> torch.FloatTensor:
197
+ hidden_states_len = hidden_states.shape[1]
198
+
199
+ input_ndim = hidden_states.ndim
200
+ if input_ndim == 4:
201
+ batch_size, channel, height, width = hidden_states.shape
202
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
203
+ if encoder_hidden_states is not None:
204
+ context_input_ndim = encoder_hidden_states.ndim
205
+ if context_input_ndim == 4:
206
+ batch_size, channel, height, width = encoder_hidden_states.shape
207
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
208
+
209
+ batch_size = hidden_states.shape[0]
210
+
211
+ # `sample` projections.
212
+ dtype = hidden_states.dtype
213
+ query = attn.to_q(hidden_states)
214
+ key = attn.to_k(hidden_states)
215
+ value = attn.to_v(hidden_states)
216
+
217
+ # `context` projections.
218
+ has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
219
+ if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
220
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
221
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
222
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
223
+
224
+ # attention
225
+ if not attn.is_cross_attention:
226
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
227
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
228
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
229
+ else:
230
+ query = hidden_states
231
+ key = encoder_hidden_states
232
+ value = encoder_hidden_states
233
+
234
+ inner_dim = key.shape[-1]
235
+ head_dim = inner_dim // attn.heads
236
+
237
+ query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
238
+ key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
239
+ value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
240
+
241
+ # RoPE需要 [B, H, S, D] 输入
242
+ # 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
243
+ query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
244
+
245
+ # Apply query and key normalization if needed
246
+ if attn.norm_q is not None:
247
+ query = attn.norm_q(query)
248
+ if attn.norm_k is not None:
249
+ key = attn.norm_k(key)
250
+
251
+ # Apply RoPE if needed
252
+ if rotary_freqs_cis is not None:
253
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
254
+ if not attn.is_cross_attention:
255
+ key = self.apply_rotary_emb(key, rotary_freqs_cis)
256
+ elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
257
+ key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
258
+
259
+ # 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
260
+ query = query.permute(0, 1, 3, 2) # [B, H, D, S]
261
+
262
+ if attention_mask is not None:
263
+ # attention_mask: [B, S] -> [B, 1, S, 1]
264
+ attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
265
+ query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
266
+ if not attn.is_cross_attention:
267
+ key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
268
+ value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
269
+
270
+ if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
271
+ encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
272
+ # 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
273
+ key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
274
+ value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
275
+
276
+ query = self.kernel_func(query)
277
+ key = self.kernel_func(key)
278
+
279
+ query, key, value = query.float(), key.float(), value.float()
280
+
281
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
282
+
283
+ vk = torch.matmul(value, key)
284
+
285
+ hidden_states = torch.matmul(vk, query)
286
+
287
+ if hidden_states.dtype in [torch.float16, torch.bfloat16]:
288
+ hidden_states = hidden_states.float()
289
+
290
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
291
+
292
+ hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
293
+
294
+ hidden_states = hidden_states.to(dtype)
295
+ if encoder_hidden_states is not None:
296
+ encoder_hidden_states = encoder_hidden_states.to(dtype)
297
+
298
+ # Split the attention outputs.
299
+ if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
300
+ hidden_states, encoder_hidden_states = (
301
+ hidden_states[:, : hidden_states_len],
302
+ hidden_states[:, hidden_states_len:],
303
+ )
304
+
305
+ # linear proj
306
+ hidden_states = attn.to_out[0](hidden_states)
307
+ # dropout
308
+ hidden_states = attn.to_out[1](hidden_states)
309
+ if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
310
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
311
+
312
+ if input_ndim == 4:
313
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
314
+ if encoder_hidden_states is not None and context_input_ndim == 4:
315
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
316
+
317
+ if torch.get_autocast_gpu_dtype() == torch.float16:
318
+ hidden_states = hidden_states.clip(-65504, 65504)
319
+ if encoder_hidden_states is not None:
320
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
321
+
322
+ return hidden_states, encoder_hidden_states
323
+
324
+
325
+ class CustomerAttnProcessor2_0:
326
+ r"""
327
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
328
+ """
329
+
330
+ def apply_rotary_emb(
331
+ self,
332
+ x: torch.Tensor,
333
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
334
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
335
+ """
336
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
337
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
338
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
339
+ tensors contain rotary embeddings and are returned as real tensors.
340
+
341
+ Args:
342
+ x (`torch.Tensor`):
343
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
344
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
345
+
346
+ Returns:
347
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
348
+ """
349
+ cos, sin = freqs_cis # [S, D]
350
+ cos = cos[None, None]
351
+ sin = sin[None, None]
352
+ cos, sin = cos.to(x.device), sin.to(x.device)
353
+
354
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
355
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
356
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
357
+
358
+ return out
359
+
360
+ def __call__(
361
+ self,
362
+ attn: Attention,
363
+ hidden_states: torch.FloatTensor,
364
+ encoder_hidden_states: torch.FloatTensor = None,
365
+ attention_mask: Optional[torch.FloatTensor] = None,
366
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
367
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
368
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
369
+ *args,
370
+ **kwargs,
371
+ ) -> torch.Tensor:
372
+
373
+ residual = hidden_states
374
+ input_ndim = hidden_states.ndim
375
+
376
+ if input_ndim == 4:
377
+ batch_size, channel, height, width = hidden_states.shape
378
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
379
+
380
+ batch_size, sequence_length, _ = (
381
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
382
+ )
383
+
384
+ has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
385
+
386
+ if attn.group_norm is not None:
387
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
388
+
389
+ query = attn.to_q(hidden_states)
390
+
391
+ if encoder_hidden_states is None:
392
+ encoder_hidden_states = hidden_states
393
+ elif attn.norm_cross:
394
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
395
+
396
+ key = attn.to_k(encoder_hidden_states)
397
+ value = attn.to_v(encoder_hidden_states)
398
+
399
+ inner_dim = key.shape[-1]
400
+ head_dim = inner_dim // attn.heads
401
+
402
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
403
+
404
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
405
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
406
+
407
+ if attn.norm_q is not None:
408
+ query = attn.norm_q(query)
409
+ if attn.norm_k is not None:
410
+ key = attn.norm_k(key)
411
+
412
+ # Apply RoPE if needed
413
+ if rotary_freqs_cis is not None:
414
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
415
+ if not attn.is_cross_attention:
416
+ key = self.apply_rotary_emb(key, rotary_freqs_cis)
417
+ elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
418
+ key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
419
+
420
+ if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
421
+ # attention_mask: N x S1
422
+ # encoder_attention_mask: N x S2
423
+ # cross attention 整合attention_mask和encoder_attention_mask
424
+ combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
425
+ attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
426
+ attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
427
+
428
+ elif not attn.is_cross_attention and attention_mask is not None:
429
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
430
+ # scaled_dot_product_attention expects attention_mask shape to be
431
+ # (batch, heads, source_length, target_length)
432
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
433
+
434
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
435
+ hidden_states = optimized_attention(
436
+ query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
437
+ ).to(query.dtype)
438
+
439
+ # linear proj
440
+ hidden_states = attn.to_out[0](hidden_states)
441
+ # dropout
442
+ hidden_states = attn.to_out[1](hidden_states)
443
+
444
+ if input_ndim == 4:
445
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
446
+
447
+ if attn.residual_connection:
448
+ hidden_states = hidden_states + residual
449
+
450
+ hidden_states = hidden_states / attn.rescale_output_factor
451
+
452
+ return hidden_states
453
+
454
+ def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
455
+ """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
456
+ if isinstance(x, (list, tuple)):
457
+ return list(x)
458
+ return [x for _ in range(repeat_time)]
459
+
460
+
461
+ def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
462
+ """Return tuple with min_len by repeating element at idx_repeat."""
463
+ # convert to list first
464
+ x = val2list(x)
465
+
466
+ # repeat elements if necessary
467
+ if len(x) > 0:
468
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
469
+
470
+ return tuple(x)
471
+
472
+
473
+ def t2i_modulate(x, shift, scale):
474
+ return x * (1 + scale) + shift
475
+
476
+
477
+ def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
478
+ if isinstance(kernel_size, tuple):
479
+ return tuple([get_same_padding(ks) for ks in kernel_size])
480
+ else:
481
+ assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
482
+ return kernel_size // 2
483
+
484
+ class ConvLayer(nn.Module):
485
+ def __init__(
486
+ self,
487
+ in_dim: int,
488
+ out_dim: int,
489
+ kernel_size=3,
490
+ stride=1,
491
+ dilation=1,
492
+ groups=1,
493
+ padding: Union[int, None] = None,
494
+ use_bias=False,
495
+ norm=None,
496
+ act=None,
497
+ dtype=None, device=None, operations=None
498
+ ):
499
+ super().__init__()
500
+ if padding is None:
501
+ padding = get_same_padding(kernel_size)
502
+ padding *= dilation
503
+
504
+ self.in_dim = in_dim
505
+ self.out_dim = out_dim
506
+ self.kernel_size = kernel_size
507
+ self.stride = stride
508
+ self.dilation = dilation
509
+ self.groups = groups
510
+ self.padding = padding
511
+ self.use_bias = use_bias
512
+
513
+ self.conv = operations.Conv1d(
514
+ in_dim,
515
+ out_dim,
516
+ kernel_size=kernel_size,
517
+ stride=stride,
518
+ padding=padding,
519
+ dilation=dilation,
520
+ groups=groups,
521
+ bias=use_bias,
522
+ device=device,
523
+ dtype=dtype
524
+ )
525
+ if norm is not None:
526
+ self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
527
+ else:
528
+ self.norm = None
529
+ if act is not None:
530
+ self.act = nn.SiLU(inplace=True)
531
+ else:
532
+ self.act = None
533
+
534
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
535
+ x = self.conv(x)
536
+ if self.norm:
537
+ x = self.norm(x)
538
+ if self.act:
539
+ x = self.act(x)
540
+ return x
541
+
542
+
543
+ class GLUMBConv(nn.Module):
544
+ def __init__(
545
+ self,
546
+ in_features: int,
547
+ hidden_features: int,
548
+ out_feature=None,
549
+ kernel_size=3,
550
+ stride=1,
551
+ padding: Union[int, None] = None,
552
+ use_bias=False,
553
+ norm=(None, None, None),
554
+ act=("silu", "silu", None),
555
+ dilation=1,
556
+ dtype=None, device=None, operations=None
557
+ ):
558
+ out_feature = out_feature or in_features
559
+ super().__init__()
560
+ use_bias = val2tuple(use_bias, 3)
561
+ norm = val2tuple(norm, 3)
562
+ act = val2tuple(act, 3)
563
+
564
+ self.glu_act = nn.SiLU(inplace=False)
565
+ self.inverted_conv = ConvLayer(
566
+ in_features,
567
+ hidden_features * 2,
568
+ 1,
569
+ use_bias=use_bias[0],
570
+ norm=norm[0],
571
+ act=act[0],
572
+ dtype=dtype,
573
+ device=device,
574
+ operations=operations,
575
+ )
576
+ self.depth_conv = ConvLayer(
577
+ hidden_features * 2,
578
+ hidden_features * 2,
579
+ kernel_size,
580
+ stride=stride,
581
+ groups=hidden_features * 2,
582
+ padding=padding,
583
+ use_bias=use_bias[1],
584
+ norm=norm[1],
585
+ act=None,
586
+ dilation=dilation,
587
+ dtype=dtype,
588
+ device=device,
589
+ operations=operations,
590
+ )
591
+ self.point_conv = ConvLayer(
592
+ hidden_features,
593
+ out_feature,
594
+ 1,
595
+ use_bias=use_bias[2],
596
+ norm=norm[2],
597
+ act=act[2],
598
+ dtype=dtype,
599
+ device=device,
600
+ operations=operations,
601
+ )
602
+
603
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
604
+ x = x.transpose(1, 2)
605
+ x = self.inverted_conv(x)
606
+ x = self.depth_conv(x)
607
+
608
+ x, gate = torch.chunk(x, 2, dim=1)
609
+ gate = self.glu_act(gate)
610
+ x = x * gate
611
+
612
+ x = self.point_conv(x)
613
+ x = x.transpose(1, 2)
614
+
615
+ return x
616
+
617
+
618
+ class LinearTransformerBlock(nn.Module):
619
+ """
620
+ A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
621
+ """
622
+ def __init__(
623
+ self,
624
+ dim,
625
+ num_attention_heads,
626
+ attention_head_dim,
627
+ use_adaln_single=True,
628
+ cross_attention_dim=None,
629
+ added_kv_proj_dim=None,
630
+ context_pre_only=False,
631
+ mlp_ratio=4.0,
632
+ add_cross_attention=False,
633
+ add_cross_attention_dim=None,
634
+ qk_norm=None,
635
+ dtype=None, device=None, operations=None
636
+ ):
637
+ super().__init__()
638
+
639
+ self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
640
+ self.attn = Attention(
641
+ query_dim=dim,
642
+ cross_attention_dim=cross_attention_dim,
643
+ added_kv_proj_dim=added_kv_proj_dim,
644
+ dim_head=attention_head_dim,
645
+ heads=num_attention_heads,
646
+ out_dim=dim,
647
+ bias=True,
648
+ qk_norm=qk_norm,
649
+ processor=CustomLiteLAProcessor2_0(),
650
+ dtype=dtype,
651
+ device=device,
652
+ operations=operations,
653
+ )
654
+
655
+ self.add_cross_attention = add_cross_attention
656
+ self.context_pre_only = context_pre_only
657
+
658
+ if add_cross_attention and add_cross_attention_dim is not None:
659
+ self.cross_attn = Attention(
660
+ query_dim=dim,
661
+ cross_attention_dim=add_cross_attention_dim,
662
+ added_kv_proj_dim=add_cross_attention_dim,
663
+ dim_head=attention_head_dim,
664
+ heads=num_attention_heads,
665
+ out_dim=dim,
666
+ context_pre_only=context_pre_only,
667
+ bias=True,
668
+ qk_norm=qk_norm,
669
+ processor=CustomerAttnProcessor2_0(),
670
+ dtype=dtype,
671
+ device=device,
672
+ operations=operations,
673
+ )
674
+
675
+ self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
676
+
677
+ self.ff = GLUMBConv(
678
+ in_features=dim,
679
+ hidden_features=int(dim * mlp_ratio),
680
+ use_bias=(True, True, False),
681
+ norm=(None, None, None),
682
+ act=("silu", "silu", None),
683
+ dtype=dtype,
684
+ device=device,
685
+ operations=operations,
686
+ )
687
+ self.use_adaln_single = use_adaln_single
688
+ if use_adaln_single:
689
+ self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
690
+
691
+ def forward(
692
+ self,
693
+ hidden_states: torch.FloatTensor,
694
+ encoder_hidden_states: torch.FloatTensor = None,
695
+ attention_mask: torch.FloatTensor = None,
696
+ encoder_attention_mask: torch.FloatTensor = None,
697
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
698
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
699
+ temb: torch.FloatTensor = None,
700
+ ):
701
+
702
+ N = hidden_states.shape[0]
703
+
704
+ # step 1: AdaLN single
705
+ if self.use_adaln_single:
706
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
707
+ comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
708
+ ).chunk(6, dim=1)
709
+
710
+ norm_hidden_states = self.norm1(hidden_states)
711
+ if self.use_adaln_single:
712
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
713
+
714
+ # step 2: attention
715
+ if not self.add_cross_attention:
716
+ attn_output, encoder_hidden_states = self.attn(
717
+ hidden_states=norm_hidden_states,
718
+ attention_mask=attention_mask,
719
+ encoder_hidden_states=encoder_hidden_states,
720
+ encoder_attention_mask=encoder_attention_mask,
721
+ rotary_freqs_cis=rotary_freqs_cis,
722
+ rotary_freqs_cis_cross=rotary_freqs_cis_cross,
723
+ )
724
+ else:
725
+ attn_output, _ = self.attn(
726
+ hidden_states=norm_hidden_states,
727
+ attention_mask=attention_mask,
728
+ encoder_hidden_states=None,
729
+ encoder_attention_mask=None,
730
+ rotary_freqs_cis=rotary_freqs_cis,
731
+ rotary_freqs_cis_cross=None,
732
+ )
733
+
734
+ if self.use_adaln_single:
735
+ attn_output = gate_msa * attn_output
736
+ hidden_states = attn_output + hidden_states
737
+
738
+ if self.add_cross_attention:
739
+ attn_output = self.cross_attn(
740
+ hidden_states=hidden_states,
741
+ attention_mask=attention_mask,
742
+ encoder_hidden_states=encoder_hidden_states,
743
+ encoder_attention_mask=encoder_attention_mask,
744
+ rotary_freqs_cis=rotary_freqs_cis,
745
+ rotary_freqs_cis_cross=rotary_freqs_cis_cross,
746
+ )
747
+ hidden_states = attn_output + hidden_states
748
+
749
+ # step 3: add norm
750
+ norm_hidden_states = self.norm2(hidden_states)
751
+ if self.use_adaln_single:
752
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
753
+
754
+ # step 4: feed forward
755
+ ff_output = self.ff(norm_hidden_states)
756
+ if self.use_adaln_single:
757
+ ff_output = gate_mlp * ff_output
758
+
759
+ hidden_states = hidden_states + ff_output
760
+
761
+ return hidden_states
ComfyUI/comfy/ldm/ace/lyric_encoder.py ADDED
@@ -0,0 +1,1067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original from: https://github.com/ace-step/ACE-Step/blob/main/models/lyrics_utils/lyric_encoder.py
2
+ from typing import Optional, Tuple, Union
3
+ import math
4
+ import torch
5
+ from torch import nn
6
+
7
+ import comfy.model_management
8
+
9
+ class ConvolutionModule(nn.Module):
10
+ """ConvolutionModule in Conformer model."""
11
+
12
+ def __init__(self,
13
+ channels: int,
14
+ kernel_size: int = 15,
15
+ activation: nn.Module = nn.ReLU(),
16
+ norm: str = "batch_norm",
17
+ causal: bool = False,
18
+ bias: bool = True,
19
+ dtype=None, device=None, operations=None):
20
+ """Construct an ConvolutionModule object.
21
+ Args:
22
+ channels (int): The number of channels of conv layers.
23
+ kernel_size (int): Kernel size of conv layers.
24
+ causal (int): Whether use causal convolution or not
25
+ """
26
+ super().__init__()
27
+
28
+ self.pointwise_conv1 = operations.Conv1d(
29
+ channels,
30
+ 2 * channels,
31
+ kernel_size=1,
32
+ stride=1,
33
+ padding=0,
34
+ bias=bias,
35
+ dtype=dtype, device=device
36
+ )
37
+ # self.lorder is used to distinguish if it's a causal convolution,
38
+ # if self.lorder > 0: it's a causal convolution, the input will be
39
+ # padded with self.lorder frames on the left in forward.
40
+ # else: it's a symmetrical convolution
41
+ if causal:
42
+ padding = 0
43
+ self.lorder = kernel_size - 1
44
+ else:
45
+ # kernel_size should be an odd number for none causal convolution
46
+ assert (kernel_size - 1) % 2 == 0
47
+ padding = (kernel_size - 1) // 2
48
+ self.lorder = 0
49
+ self.depthwise_conv = operations.Conv1d(
50
+ channels,
51
+ channels,
52
+ kernel_size,
53
+ stride=1,
54
+ padding=padding,
55
+ groups=channels,
56
+ bias=bias,
57
+ dtype=dtype, device=device
58
+ )
59
+
60
+ assert norm in ['batch_norm', 'layer_norm']
61
+ if norm == "batch_norm":
62
+ self.use_layer_norm = False
63
+ self.norm = nn.BatchNorm1d(channels)
64
+ else:
65
+ self.use_layer_norm = True
66
+ self.norm = operations.LayerNorm(channels, dtype=dtype, device=device)
67
+
68
+ self.pointwise_conv2 = operations.Conv1d(
69
+ channels,
70
+ channels,
71
+ kernel_size=1,
72
+ stride=1,
73
+ padding=0,
74
+ bias=bias,
75
+ dtype=dtype, device=device
76
+ )
77
+ self.activation = activation
78
+
79
+ def forward(
80
+ self,
81
+ x: torch.Tensor,
82
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
83
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
84
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
85
+ """Compute convolution module.
86
+ Args:
87
+ x (torch.Tensor): Input tensor (#batch, time, channels).
88
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
89
+ (0, 0, 0) means fake mask.
90
+ cache (torch.Tensor): left context cache, it is only
91
+ used in causal convolution (#batch, channels, cache_t),
92
+ (0, 0, 0) meas fake cache.
93
+ Returns:
94
+ torch.Tensor: Output tensor (#batch, time, channels).
95
+ """
96
+ # exchange the temporal dimension and the feature dimension
97
+ x = x.transpose(1, 2) # (#batch, channels, time)
98
+
99
+ # mask batch padding
100
+ if mask_pad.size(2) > 0: # time > 0
101
+ x.masked_fill_(~mask_pad, 0.0)
102
+
103
+ if self.lorder > 0:
104
+ if cache.size(2) == 0: # cache_t == 0
105
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
106
+ else:
107
+ assert cache.size(0) == x.size(0) # equal batch
108
+ assert cache.size(1) == x.size(1) # equal channel
109
+ x = torch.cat((cache, x), dim=2)
110
+ assert (x.size(2) > self.lorder)
111
+ new_cache = x[:, :, -self.lorder:]
112
+ else:
113
+ # It's better we just return None if no cache is required,
114
+ # However, for JIT export, here we just fake one tensor instead of
115
+ # None.
116
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
117
+
118
+ # GLU mechanism
119
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
120
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
121
+
122
+ # 1D Depthwise Conv
123
+ x = self.depthwise_conv(x)
124
+ if self.use_layer_norm:
125
+ x = x.transpose(1, 2)
126
+ x = self.activation(self.norm(x))
127
+ if self.use_layer_norm:
128
+ x = x.transpose(1, 2)
129
+ x = self.pointwise_conv2(x)
130
+ # mask batch padding
131
+ if mask_pad.size(2) > 0: # time > 0
132
+ x.masked_fill_(~mask_pad, 0.0)
133
+
134
+ return x.transpose(1, 2), new_cache
135
+
136
+ class PositionwiseFeedForward(torch.nn.Module):
137
+ """Positionwise feed forward layer.
138
+
139
+ FeedForward are appied on each position of the sequence.
140
+ The output dim is same with the input dim.
141
+
142
+ Args:
143
+ idim (int): Input dimenstion.
144
+ hidden_units (int): The number of hidden units.
145
+ dropout_rate (float): Dropout rate.
146
+ activation (torch.nn.Module): Activation function
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ idim: int,
152
+ hidden_units: int,
153
+ dropout_rate: float,
154
+ activation: torch.nn.Module = torch.nn.ReLU(),
155
+ dtype=None, device=None, operations=None
156
+ ):
157
+ """Construct a PositionwiseFeedForward object."""
158
+ super(PositionwiseFeedForward, self).__init__()
159
+ self.w_1 = operations.Linear(idim, hidden_units, dtype=dtype, device=device)
160
+ self.activation = activation
161
+ self.dropout = torch.nn.Dropout(dropout_rate)
162
+ self.w_2 = operations.Linear(hidden_units, idim, dtype=dtype, device=device)
163
+
164
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
165
+ """Forward function.
166
+
167
+ Args:
168
+ xs: input tensor (B, L, D)
169
+ Returns:
170
+ output tensor, (B, L, D)
171
+ """
172
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
173
+
174
+ class Swish(torch.nn.Module):
175
+ """Construct an Swish object."""
176
+
177
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
178
+ """Return Swish activation function."""
179
+ return x * torch.sigmoid(x)
180
+
181
+ class MultiHeadedAttention(nn.Module):
182
+ """Multi-Head Attention layer.
183
+
184
+ Args:
185
+ n_head (int): The number of heads.
186
+ n_feat (int): The number of features.
187
+ dropout_rate (float): Dropout rate.
188
+
189
+ """
190
+
191
+ def __init__(self,
192
+ n_head: int,
193
+ n_feat: int,
194
+ dropout_rate: float,
195
+ key_bias: bool = True,
196
+ dtype=None, device=None, operations=None):
197
+ """Construct an MultiHeadedAttention object."""
198
+ super().__init__()
199
+ assert n_feat % n_head == 0
200
+ # We assume d_v always equals d_k
201
+ self.d_k = n_feat // n_head
202
+ self.h = n_head
203
+ self.linear_q = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
204
+ self.linear_k = operations.Linear(n_feat, n_feat, bias=key_bias, dtype=dtype, device=device)
205
+ self.linear_v = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
206
+ self.linear_out = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
207
+ self.dropout = nn.Dropout(p=dropout_rate)
208
+
209
+ def forward_qkv(
210
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
211
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
212
+ """Transform query, key and value.
213
+
214
+ Args:
215
+ query (torch.Tensor): Query tensor (#batch, time1, size).
216
+ key (torch.Tensor): Key tensor (#batch, time2, size).
217
+ value (torch.Tensor): Value tensor (#batch, time2, size).
218
+
219
+ Returns:
220
+ torch.Tensor: Transformed query tensor, size
221
+ (#batch, n_head, time1, d_k).
222
+ torch.Tensor: Transformed key tensor, size
223
+ (#batch, n_head, time2, d_k).
224
+ torch.Tensor: Transformed value tensor, size
225
+ (#batch, n_head, time2, d_k).
226
+
227
+ """
228
+ n_batch = query.size(0)
229
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
230
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
231
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
232
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
233
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
234
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
235
+ return q, k, v
236
+
237
+ def forward_attention(
238
+ self,
239
+ value: torch.Tensor,
240
+ scores: torch.Tensor,
241
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
242
+ ) -> torch.Tensor:
243
+ """Compute attention context vector.
244
+
245
+ Args:
246
+ value (torch.Tensor): Transformed value, size
247
+ (#batch, n_head, time2, d_k).
248
+ scores (torch.Tensor): Attention score, size
249
+ (#batch, n_head, time1, time2).
250
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
251
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
252
+
253
+ Returns:
254
+ torch.Tensor: Transformed value (#batch, time1, d_model)
255
+ weighted by the attention score (#batch, time1, time2).
256
+
257
+ """
258
+ n_batch = value.size(0)
259
+
260
+ if mask is not None and mask.size(2) > 0: # time2 > 0
261
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
262
+ # For last chunk, time2 might be larger than scores.size(-1)
263
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
264
+ scores = scores.masked_fill(mask, -float('inf'))
265
+ attn = torch.softmax(scores, dim=-1).masked_fill(
266
+ mask, 0.0) # (batch, head, time1, time2)
267
+
268
+ else:
269
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
270
+
271
+ p_attn = self.dropout(attn)
272
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
273
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
274
+ self.h * self.d_k)
275
+ ) # (batch, time1, d_model)
276
+
277
+ return self.linear_out(x) # (batch, time1, d_model)
278
+
279
+ def forward(
280
+ self,
281
+ query: torch.Tensor,
282
+ key: torch.Tensor,
283
+ value: torch.Tensor,
284
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
285
+ pos_emb: torch.Tensor = torch.empty(0),
286
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
287
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
288
+ """Compute scaled dot product attention.
289
+
290
+ Args:
291
+ query (torch.Tensor): Query tensor (#batch, time1, size).
292
+ key (torch.Tensor): Key tensor (#batch, time2, size).
293
+ value (torch.Tensor): Value tensor (#batch, time2, size).
294
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
295
+ (#batch, time1, time2).
296
+ 1.When applying cross attention between decoder and encoder,
297
+ the batch padding mask for input is in (#batch, 1, T) shape.
298
+ 2.When applying self attention of encoder,
299
+ the mask is in (#batch, T, T) shape.
300
+ 3.When applying self attention of decoder,
301
+ the mask is in (#batch, L, L) shape.
302
+ 4.If the different position in decoder see different block
303
+ of the encoder, such as Mocha, the passed in mask could be
304
+ in (#batch, L, T) shape. But there is no such case in current
305
+ CosyVoice.
306
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
307
+ where `cache_t == chunk_size * num_decoding_left_chunks`
308
+ and `head * d_k == size`
309
+
310
+
311
+ Returns:
312
+ torch.Tensor: Output tensor (#batch, time1, d_model).
313
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
314
+ where `cache_t == chunk_size * num_decoding_left_chunks`
315
+ and `head * d_k == size`
316
+
317
+ """
318
+ q, k, v = self.forward_qkv(query, key, value)
319
+ if cache.size(0) > 0:
320
+ key_cache, value_cache = torch.split(cache,
321
+ cache.size(-1) // 2,
322
+ dim=-1)
323
+ k = torch.cat([key_cache, k], dim=2)
324
+ v = torch.cat([value_cache, v], dim=2)
325
+ new_cache = torch.cat((k, v), dim=-1)
326
+
327
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
328
+ return self.forward_attention(v, scores, mask), new_cache
329
+
330
+
331
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
332
+ """Multi-Head Attention layer with relative position encoding.
333
+ Paper: https://arxiv.org/abs/1901.02860
334
+ Args:
335
+ n_head (int): The number of heads.
336
+ n_feat (int): The number of features.
337
+ dropout_rate (float): Dropout rate.
338
+ """
339
+
340
+ def __init__(self,
341
+ n_head: int,
342
+ n_feat: int,
343
+ dropout_rate: float,
344
+ key_bias: bool = True,
345
+ dtype=None, device=None, operations=None):
346
+ """Construct an RelPositionMultiHeadedAttention object."""
347
+ super().__init__(n_head, n_feat, dropout_rate, key_bias, dtype=dtype, device=device, operations=operations)
348
+ # linear transformation for positional encoding
349
+ self.linear_pos = operations.Linear(n_feat, n_feat, bias=False, dtype=dtype, device=device)
350
+ # these two learnable bias are used in matrix c and matrix d
351
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
352
+ self.pos_bias_u = nn.Parameter(torch.empty(self.h, self.d_k, dtype=dtype, device=device))
353
+ self.pos_bias_v = nn.Parameter(torch.empty(self.h, self.d_k, dtype=dtype, device=device))
354
+ # torch.nn.init.xavier_uniform_(self.pos_bias_u)
355
+ # torch.nn.init.xavier_uniform_(self.pos_bias_v)
356
+
357
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
358
+ """Compute relative positional encoding.
359
+
360
+ Args:
361
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
362
+ time1 means the length of query vector.
363
+
364
+ Returns:
365
+ torch.Tensor: Output tensor.
366
+
367
+ """
368
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
369
+ device=x.device,
370
+ dtype=x.dtype)
371
+ x_padded = torch.cat([zero_pad, x], dim=-1)
372
+
373
+ x_padded = x_padded.view(x.size()[0],
374
+ x.size()[1],
375
+ x.size(3) + 1, x.size(2))
376
+ x = x_padded[:, :, 1:].view_as(x)[
377
+ :, :, :, : x.size(-1) // 2 + 1
378
+ ] # only keep the positions from 0 to time2
379
+ return x
380
+
381
+ def forward(
382
+ self,
383
+ query: torch.Tensor,
384
+ key: torch.Tensor,
385
+ value: torch.Tensor,
386
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
387
+ pos_emb: torch.Tensor = torch.empty(0),
388
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
389
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
390
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
391
+ Args:
392
+ query (torch.Tensor): Query tensor (#batch, time1, size).
393
+ key (torch.Tensor): Key tensor (#batch, time2, size).
394
+ value (torch.Tensor): Value tensor (#batch, time2, size).
395
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
396
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
397
+ pos_emb (torch.Tensor): Positional embedding tensor
398
+ (#batch, time2, size).
399
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
400
+ where `cache_t == chunk_size * num_decoding_left_chunks`
401
+ and `head * d_k == size`
402
+ Returns:
403
+ torch.Tensor: Output tensor (#batch, time1, d_model).
404
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
405
+ where `cache_t == chunk_size * num_decoding_left_chunks`
406
+ and `head * d_k == size`
407
+ """
408
+ q, k, v = self.forward_qkv(query, key, value)
409
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
410
+
411
+ if cache.size(0) > 0:
412
+ key_cache, value_cache = torch.split(cache,
413
+ cache.size(-1) // 2,
414
+ dim=-1)
415
+ k = torch.cat([key_cache, k], dim=2)
416
+ v = torch.cat([value_cache, v], dim=2)
417
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
418
+ # non-trivial to calculate `next_cache_start` here.
419
+ new_cache = torch.cat((k, v), dim=-1)
420
+
421
+ n_batch_pos = pos_emb.size(0)
422
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
423
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
424
+
425
+ # (batch, head, time1, d_k)
426
+ q_with_bias_u = (q + comfy.model_management.cast_to(self.pos_bias_u, dtype=q.dtype, device=q.device)).transpose(1, 2)
427
+ # (batch, head, time1, d_k)
428
+ q_with_bias_v = (q + comfy.model_management.cast_to(self.pos_bias_v, dtype=q.dtype, device=q.device)).transpose(1, 2)
429
+
430
+ # compute attention score
431
+ # first compute matrix a and matrix c
432
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
433
+ # (batch, head, time1, time2)
434
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
435
+
436
+ # compute matrix b and matrix d
437
+ # (batch, head, time1, time2)
438
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
439
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
440
+ if matrix_ac.shape != matrix_bd.shape:
441
+ matrix_bd = self.rel_shift(matrix_bd)
442
+
443
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
444
+ self.d_k) # (batch, head, time1, time2)
445
+
446
+ return self.forward_attention(v, scores, mask), new_cache
447
+
448
+
449
+
450
+ def subsequent_mask(
451
+ size: int,
452
+ device: torch.device = torch.device("cpu"),
453
+ ) -> torch.Tensor:
454
+ """Create mask for subsequent steps (size, size).
455
+
456
+ This mask is used only in decoder which works in an auto-regressive mode.
457
+ This means the current step could only do attention with its left steps.
458
+
459
+ In encoder, fully attention is used when streaming is not necessary and
460
+ the sequence is not long. In this case, no attention mask is needed.
461
+
462
+ When streaming is need, chunk-based attention is used in encoder. See
463
+ subsequent_chunk_mask for the chunk-based attention mask.
464
+
465
+ Args:
466
+ size (int): size of mask
467
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
468
+ dtype (torch.device): result dtype
469
+
470
+ Returns:
471
+ torch.Tensor: mask
472
+
473
+ Examples:
474
+ >>> subsequent_mask(3)
475
+ [[1, 0, 0],
476
+ [1, 1, 0],
477
+ [1, 1, 1]]
478
+ """
479
+ arange = torch.arange(size, device=device)
480
+ mask = arange.expand(size, size)
481
+ arange = arange.unsqueeze(-1)
482
+ mask = mask <= arange
483
+ return mask
484
+
485
+
486
+ def subsequent_chunk_mask(
487
+ size: int,
488
+ chunk_size: int,
489
+ num_left_chunks: int = -1,
490
+ device: torch.device = torch.device("cpu"),
491
+ ) -> torch.Tensor:
492
+ """Create mask for subsequent steps (size, size) with chunk size,
493
+ this is for streaming encoder
494
+
495
+ Args:
496
+ size (int): size of mask
497
+ chunk_size (int): size of chunk
498
+ num_left_chunks (int): number of left chunks
499
+ <0: use full chunk
500
+ >=0: use num_left_chunks
501
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
502
+
503
+ Returns:
504
+ torch.Tensor: mask
505
+
506
+ Examples:
507
+ >>> subsequent_chunk_mask(4, 2)
508
+ [[1, 1, 0, 0],
509
+ [1, 1, 0, 0],
510
+ [1, 1, 1, 1],
511
+ [1, 1, 1, 1]]
512
+ """
513
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
514
+ for i in range(size):
515
+ if num_left_chunks < 0:
516
+ start = 0
517
+ else:
518
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
519
+ ending = min((i // chunk_size + 1) * chunk_size, size)
520
+ ret[i, start:ending] = True
521
+ return ret
522
+
523
+ def add_optional_chunk_mask(xs: torch.Tensor,
524
+ masks: torch.Tensor,
525
+ use_dynamic_chunk: bool,
526
+ use_dynamic_left_chunk: bool,
527
+ decoding_chunk_size: int,
528
+ static_chunk_size: int,
529
+ num_decoding_left_chunks: int,
530
+ enable_full_context: bool = True):
531
+ """ Apply optional mask for encoder.
532
+
533
+ Args:
534
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
535
+ mask (torch.Tensor): mask for xs, (B, 1, L)
536
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
537
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
538
+ training.
539
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
540
+ 0: default for training, use random dynamic chunk.
541
+ <0: for decoding, use full chunk.
542
+ >0: for decoding, use fixed chunk size as set.
543
+ static_chunk_size (int): chunk size for static chunk training/decoding
544
+ if it's greater than 0, if use_dynamic_chunk is true,
545
+ this parameter will be ignored
546
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
547
+ the chunk size is decoding_chunk_size.
548
+ >=0: use num_decoding_left_chunks
549
+ <0: use all left chunks
550
+ enable_full_context (bool):
551
+ True: chunk size is either [1, 25] or full context(max_len)
552
+ False: chunk size ~ U[1, 25]
553
+
554
+ Returns:
555
+ torch.Tensor: chunk mask of the input xs.
556
+ """
557
+ # Whether to use chunk mask or not
558
+ if use_dynamic_chunk:
559
+ max_len = xs.size(1)
560
+ if decoding_chunk_size < 0:
561
+ chunk_size = max_len
562
+ num_left_chunks = -1
563
+ elif decoding_chunk_size > 0:
564
+ chunk_size = decoding_chunk_size
565
+ num_left_chunks = num_decoding_left_chunks
566
+ else:
567
+ # chunk size is either [1, 25] or full context(max_len).
568
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
569
+ # delay, the maximum frame is 100 / 4 = 25.
570
+ chunk_size = torch.randint(1, max_len, (1, )).item()
571
+ num_left_chunks = -1
572
+ if chunk_size > max_len // 2 and enable_full_context:
573
+ chunk_size = max_len
574
+ else:
575
+ chunk_size = chunk_size % 25 + 1
576
+ if use_dynamic_left_chunk:
577
+ max_left_chunks = (max_len - 1) // chunk_size
578
+ num_left_chunks = torch.randint(0, max_left_chunks,
579
+ (1, )).item()
580
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
581
+ num_left_chunks,
582
+ xs.device) # (L, L)
583
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
584
+ chunk_masks = masks & chunk_masks # (B, L, L)
585
+ elif static_chunk_size > 0:
586
+ num_left_chunks = num_decoding_left_chunks
587
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
588
+ num_left_chunks,
589
+ xs.device) # (L, L)
590
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
591
+ chunk_masks = masks & chunk_masks # (B, L, L)
592
+ else:
593
+ chunk_masks = masks
594
+ return chunk_masks
595
+
596
+
597
+ class ConformerEncoderLayer(nn.Module):
598
+ """Encoder layer module.
599
+ Args:
600
+ size (int): Input dimension.
601
+ self_attn (torch.nn.Module): Self-attention module instance.
602
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
603
+ instance can be used as the argument.
604
+ feed_forward (torch.nn.Module): Feed-forward module instance.
605
+ `PositionwiseFeedForward` instance can be used as the argument.
606
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
607
+ instance.
608
+ `PositionwiseFeedForward` instance can be used as the argument.
609
+ conv_module (torch.nn.Module): Convolution module instance.
610
+ `ConvlutionModule` instance can be used as the argument.
611
+ dropout_rate (float): Dropout rate.
612
+ normalize_before (bool):
613
+ True: use layer_norm before each sub-block.
614
+ False: use layer_norm after each sub-block.
615
+ """
616
+
617
+ def __init__(
618
+ self,
619
+ size: int,
620
+ self_attn: torch.nn.Module,
621
+ feed_forward: Optional[nn.Module] = None,
622
+ feed_forward_macaron: Optional[nn.Module] = None,
623
+ conv_module: Optional[nn.Module] = None,
624
+ dropout_rate: float = 0.1,
625
+ normalize_before: bool = True,
626
+ dtype=None, device=None, operations=None
627
+ ):
628
+ """Construct an EncoderLayer object."""
629
+ super().__init__()
630
+ self.self_attn = self_attn
631
+ self.feed_forward = feed_forward
632
+ self.feed_forward_macaron = feed_forward_macaron
633
+ self.conv_module = conv_module
634
+ self.norm_ff = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the FNN module
635
+ self.norm_mha = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the MHA module
636
+ if feed_forward_macaron is not None:
637
+ self.norm_ff_macaron = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device)
638
+ self.ff_scale = 0.5
639
+ else:
640
+ self.ff_scale = 1.0
641
+ if self.conv_module is not None:
642
+ self.norm_conv = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the CNN module
643
+ self.norm_final = operations.LayerNorm(
644
+ size, eps=1e-5, dtype=dtype, device=device) # for the final output of the block
645
+ self.dropout = nn.Dropout(dropout_rate)
646
+ self.size = size
647
+ self.normalize_before = normalize_before
648
+
649
+ def forward(
650
+ self,
651
+ x: torch.Tensor,
652
+ mask: torch.Tensor,
653
+ pos_emb: torch.Tensor,
654
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
655
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
656
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
657
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
658
+ """Compute encoded features.
659
+
660
+ Args:
661
+ x (torch.Tensor): (#batch, time, size)
662
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
663
+ (0, 0, 0) means fake mask.
664
+ pos_emb (torch.Tensor): positional encoding, must not be None
665
+ for ConformerEncoderLayer.
666
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
667
+ (#batch, 1,time), (0, 0, 0) means fake mask.
668
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
669
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
670
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
671
+ (#batch=1, size, cache_t2)
672
+ Returns:
673
+ torch.Tensor: Output tensor (#batch, time, size).
674
+ torch.Tensor: Mask tensor (#batch, time, time).
675
+ torch.Tensor: att_cache tensor,
676
+ (#batch=1, head, cache_t1 + time, d_k * 2).
677
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
678
+ """
679
+
680
+ # whether to use macaron style
681
+ if self.feed_forward_macaron is not None:
682
+ residual = x
683
+ if self.normalize_before:
684
+ x = self.norm_ff_macaron(x)
685
+ x = residual + self.ff_scale * self.dropout(
686
+ self.feed_forward_macaron(x))
687
+ if not self.normalize_before:
688
+ x = self.norm_ff_macaron(x)
689
+
690
+ # multi-headed self-attention module
691
+ residual = x
692
+ if self.normalize_before:
693
+ x = self.norm_mha(x)
694
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
695
+ att_cache)
696
+ x = residual + self.dropout(x_att)
697
+ if not self.normalize_before:
698
+ x = self.norm_mha(x)
699
+
700
+ # convolution module
701
+ # Fake new cnn cache here, and then change it in conv_module
702
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
703
+ if self.conv_module is not None:
704
+ residual = x
705
+ if self.normalize_before:
706
+ x = self.norm_conv(x)
707
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
708
+ x = residual + self.dropout(x)
709
+
710
+ if not self.normalize_before:
711
+ x = self.norm_conv(x)
712
+
713
+ # feed forward module
714
+ residual = x
715
+ if self.normalize_before:
716
+ x = self.norm_ff(x)
717
+
718
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
719
+ if not self.normalize_before:
720
+ x = self.norm_ff(x)
721
+
722
+ if self.conv_module is not None:
723
+ x = self.norm_final(x)
724
+
725
+ return x, mask, new_att_cache, new_cnn_cache
726
+
727
+
728
+
729
+ class EspnetRelPositionalEncoding(torch.nn.Module):
730
+ """Relative positional encoding module (new implementation).
731
+
732
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
733
+
734
+ See : Appendix B in https://arxiv.org/abs/1901.02860
735
+
736
+ Args:
737
+ d_model (int): Embedding dimension.
738
+ dropout_rate (float): Dropout rate.
739
+ max_len (int): Maximum input length.
740
+
741
+ """
742
+
743
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
744
+ """Construct an PositionalEncoding object."""
745
+ super(EspnetRelPositionalEncoding, self).__init__()
746
+ self.d_model = d_model
747
+ self.xscale = math.sqrt(self.d_model)
748
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
749
+ self.pe = None
750
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
751
+
752
+ def extend_pe(self, x: torch.Tensor):
753
+ """Reset the positional encodings."""
754
+ if self.pe is not None:
755
+ # self.pe contains both positive and negative parts
756
+ # the length of self.pe is 2 * input_len - 1
757
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
758
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
759
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
760
+ return
761
+ # Suppose `i` means to the position of query vecotr and `j` means the
762
+ # position of key vector. We use position relative positions when keys
763
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
764
+ pe_positive = torch.zeros(x.size(1), self.d_model)
765
+ pe_negative = torch.zeros(x.size(1), self.d_model)
766
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
767
+ div_term = torch.exp(
768
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
769
+ * -(math.log(10000.0) / self.d_model)
770
+ )
771
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
772
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
773
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
774
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
775
+
776
+ # Reserve the order of positive indices and concat both positive and
777
+ # negative indices. This is used to support the shifting trick
778
+ # as in https://arxiv.org/abs/1901.02860
779
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
780
+ pe_negative = pe_negative[1:].unsqueeze(0)
781
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
782
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
783
+
784
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
785
+ -> Tuple[torch.Tensor, torch.Tensor]:
786
+ """Add positional encoding.
787
+
788
+ Args:
789
+ x (torch.Tensor): Input tensor (batch, time, `*`).
790
+
791
+ Returns:
792
+ torch.Tensor: Encoded tensor (batch, time, `*`).
793
+
794
+ """
795
+ self.extend_pe(x)
796
+ x = x * self.xscale
797
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
798
+ return self.dropout(x), self.dropout(pos_emb)
799
+
800
+ def position_encoding(self,
801
+ offset: Union[int, torch.Tensor],
802
+ size: int) -> torch.Tensor:
803
+ """ For getting encoding in a streaming fashion
804
+
805
+ Attention!!!!!
806
+ we apply dropout only once at the whole utterance level in a none
807
+ streaming way, but will call this function several times with
808
+ increasing input size in a streaming scenario, so the dropout will
809
+ be applied several times.
810
+
811
+ Args:
812
+ offset (int or torch.tensor): start offset
813
+ size (int): required size of position encoding
814
+
815
+ Returns:
816
+ torch.Tensor: Corresponding encoding
817
+ """
818
+ pos_emb = self.pe[
819
+ :,
820
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
821
+ ]
822
+ return pos_emb
823
+
824
+
825
+
826
+ class LinearEmbed(torch.nn.Module):
827
+ """Linear transform the input without subsampling
828
+
829
+ Args:
830
+ idim (int): Input dimension.
831
+ odim (int): Output dimension.
832
+ dropout_rate (float): Dropout rate.
833
+
834
+ """
835
+
836
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
837
+ pos_enc_class: torch.nn.Module, dtype=None, device=None, operations=None):
838
+ """Construct an linear object."""
839
+ super().__init__()
840
+ self.out = torch.nn.Sequential(
841
+ operations.Linear(idim, odim, dtype=dtype, device=device),
842
+ operations.LayerNorm(odim, eps=1e-5, dtype=dtype, device=device),
843
+ torch.nn.Dropout(dropout_rate),
844
+ )
845
+ self.pos_enc = pos_enc_class #rel_pos_espnet
846
+
847
+ def position_encoding(self, offset: Union[int, torch.Tensor],
848
+ size: int) -> torch.Tensor:
849
+ return self.pos_enc.position_encoding(offset, size)
850
+
851
+ def forward(
852
+ self,
853
+ x: torch.Tensor,
854
+ offset: Union[int, torch.Tensor] = 0
855
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
856
+ """Input x.
857
+
858
+ Args:
859
+ x (torch.Tensor): Input tensor (#batch, time, idim).
860
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
861
+
862
+ Returns:
863
+ torch.Tensor: linear input tensor (#batch, time', odim),
864
+ where time' = time .
865
+ torch.Tensor: linear input mask (#batch, 1, time'),
866
+ where time' = time .
867
+
868
+ """
869
+ x = self.out(x)
870
+ x, pos_emb = self.pos_enc(x, offset)
871
+ return x, pos_emb
872
+
873
+
874
+ ATTENTION_CLASSES = {
875
+ "selfattn": MultiHeadedAttention,
876
+ "rel_selfattn": RelPositionMultiHeadedAttention,
877
+ }
878
+
879
+ ACTIVATION_CLASSES = {
880
+ "hardtanh": torch.nn.Hardtanh,
881
+ "tanh": torch.nn.Tanh,
882
+ "relu": torch.nn.ReLU,
883
+ "selu": torch.nn.SELU,
884
+ "swish": getattr(torch.nn, "SiLU", Swish),
885
+ "gelu": torch.nn.GELU,
886
+ }
887
+
888
+
889
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
890
+ """Make mask tensor containing indices of padded part.
891
+
892
+ See description of make_non_pad_mask.
893
+
894
+ Args:
895
+ lengths (torch.Tensor): Batch of lengths (B,).
896
+ Returns:
897
+ torch.Tensor: Mask tensor containing indices of padded part.
898
+
899
+ Examples:
900
+ >>> lengths = [5, 3, 2]
901
+ >>> make_pad_mask(lengths)
902
+ masks = [[0, 0, 0, 0 ,0],
903
+ [0, 0, 0, 1, 1],
904
+ [0, 0, 1, 1, 1]]
905
+ """
906
+ batch_size = lengths.size(0)
907
+ max_len = max_len if max_len > 0 else lengths.max().item()
908
+ seq_range = torch.arange(0,
909
+ max_len,
910
+ dtype=torch.int64,
911
+ device=lengths.device)
912
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
913
+ seq_length_expand = lengths.unsqueeze(-1)
914
+ mask = seq_range_expand >= seq_length_expand
915
+ return mask
916
+
917
+ #https://github.com/FunAudioLLM/CosyVoice/blob/main/examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml
918
+ class ConformerEncoder(torch.nn.Module):
919
+ """Conformer encoder module."""
920
+
921
+ def __init__(
922
+ self,
923
+ input_size: int,
924
+ output_size: int = 1024,
925
+ attention_heads: int = 16,
926
+ linear_units: int = 4096,
927
+ num_blocks: int = 6,
928
+ dropout_rate: float = 0.1,
929
+ positional_dropout_rate: float = 0.1,
930
+ attention_dropout_rate: float = 0.0,
931
+ input_layer: str = 'linear',
932
+ pos_enc_layer_type: str = 'rel_pos_espnet',
933
+ normalize_before: bool = True,
934
+ static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask
935
+ use_dynamic_chunk: bool = False,
936
+ use_dynamic_left_chunk: bool = False,
937
+ positionwise_conv_kernel_size: int = 1,
938
+ macaron_style: bool =False,
939
+ selfattention_layer_type: str = "rel_selfattn",
940
+ activation_type: str = "swish",
941
+ use_cnn_module: bool = False,
942
+ cnn_module_kernel: int = 15,
943
+ causal: bool = False,
944
+ cnn_module_norm: str = "batch_norm",
945
+ key_bias: bool = True,
946
+ dtype=None, device=None, operations=None
947
+ ):
948
+ """Construct ConformerEncoder
949
+
950
+ Args:
951
+ input_size to use_dynamic_chunk, see in BaseEncoder
952
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
953
+ conv1d layer.
954
+ macaron_style (bool): Whether to use macaron style for
955
+ positionwise layer.
956
+ selfattention_layer_type (str): Encoder attention layer type,
957
+ the parameter has no effect now, it's just for configure
958
+ compatibility. #'rel_selfattn'
959
+ activation_type (str): Encoder activation function type.
960
+ use_cnn_module (bool): Whether to use convolution module.
961
+ cnn_module_kernel (int): Kernel size of convolution module.
962
+ causal (bool): whether to use causal convolution or not.
963
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
964
+ """
965
+ super().__init__()
966
+ self.output_size = output_size
967
+ self.embed = LinearEmbed(input_size, output_size, dropout_rate,
968
+ EspnetRelPositionalEncoding(output_size, positional_dropout_rate), dtype=dtype, device=device, operations=operations)
969
+ self.normalize_before = normalize_before
970
+ self.after_norm = operations.LayerNorm(output_size, eps=1e-5, dtype=dtype, device=device)
971
+ self.use_dynamic_chunk = use_dynamic_chunk
972
+
973
+ self.static_chunk_size = static_chunk_size
974
+ self.use_dynamic_chunk = use_dynamic_chunk
975
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
976
+ activation = ACTIVATION_CLASSES[activation_type]()
977
+
978
+ # self-attention module definition
979
+ encoder_selfattn_layer_args = (
980
+ attention_heads,
981
+ output_size,
982
+ attention_dropout_rate,
983
+ key_bias,
984
+ )
985
+ # feed-forward module definition
986
+ positionwise_layer_args = (
987
+ output_size,
988
+ linear_units,
989
+ dropout_rate,
990
+ activation,
991
+ )
992
+ # convolution module definition
993
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
994
+ cnn_module_norm, causal)
995
+
996
+ self.encoders = torch.nn.ModuleList([
997
+ ConformerEncoderLayer(
998
+ output_size,
999
+ RelPositionMultiHeadedAttention(
1000
+ *encoder_selfattn_layer_args, dtype=dtype, device=device, operations=operations),
1001
+ PositionwiseFeedForward(*positionwise_layer_args, dtype=dtype, device=device, operations=operations),
1002
+ PositionwiseFeedForward(
1003
+ *positionwise_layer_args, dtype=dtype, device=device, operations=operations) if macaron_style else None,
1004
+ ConvolutionModule(
1005
+ *convolution_layer_args, dtype=dtype, device=device, operations=operations) if use_cnn_module else None,
1006
+ dropout_rate,
1007
+ normalize_before, dtype=dtype, device=device, operations=operations
1008
+ ) for _ in range(num_blocks)
1009
+ ])
1010
+
1011
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
1012
+ pos_emb: torch.Tensor,
1013
+ mask_pad: torch.Tensor) -> torch.Tensor:
1014
+ for layer in self.encoders:
1015
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
1016
+ return xs
1017
+
1018
+ def forward(
1019
+ self,
1020
+ xs: torch.Tensor,
1021
+ pad_mask: torch.Tensor,
1022
+ decoding_chunk_size: int = 0,
1023
+ num_decoding_left_chunks: int = -1,
1024
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1025
+ """Embed positions in tensor.
1026
+
1027
+ Args:
1028
+ xs: padded input tensor (B, T, D)
1029
+ xs_lens: input length (B)
1030
+ decoding_chunk_size: decoding chunk size for dynamic chunk
1031
+ 0: default for training, use random dynamic chunk.
1032
+ <0: for decoding, use full chunk.
1033
+ >0: for decoding, use fixed chunk size as set.
1034
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
1035
+ the chunk size is decoding_chunk_size.
1036
+ >=0: use num_decoding_left_chunks
1037
+ <0: use all left chunks
1038
+ Returns:
1039
+ encoder output tensor xs, and subsampled masks
1040
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
1041
+ masks: torch.Tensor batch padding mask after subsample
1042
+ (B, 1, T' ~= T/subsample_rate)
1043
+ NOTE(xcsong):
1044
+ We pass the `__call__` method of the modules instead of `forward` to the
1045
+ checkpointing API because `__call__` attaches all the hooks of the module.
1046
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
1047
+ """
1048
+ masks = None
1049
+ if pad_mask is not None:
1050
+ masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T)
1051
+ xs, pos_emb = self.embed(xs)
1052
+ mask_pad = masks # (B, 1, T/subsample_rate)
1053
+ chunk_masks = add_optional_chunk_mask(xs, masks,
1054
+ self.use_dynamic_chunk,
1055
+ self.use_dynamic_left_chunk,
1056
+ decoding_chunk_size,
1057
+ self.static_chunk_size,
1058
+ num_decoding_left_chunks)
1059
+
1060
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
1061
+ if self.normalize_before:
1062
+ xs = self.after_norm(xs)
1063
+ # Here we assume the mask is not changed in encoder layers, so just
1064
+ # return the masks before encoder layers, and the masks will be used
1065
+ # for cross attention with decoder later
1066
+ return xs, masks
1067
+
ComfyUI/comfy/ldm/ace/model.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
2
+
3
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Optional, List, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ import comfy.model_management
22
+
23
+ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
24
+ from .attention import LinearTransformerBlock, t2i_modulate
25
+ from .lyric_encoder import ConformerEncoder as LyricEncoder
26
+
27
+
28
+ def cross_norm(hidden_states, controlnet_input):
29
+ # input N x T x c
30
+ mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
31
+ mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
32
+ controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
33
+ return controlnet_input
34
+
35
+
36
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
37
+ class Qwen2RotaryEmbedding(nn.Module):
38
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
39
+ super().__init__()
40
+
41
+ self.dim = dim
42
+ self.max_position_embeddings = max_position_embeddings
43
+ self.base = base
44
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
45
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
46
+
47
+ # Build here to make `torch.jit.trace` work.
48
+ self._set_cos_sin_cache(
49
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
50
+ )
51
+
52
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
53
+ self.max_seq_len_cached = seq_len
54
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
55
+
56
+ freqs = torch.outer(t, self.inv_freq)
57
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
58
+ emb = torch.cat((freqs, freqs), dim=-1)
59
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
60
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
61
+
62
+ def forward(self, x, seq_len=None):
63
+ # x: [bs, num_attention_heads, seq_len, head_size]
64
+ if seq_len > self.max_seq_len_cached:
65
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
66
+
67
+ return (
68
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
69
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
70
+ )
71
+
72
+
73
+ class T2IFinalLayer(nn.Module):
74
+ """
75
+ The final layer of Sana.
76
+ """
77
+
78
+ def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
79
+ super().__init__()
80
+ self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
81
+ self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
82
+ self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
83
+ self.out_channels = out_channels
84
+ self.patch_size = patch_size
85
+
86
+ def unpatchfy(
87
+ self,
88
+ hidden_states: torch.Tensor,
89
+ width: int,
90
+ ):
91
+ # 4 unpatchify
92
+ new_height, new_width = 1, hidden_states.size(1)
93
+ hidden_states = hidden_states.reshape(
94
+ shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
95
+ ).contiguous()
96
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
97
+ output = hidden_states.reshape(
98
+ shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
99
+ ).contiguous()
100
+ if width > new_width:
101
+ output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
102
+ elif width < new_width:
103
+ output = output[:, :, :, :width]
104
+ return output
105
+
106
+ def forward(self, x, t, output_length):
107
+ shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
108
+ x = t2i_modulate(self.norm_final(x), shift, scale)
109
+ x = self.linear(x)
110
+ # unpatchify
111
+ output = self.unpatchfy(x, output_length)
112
+ return output
113
+
114
+
115
+ class PatchEmbed(nn.Module):
116
+ """2D Image to Patch Embedding"""
117
+
118
+ def __init__(
119
+ self,
120
+ height=16,
121
+ width=4096,
122
+ patch_size=(16, 1),
123
+ in_channels=8,
124
+ embed_dim=1152,
125
+ bias=True,
126
+ dtype=None, device=None, operations=None
127
+ ):
128
+ super().__init__()
129
+ patch_size_h, patch_size_w = patch_size
130
+ self.early_conv_layers = nn.Sequential(
131
+ operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
132
+ operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
133
+ operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
134
+ )
135
+ self.patch_size = patch_size
136
+ self.height, self.width = height // patch_size_h, width // patch_size_w
137
+ self.base_size = self.width
138
+
139
+ def forward(self, latent):
140
+ # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
141
+ latent = self.early_conv_layers(latent)
142
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
143
+ return latent
144
+
145
+
146
+ class ACEStepTransformer2DModel(nn.Module):
147
+ # _supports_gradient_checkpointing = True
148
+
149
+ def __init__(
150
+ self,
151
+ in_channels: Optional[int] = 8,
152
+ num_layers: int = 28,
153
+ inner_dim: int = 1536,
154
+ attention_head_dim: int = 64,
155
+ num_attention_heads: int = 24,
156
+ mlp_ratio: float = 4.0,
157
+ out_channels: int = 8,
158
+ max_position: int = 32768,
159
+ rope_theta: float = 1000000.0,
160
+ speaker_embedding_dim: int = 512,
161
+ text_embedding_dim: int = 768,
162
+ ssl_encoder_depths: List[int] = [9, 9],
163
+ ssl_names: List[str] = ["mert", "m-hubert"],
164
+ ssl_latent_dims: List[int] = [1024, 768],
165
+ lyric_encoder_vocab_size: int = 6681,
166
+ lyric_hidden_size: int = 1024,
167
+ patch_size: List[int] = [16, 1],
168
+ max_height: int = 16,
169
+ max_width: int = 4096,
170
+ audio_model=None,
171
+ dtype=None, device=None, operations=None
172
+
173
+ ):
174
+ super().__init__()
175
+
176
+ self.dtype = dtype
177
+ self.num_attention_heads = num_attention_heads
178
+ self.attention_head_dim = attention_head_dim
179
+ inner_dim = num_attention_heads * attention_head_dim
180
+ self.inner_dim = inner_dim
181
+ self.out_channels = out_channels
182
+ self.max_position = max_position
183
+ self.patch_size = patch_size
184
+
185
+ self.rope_theta = rope_theta
186
+
187
+ self.rotary_emb = Qwen2RotaryEmbedding(
188
+ dim=self.attention_head_dim,
189
+ max_position_embeddings=self.max_position,
190
+ base=self.rope_theta,
191
+ dtype=dtype,
192
+ device=device,
193
+ )
194
+
195
+ # 2. Define input layers
196
+ self.in_channels = in_channels
197
+
198
+ self.num_layers = num_layers
199
+ # 3. Define transformers blocks
200
+ self.transformer_blocks = nn.ModuleList(
201
+ [
202
+ LinearTransformerBlock(
203
+ dim=self.inner_dim,
204
+ num_attention_heads=self.num_attention_heads,
205
+ attention_head_dim=attention_head_dim,
206
+ mlp_ratio=mlp_ratio,
207
+ add_cross_attention=True,
208
+ add_cross_attention_dim=self.inner_dim,
209
+ dtype=dtype,
210
+ device=device,
211
+ operations=operations,
212
+ )
213
+ for i in range(self.num_layers)
214
+ ]
215
+ )
216
+
217
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
218
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
219
+ self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
220
+
221
+ # speaker
222
+ self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
223
+
224
+ # genre
225
+ self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
226
+
227
+ # lyric
228
+ self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
229
+ self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
230
+ self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
231
+
232
+ projector_dim = 2 * self.inner_dim
233
+
234
+ self.projectors = nn.ModuleList([
235
+ nn.Sequential(
236
+ operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
237
+ nn.SiLU(),
238
+ operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
239
+ nn.SiLU(),
240
+ operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
241
+ ) for ssl_dim in ssl_latent_dims
242
+ ])
243
+
244
+ self.proj_in = PatchEmbed(
245
+ height=max_height,
246
+ width=max_width,
247
+ patch_size=patch_size,
248
+ embed_dim=self.inner_dim,
249
+ bias=True,
250
+ dtype=dtype,
251
+ device=device,
252
+ operations=operations,
253
+ )
254
+
255
+ self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
256
+
257
+ def forward_lyric_encoder(
258
+ self,
259
+ lyric_token_idx: Optional[torch.LongTensor] = None,
260
+ lyric_mask: Optional[torch.LongTensor] = None,
261
+ out_dtype=None,
262
+ ):
263
+ # N x T x D
264
+ lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
265
+ prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
266
+ prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
267
+ return prompt_prenet_out
268
+
269
+ def encode(
270
+ self,
271
+ encoder_text_hidden_states: Optional[torch.Tensor] = None,
272
+ text_attention_mask: Optional[torch.LongTensor] = None,
273
+ speaker_embeds: Optional[torch.FloatTensor] = None,
274
+ lyric_token_idx: Optional[torch.LongTensor] = None,
275
+ lyric_mask: Optional[torch.LongTensor] = None,
276
+ lyrics_strength=1.0,
277
+ ):
278
+
279
+ bs = encoder_text_hidden_states.shape[0]
280
+ device = encoder_text_hidden_states.device
281
+
282
+ # speaker embedding
283
+ encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
284
+
285
+ # genre embedding
286
+ encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
287
+
288
+ # lyric
289
+ encoder_lyric_hidden_states = self.forward_lyric_encoder(
290
+ lyric_token_idx=lyric_token_idx,
291
+ lyric_mask=lyric_mask,
292
+ out_dtype=encoder_text_hidden_states.dtype,
293
+ )
294
+
295
+ encoder_lyric_hidden_states *= lyrics_strength
296
+
297
+ encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
298
+
299
+ encoder_hidden_mask = None
300
+ if text_attention_mask is not None:
301
+ speaker_mask = torch.ones(bs, 1, device=device)
302
+ encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
303
+
304
+ return encoder_hidden_states, encoder_hidden_mask
305
+
306
+ def decode(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ attention_mask: torch.Tensor,
310
+ encoder_hidden_states: torch.Tensor,
311
+ encoder_hidden_mask: torch.Tensor,
312
+ timestep: Optional[torch.Tensor],
313
+ output_length: int = 0,
314
+ block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
315
+ controlnet_scale: Union[float, torch.Tensor] = 1.0,
316
+ ):
317
+ embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
318
+ temb = self.t_block(embedded_timestep)
319
+
320
+ hidden_states = self.proj_in(hidden_states)
321
+
322
+ # controlnet logic
323
+ if block_controlnet_hidden_states is not None:
324
+ control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
325
+ hidden_states = hidden_states + control_condi * controlnet_scale
326
+
327
+ # inner_hidden_states = []
328
+
329
+ rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
330
+ encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
331
+
332
+ for index_block, block in enumerate(self.transformer_blocks):
333
+ hidden_states = block(
334
+ hidden_states=hidden_states,
335
+ attention_mask=attention_mask,
336
+ encoder_hidden_states=encoder_hidden_states,
337
+ encoder_attention_mask=encoder_hidden_mask,
338
+ rotary_freqs_cis=rotary_freqs_cis,
339
+ rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
340
+ temb=temb,
341
+ )
342
+
343
+ output = self.final_layer(hidden_states, embedded_timestep, output_length)
344
+ return output
345
+
346
+ def forward(
347
+ self,
348
+ x,
349
+ timestep,
350
+ attention_mask=None,
351
+ context: Optional[torch.Tensor] = None,
352
+ text_attention_mask: Optional[torch.LongTensor] = None,
353
+ speaker_embeds: Optional[torch.FloatTensor] = None,
354
+ lyric_token_idx: Optional[torch.LongTensor] = None,
355
+ lyric_mask: Optional[torch.LongTensor] = None,
356
+ block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
357
+ controlnet_scale: Union[float, torch.Tensor] = 1.0,
358
+ lyrics_strength=1.0,
359
+ **kwargs
360
+ ):
361
+ hidden_states = x
362
+ encoder_text_hidden_states = context
363
+ encoder_hidden_states, encoder_hidden_mask = self.encode(
364
+ encoder_text_hidden_states=encoder_text_hidden_states,
365
+ text_attention_mask=text_attention_mask,
366
+ speaker_embeds=speaker_embeds,
367
+ lyric_token_idx=lyric_token_idx,
368
+ lyric_mask=lyric_mask,
369
+ lyrics_strength=lyrics_strength,
370
+ )
371
+
372
+ output_length = hidden_states.shape[-1]
373
+
374
+ output = self.decode(
375
+ hidden_states=hidden_states,
376
+ attention_mask=attention_mask,
377
+ encoder_hidden_states=encoder_hidden_states,
378
+ encoder_hidden_mask=encoder_hidden_mask,
379
+ timestep=timestep,
380
+ output_length=output_length,
381
+ block_controlnet_hidden_states=block_controlnet_hidden_states,
382
+ controlnet_scale=controlnet_scale,
383
+ )
384
+
385
+ return output
ComfyUI/comfy/ldm/ace/vae/music_log_mel.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ import logging
6
+ try:
7
+ from torchaudio.transforms import MelScale
8
+ except:
9
+ logging.warning("torchaudio missing, ACE model will be broken")
10
+
11
+ import comfy.model_management
12
+
13
+ class LinearSpectrogram(nn.Module):
14
+ def __init__(
15
+ self,
16
+ n_fft=2048,
17
+ win_length=2048,
18
+ hop_length=512,
19
+ center=False,
20
+ mode="pow2_sqrt",
21
+ ):
22
+ super().__init__()
23
+
24
+ self.n_fft = n_fft
25
+ self.win_length = win_length
26
+ self.hop_length = hop_length
27
+ self.center = center
28
+ self.mode = mode
29
+
30
+ self.register_buffer("window", torch.hann_window(win_length))
31
+
32
+ def forward(self, y: Tensor) -> Tensor:
33
+ if y.ndim == 3:
34
+ y = y.squeeze(1)
35
+
36
+ y = torch.nn.functional.pad(
37
+ y.unsqueeze(1),
38
+ (
39
+ (self.win_length - self.hop_length) // 2,
40
+ (self.win_length - self.hop_length + 1) // 2,
41
+ ),
42
+ mode="reflect",
43
+ ).squeeze(1)
44
+ dtype = y.dtype
45
+ spec = torch.stft(
46
+ y.float(),
47
+ self.n_fft,
48
+ hop_length=self.hop_length,
49
+ win_length=self.win_length,
50
+ window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
51
+ center=self.center,
52
+ pad_mode="reflect",
53
+ normalized=False,
54
+ onesided=True,
55
+ return_complex=True,
56
+ )
57
+ spec = torch.view_as_real(spec)
58
+
59
+ if self.mode == "pow2_sqrt":
60
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
61
+ spec = spec.to(dtype)
62
+ return spec
63
+
64
+
65
+ class LogMelSpectrogram(nn.Module):
66
+ def __init__(
67
+ self,
68
+ sample_rate=44100,
69
+ n_fft=2048,
70
+ win_length=2048,
71
+ hop_length=512,
72
+ n_mels=128,
73
+ center=False,
74
+ f_min=0.0,
75
+ f_max=None,
76
+ ):
77
+ super().__init__()
78
+
79
+ self.sample_rate = sample_rate
80
+ self.n_fft = n_fft
81
+ self.win_length = win_length
82
+ self.hop_length = hop_length
83
+ self.center = center
84
+ self.n_mels = n_mels
85
+ self.f_min = f_min
86
+ self.f_max = f_max or sample_rate // 2
87
+
88
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
89
+ self.mel_scale = MelScale(
90
+ self.n_mels,
91
+ self.sample_rate,
92
+ self.f_min,
93
+ self.f_max,
94
+ self.n_fft // 2 + 1,
95
+ "slaney",
96
+ "slaney",
97
+ )
98
+
99
+ def compress(self, x: Tensor) -> Tensor:
100
+ return torch.log(torch.clamp(x, min=1e-5))
101
+
102
+ def decompress(self, x: Tensor) -> Tensor:
103
+ return torch.exp(x)
104
+
105
+ def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
106
+ linear = self.spectrogram(x)
107
+ x = self.mel_scale(linear)
108
+ x = self.compress(x)
109
+ # print(x.shape)
110
+ if return_linear:
111
+ return x, self.compress(linear)
112
+
113
+ return x
ComfyUI/comfy/ldm/audio/autoencoder.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ import torch
4
+ from torch import nn
5
+ from typing import Literal
6
+ import math
7
+ import comfy.ops
8
+ ops = comfy.ops.disable_weight_init
9
+
10
+ def vae_sample(mean, scale):
11
+ stdev = nn.functional.softplus(scale) + 1e-4
12
+ var = stdev * stdev
13
+ logvar = torch.log(var)
14
+ latents = torch.randn_like(mean) * stdev + mean
15
+
16
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
17
+
18
+ return latents, kl
19
+
20
+ class VAEBottleneck(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.is_discrete = False
24
+
25
+ def encode(self, x, return_info=False, **kwargs):
26
+ info = {}
27
+
28
+ mean, scale = x.chunk(2, dim=1)
29
+
30
+ x, kl = vae_sample(mean, scale)
31
+
32
+ info["kl"] = kl
33
+
34
+ if return_info:
35
+ return x, info
36
+ else:
37
+ return x
38
+
39
+ def decode(self, x):
40
+ return x
41
+
42
+
43
+ def snake_beta(x, alpha, beta):
44
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
45
+
46
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
47
+ class SnakeBeta(nn.Module):
48
+
49
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
50
+ super(SnakeBeta, self).__init__()
51
+ self.in_features = in_features
52
+
53
+ # initialize alpha
54
+ self.alpha_logscale = alpha_logscale
55
+ if self.alpha_logscale: # log scale alphas initialized to zeros
56
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
58
+ else: # linear scale alphas initialized to ones
59
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
60
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
61
+
62
+ # self.alpha.requires_grad = alpha_trainable
63
+ # self.beta.requires_grad = alpha_trainable
64
+
65
+ self.no_div_by_zero = 0.000000001
66
+
67
+ def forward(self, x):
68
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
69
+ beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
70
+ if self.alpha_logscale:
71
+ alpha = torch.exp(alpha)
72
+ beta = torch.exp(beta)
73
+ x = snake_beta(x, alpha, beta)
74
+
75
+ return x
76
+
77
+ def WNConv1d(*args, **kwargs):
78
+ return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
79
+
80
+ def WNConvTranspose1d(*args, **kwargs):
81
+ return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
82
+
83
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
84
+ if activation == "elu":
85
+ act = torch.nn.ELU()
86
+ elif activation == "snake":
87
+ act = SnakeBeta(channels)
88
+ elif activation == "none":
89
+ act = torch.nn.Identity()
90
+ else:
91
+ raise ValueError(f"Unknown activation {activation}")
92
+
93
+ if antialias:
94
+ act = Activation1d(act) # noqa: F821 Activation1d is not defined
95
+
96
+ return act
97
+
98
+
99
+ class ResidualUnit(nn.Module):
100
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
101
+ super().__init__()
102
+
103
+ self.dilation = dilation
104
+
105
+ padding = (dilation * (7-1)) // 2
106
+
107
+ self.layers = nn.Sequential(
108
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
109
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
110
+ kernel_size=7, dilation=dilation, padding=padding),
111
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
112
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
113
+ kernel_size=1)
114
+ )
115
+
116
+ def forward(self, x):
117
+ res = x
118
+
119
+ #x = checkpoint(self.layers, x)
120
+ x = self.layers(x)
121
+
122
+ return x + res
123
+
124
+ class EncoderBlock(nn.Module):
125
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
126
+ super().__init__()
127
+
128
+ self.layers = nn.Sequential(
129
+ ResidualUnit(in_channels=in_channels,
130
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
131
+ ResidualUnit(in_channels=in_channels,
132
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
133
+ ResidualUnit(in_channels=in_channels,
134
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
135
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
136
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
137
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
138
+ )
139
+
140
+ def forward(self, x):
141
+ return self.layers(x)
142
+
143
+ class DecoderBlock(nn.Module):
144
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
145
+ super().__init__()
146
+
147
+ if use_nearest_upsample:
148
+ upsample_layer = nn.Sequential(
149
+ nn.Upsample(scale_factor=stride, mode="nearest"),
150
+ WNConv1d(in_channels=in_channels,
151
+ out_channels=out_channels,
152
+ kernel_size=2*stride,
153
+ stride=1,
154
+ bias=False,
155
+ padding='same')
156
+ )
157
+ else:
158
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
159
+ out_channels=out_channels,
160
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
161
+
162
+ self.layers = nn.Sequential(
163
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
164
+ upsample_layer,
165
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
166
+ dilation=1, use_snake=use_snake),
167
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
168
+ dilation=3, use_snake=use_snake),
169
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
170
+ dilation=9, use_snake=use_snake),
171
+ )
172
+
173
+ def forward(self, x):
174
+ return self.layers(x)
175
+
176
+ class OobleckEncoder(nn.Module):
177
+ def __init__(self,
178
+ in_channels=2,
179
+ channels=128,
180
+ latent_dim=32,
181
+ c_mults = [1, 2, 4, 8],
182
+ strides = [2, 4, 8, 8],
183
+ use_snake=False,
184
+ antialias_activation=False
185
+ ):
186
+ super().__init__()
187
+
188
+ c_mults = [1] + c_mults
189
+
190
+ self.depth = len(c_mults)
191
+
192
+ layers = [
193
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
194
+ ]
195
+
196
+ for i in range(self.depth-1):
197
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
198
+
199
+ layers += [
200
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
201
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
202
+ ]
203
+
204
+ self.layers = nn.Sequential(*layers)
205
+
206
+ def forward(self, x):
207
+ return self.layers(x)
208
+
209
+
210
+ class OobleckDecoder(nn.Module):
211
+ def __init__(self,
212
+ out_channels=2,
213
+ channels=128,
214
+ latent_dim=32,
215
+ c_mults = [1, 2, 4, 8],
216
+ strides = [2, 4, 8, 8],
217
+ use_snake=False,
218
+ antialias_activation=False,
219
+ use_nearest_upsample=False,
220
+ final_tanh=True):
221
+ super().__init__()
222
+
223
+ c_mults = [1] + c_mults
224
+
225
+ self.depth = len(c_mults)
226
+
227
+ layers = [
228
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
229
+ ]
230
+
231
+ for i in range(self.depth-1, 0, -1):
232
+ layers += [DecoderBlock(
233
+ in_channels=c_mults[i]*channels,
234
+ out_channels=c_mults[i-1]*channels,
235
+ stride=strides[i-1],
236
+ use_snake=use_snake,
237
+ antialias_activation=antialias_activation,
238
+ use_nearest_upsample=use_nearest_upsample
239
+ )
240
+ ]
241
+
242
+ layers += [
243
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
244
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
245
+ nn.Tanh() if final_tanh else nn.Identity()
246
+ ]
247
+
248
+ self.layers = nn.Sequential(*layers)
249
+
250
+ def forward(self, x):
251
+ return self.layers(x)
252
+
253
+
254
+ class AudioOobleckVAE(nn.Module):
255
+ def __init__(self,
256
+ in_channels=2,
257
+ channels=128,
258
+ latent_dim=64,
259
+ c_mults = [1, 2, 4, 8, 16],
260
+ strides = [2, 4, 4, 8, 8],
261
+ use_snake=True,
262
+ antialias_activation=False,
263
+ use_nearest_upsample=False,
264
+ final_tanh=False):
265
+ super().__init__()
266
+ self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
267
+ self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
268
+ use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
269
+ self.bottleneck = VAEBottleneck()
270
+
271
+ def encode(self, x):
272
+ return self.bottleneck.encode(self.encoder(x))
273
+
274
+ def decode(self, x):
275
+ return self.decoder(self.bottleneck.decode(x))
276
+
ComfyUI/comfy/ldm/audio/dit.py ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ from comfy.ldm.modules.attention import optimized_attention
4
+ import typing as tp
5
+
6
+ import torch
7
+
8
+ from einops import rearrange
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ import math
12
+ import comfy.ops
13
+
14
+ class FourierFeatures(nn.Module):
15
+ def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
16
+ super().__init__()
17
+ assert out_features % 2 == 0
18
+ self.weight = nn.Parameter(torch.empty(
19
+ [out_features // 2, in_features], dtype=dtype, device=device))
20
+
21
+ def forward(self, input):
22
+ f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
23
+ return torch.cat([f.cos(), f.sin()], dim=-1)
24
+
25
+ # norms
26
+ class LayerNorm(nn.Module):
27
+ def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
28
+ """
29
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
30
+ """
31
+ super().__init__()
32
+
33
+ self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
34
+
35
+ if bias:
36
+ self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
37
+ else:
38
+ self.beta = None
39
+
40
+ def forward(self, x):
41
+ beta = self.beta
42
+ if beta is not None:
43
+ beta = comfy.ops.cast_to_input(beta, x)
44
+ return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
45
+
46
+ class GLU(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim_in,
50
+ dim_out,
51
+ activation,
52
+ use_conv = False,
53
+ conv_kernel_size = 3,
54
+ dtype=None,
55
+ device=None,
56
+ operations=None,
57
+ ):
58
+ super().__init__()
59
+ self.act = activation
60
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
61
+ self.use_conv = use_conv
62
+
63
+ def forward(self, x):
64
+ if self.use_conv:
65
+ x = rearrange(x, 'b n d -> b d n')
66
+ x = self.proj(x)
67
+ x = rearrange(x, 'b d n -> b n d')
68
+ else:
69
+ x = self.proj(x)
70
+
71
+ x, gate = x.chunk(2, dim = -1)
72
+ return x * self.act(gate)
73
+
74
+ class AbsolutePositionalEmbedding(nn.Module):
75
+ def __init__(self, dim, max_seq_len):
76
+ super().__init__()
77
+ self.scale = dim ** -0.5
78
+ self.max_seq_len = max_seq_len
79
+ self.emb = nn.Embedding(max_seq_len, dim)
80
+
81
+ def forward(self, x, pos = None, seq_start_pos = None):
82
+ seq_len, device = x.shape[1], x.device
83
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
84
+
85
+ if pos is None:
86
+ pos = torch.arange(seq_len, device = device)
87
+
88
+ if seq_start_pos is not None:
89
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
90
+
91
+ pos_emb = self.emb(pos)
92
+ pos_emb = pos_emb * self.scale
93
+ return pos_emb
94
+
95
+ class ScaledSinusoidalEmbedding(nn.Module):
96
+ def __init__(self, dim, theta = 10000):
97
+ super().__init__()
98
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
99
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
100
+
101
+ half_dim = dim // 2
102
+ freq_seq = torch.arange(half_dim).float() / half_dim
103
+ inv_freq = theta ** -freq_seq
104
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
105
+
106
+ def forward(self, x, pos = None, seq_start_pos = None):
107
+ seq_len, device = x.shape[1], x.device
108
+
109
+ if pos is None:
110
+ pos = torch.arange(seq_len, device = device)
111
+
112
+ if seq_start_pos is not None:
113
+ pos = pos - seq_start_pos[..., None]
114
+
115
+ emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
116
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
117
+ return emb * self.scale
118
+
119
+ class RotaryEmbedding(nn.Module):
120
+ def __init__(
121
+ self,
122
+ dim,
123
+ use_xpos = False,
124
+ scale_base = 512,
125
+ interpolation_factor = 1.,
126
+ base = 10000,
127
+ base_rescale_factor = 1.,
128
+ dtype=None,
129
+ device=None,
130
+ ):
131
+ super().__init__()
132
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
133
+ # has some connection to NTK literature
134
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
135
+ base *= base_rescale_factor ** (dim / (dim - 2))
136
+
137
+ # inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
138
+ self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
139
+
140
+ assert interpolation_factor >= 1.
141
+ self.interpolation_factor = interpolation_factor
142
+
143
+ if not use_xpos:
144
+ self.register_buffer('scale', None)
145
+ return
146
+
147
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
148
+
149
+ self.scale_base = scale_base
150
+ self.register_buffer('scale', scale)
151
+
152
+ def forward_from_seq_len(self, seq_len, device, dtype):
153
+ # device = self.inv_freq.device
154
+
155
+ t = torch.arange(seq_len, device=device, dtype=dtype)
156
+ return self.forward(t)
157
+
158
+ def forward(self, t):
159
+ # device = self.inv_freq.device
160
+ device = t.device
161
+
162
+ # t = t.to(torch.float32)
163
+
164
+ t = t / self.interpolation_factor
165
+
166
+ freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
167
+ freqs = torch.cat((freqs, freqs), dim = -1)
168
+
169
+ if self.scale is None:
170
+ return freqs, 1.
171
+
172
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
173
+ scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
174
+ scale = torch.cat((scale, scale), dim = -1)
175
+
176
+ return freqs, scale
177
+
178
+ def rotate_half(x):
179
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
180
+ x1, x2 = x.unbind(dim = -2)
181
+ return torch.cat((-x2, x1), dim = -1)
182
+
183
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
184
+ out_dtype = t.dtype
185
+
186
+ # cast to float32 if necessary for numerical stability
187
+ dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
188
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
189
+ freqs, t = freqs.to(dtype), t.to(dtype)
190
+ freqs = freqs[-seq_len:, :]
191
+
192
+ if t.ndim == 4 and freqs.ndim == 3:
193
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
194
+
195
+ # partial rotary embeddings, Wang et al. GPT-J
196
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
197
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
198
+
199
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
200
+
201
+ return torch.cat((t, t_unrotated), dim = -1)
202
+
203
+ class FeedForward(nn.Module):
204
+ def __init__(
205
+ self,
206
+ dim,
207
+ dim_out = None,
208
+ mult = 4,
209
+ no_bias = False,
210
+ glu = True,
211
+ use_conv = False,
212
+ conv_kernel_size = 3,
213
+ zero_init_output = True,
214
+ dtype=None,
215
+ device=None,
216
+ operations=None,
217
+ ):
218
+ super().__init__()
219
+ inner_dim = int(dim * mult)
220
+
221
+ # Default to SwiGLU
222
+
223
+ activation = nn.SiLU()
224
+
225
+ dim_out = dim if dim_out is None else dim_out
226
+
227
+ if glu:
228
+ linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
229
+ else:
230
+ linear_in = nn.Sequential(
231
+ rearrange('b n d -> b d n') if use_conv else nn.Identity(),
232
+ operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
233
+ rearrange('b n d -> b d n') if use_conv else nn.Identity(),
234
+ activation
235
+ )
236
+
237
+ linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
238
+
239
+ # # init last linear layer to 0
240
+ # if zero_init_output:
241
+ # nn.init.zeros_(linear_out.weight)
242
+ # if not no_bias:
243
+ # nn.init.zeros_(linear_out.bias)
244
+
245
+
246
+ self.ff = nn.Sequential(
247
+ linear_in,
248
+ rearrange('b d n -> b n d') if use_conv else nn.Identity(),
249
+ linear_out,
250
+ rearrange('b n d -> b d n') if use_conv else nn.Identity(),
251
+ )
252
+
253
+ def forward(self, x):
254
+ return self.ff(x)
255
+
256
+ class Attention(nn.Module):
257
+ def __init__(
258
+ self,
259
+ dim,
260
+ dim_heads = 64,
261
+ dim_context = None,
262
+ causal = False,
263
+ zero_init_output=True,
264
+ qk_norm = False,
265
+ natten_kernel_size = None,
266
+ dtype=None,
267
+ device=None,
268
+ operations=None,
269
+ ):
270
+ super().__init__()
271
+ self.dim = dim
272
+ self.dim_heads = dim_heads
273
+ self.causal = causal
274
+
275
+ dim_kv = dim_context if dim_context is not None else dim
276
+
277
+ self.num_heads = dim // dim_heads
278
+ self.kv_heads = dim_kv // dim_heads
279
+
280
+ if dim_context is not None:
281
+ self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
282
+ self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
283
+ else:
284
+ self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
285
+
286
+ self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
287
+
288
+ # if zero_init_output:
289
+ # nn.init.zeros_(self.to_out.weight)
290
+
291
+ self.qk_norm = qk_norm
292
+
293
+
294
+ def forward(
295
+ self,
296
+ x,
297
+ context = None,
298
+ mask = None,
299
+ context_mask = None,
300
+ rotary_pos_emb = None,
301
+ causal = None
302
+ ):
303
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
304
+
305
+ kv_input = context if has_context else x
306
+
307
+ if hasattr(self, 'to_q'):
308
+ # Use separate linear projections for q and k/v
309
+ q = self.to_q(x)
310
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
311
+
312
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
313
+
314
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
315
+ else:
316
+ # Use fused linear projection
317
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
318
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
319
+
320
+ # Normalize q and k for cosine sim attention
321
+ if self.qk_norm:
322
+ q = F.normalize(q, dim=-1)
323
+ k = F.normalize(k, dim=-1)
324
+
325
+ if rotary_pos_emb is not None and not has_context:
326
+ freqs, _ = rotary_pos_emb
327
+
328
+ q_dtype = q.dtype
329
+ k_dtype = k.dtype
330
+
331
+ q = q.to(torch.float32)
332
+ k = k.to(torch.float32)
333
+ freqs = freqs.to(torch.float32)
334
+
335
+ q = apply_rotary_pos_emb(q, freqs)
336
+ k = apply_rotary_pos_emb(k, freqs)
337
+
338
+ q = q.to(q_dtype)
339
+ k = k.to(k_dtype)
340
+
341
+ input_mask = context_mask
342
+
343
+ if input_mask is None and not has_context:
344
+ input_mask = mask
345
+
346
+ # determine masking
347
+ masks = []
348
+
349
+ if input_mask is not None:
350
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
351
+ masks.append(~input_mask)
352
+
353
+ # Other masks will be added here later
354
+ n = q.shape[-2]
355
+
356
+ causal = self.causal if causal is None else causal
357
+
358
+ if n == 1 and causal:
359
+ causal = False
360
+
361
+ if h != kv_h:
362
+ # Repeat interleave kv_heads to match q_heads
363
+ heads_per_kv_head = h // kv_h
364
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
365
+
366
+ out = optimized_attention(q, k, v, h, skip_reshape=True)
367
+ out = self.to_out(out)
368
+
369
+ if mask is not None:
370
+ mask = rearrange(mask, 'b n -> b n 1')
371
+ out = out.masked_fill(~mask, 0.)
372
+
373
+ return out
374
+
375
+ class ConformerModule(nn.Module):
376
+ def __init__(
377
+ self,
378
+ dim,
379
+ norm_kwargs = {},
380
+ ):
381
+
382
+ super().__init__()
383
+
384
+ self.dim = dim
385
+
386
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
387
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
388
+ self.glu = GLU(dim, dim, nn.SiLU())
389
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
390
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
391
+ self.swish = nn.SiLU()
392
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
393
+
394
+ def forward(self, x):
395
+ x = self.in_norm(x)
396
+ x = rearrange(x, 'b n d -> b d n')
397
+ x = self.pointwise_conv(x)
398
+ x = rearrange(x, 'b d n -> b n d')
399
+ x = self.glu(x)
400
+ x = rearrange(x, 'b n d -> b d n')
401
+ x = self.depthwise_conv(x)
402
+ x = rearrange(x, 'b d n -> b n d')
403
+ x = self.mid_norm(x)
404
+ x = self.swish(x)
405
+ x = rearrange(x, 'b n d -> b d n')
406
+ x = self.pointwise_conv_2(x)
407
+ x = rearrange(x, 'b d n -> b n d')
408
+
409
+ return x
410
+
411
+ class TransformerBlock(nn.Module):
412
+ def __init__(
413
+ self,
414
+ dim,
415
+ dim_heads = 64,
416
+ cross_attend = False,
417
+ dim_context = None,
418
+ global_cond_dim = None,
419
+ causal = False,
420
+ zero_init_branch_outputs = True,
421
+ conformer = False,
422
+ layer_ix = -1,
423
+ remove_norms = False,
424
+ attn_kwargs = {},
425
+ ff_kwargs = {},
426
+ norm_kwargs = {},
427
+ dtype=None,
428
+ device=None,
429
+ operations=None,
430
+ ):
431
+
432
+ super().__init__()
433
+ self.dim = dim
434
+ self.dim_heads = dim_heads
435
+ self.cross_attend = cross_attend
436
+ self.dim_context = dim_context
437
+ self.causal = causal
438
+
439
+ self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
440
+
441
+ self.self_attn = Attention(
442
+ dim,
443
+ dim_heads = dim_heads,
444
+ causal = causal,
445
+ zero_init_output=zero_init_branch_outputs,
446
+ dtype=dtype,
447
+ device=device,
448
+ operations=operations,
449
+ **attn_kwargs
450
+ )
451
+
452
+ if cross_attend:
453
+ self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
454
+ self.cross_attn = Attention(
455
+ dim,
456
+ dim_heads = dim_heads,
457
+ dim_context=dim_context,
458
+ causal = causal,
459
+ zero_init_output=zero_init_branch_outputs,
460
+ dtype=dtype,
461
+ device=device,
462
+ operations=operations,
463
+ **attn_kwargs
464
+ )
465
+
466
+ self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
467
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
468
+
469
+ self.layer_ix = layer_ix
470
+
471
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
472
+
473
+ self.global_cond_dim = global_cond_dim
474
+
475
+ if global_cond_dim is not None:
476
+ self.to_scale_shift_gate = nn.Sequential(
477
+ nn.SiLU(),
478
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
479
+ )
480
+
481
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
482
+ #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
483
+
484
+ def forward(
485
+ self,
486
+ x,
487
+ context = None,
488
+ global_cond=None,
489
+ mask = None,
490
+ context_mask = None,
491
+ rotary_pos_emb = None
492
+ ):
493
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
494
+
495
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
496
+
497
+ # self-attention with adaLN
498
+ residual = x
499
+ x = self.pre_norm(x)
500
+ x = x * (1 + scale_self) + shift_self
501
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
502
+ x = x * torch.sigmoid(1 - gate_self)
503
+ x = x + residual
504
+
505
+ if context is not None:
506
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
507
+
508
+ if self.conformer is not None:
509
+ x = x + self.conformer(x)
510
+
511
+ # feedforward with adaLN
512
+ residual = x
513
+ x = self.ff_norm(x)
514
+ x = x * (1 + scale_ff) + shift_ff
515
+ x = self.ff(x)
516
+ x = x * torch.sigmoid(1 - gate_ff)
517
+ x = x + residual
518
+
519
+ else:
520
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
521
+
522
+ if context is not None:
523
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
524
+
525
+ if self.conformer is not None:
526
+ x = x + self.conformer(x)
527
+
528
+ x = x + self.ff(self.ff_norm(x))
529
+
530
+ return x
531
+
532
+ class ContinuousTransformer(nn.Module):
533
+ def __init__(
534
+ self,
535
+ dim,
536
+ depth,
537
+ *,
538
+ dim_in = None,
539
+ dim_out = None,
540
+ dim_heads = 64,
541
+ cross_attend=False,
542
+ cond_token_dim=None,
543
+ global_cond_dim=None,
544
+ causal=False,
545
+ rotary_pos_emb=True,
546
+ zero_init_branch_outputs=True,
547
+ conformer=False,
548
+ use_sinusoidal_emb=False,
549
+ use_abs_pos_emb=False,
550
+ abs_pos_emb_max_length=10000,
551
+ dtype=None,
552
+ device=None,
553
+ operations=None,
554
+ **kwargs
555
+ ):
556
+
557
+ super().__init__()
558
+
559
+ self.dim = dim
560
+ self.depth = depth
561
+ self.causal = causal
562
+ self.layers = nn.ModuleList([])
563
+
564
+ self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
565
+ self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
566
+
567
+ if rotary_pos_emb:
568
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
569
+ else:
570
+ self.rotary_pos_emb = None
571
+
572
+ self.use_sinusoidal_emb = use_sinusoidal_emb
573
+ if use_sinusoidal_emb:
574
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
575
+
576
+ self.use_abs_pos_emb = use_abs_pos_emb
577
+ if use_abs_pos_emb:
578
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
579
+
580
+ for i in range(depth):
581
+ self.layers.append(
582
+ TransformerBlock(
583
+ dim,
584
+ dim_heads = dim_heads,
585
+ cross_attend = cross_attend,
586
+ dim_context = cond_token_dim,
587
+ global_cond_dim = global_cond_dim,
588
+ causal = causal,
589
+ zero_init_branch_outputs = zero_init_branch_outputs,
590
+ conformer=conformer,
591
+ layer_ix=i,
592
+ dtype=dtype,
593
+ device=device,
594
+ operations=operations,
595
+ **kwargs
596
+ )
597
+ )
598
+
599
+ def forward(
600
+ self,
601
+ x,
602
+ mask = None,
603
+ prepend_embeds = None,
604
+ prepend_mask = None,
605
+ global_cond = None,
606
+ return_info = False,
607
+ **kwargs
608
+ ):
609
+ patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
610
+ batch, seq, device = *x.shape[:2], x.device
611
+ context = kwargs["context"]
612
+
613
+ info = {
614
+ "hidden_states": [],
615
+ }
616
+
617
+ x = self.project_in(x)
618
+
619
+ if prepend_embeds is not None:
620
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
621
+
622
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
623
+
624
+ x = torch.cat((prepend_embeds, x), dim = -2)
625
+
626
+ if prepend_mask is not None or mask is not None:
627
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
628
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
629
+
630
+ mask = torch.cat((prepend_mask, mask), dim = -1)
631
+
632
+ # Attention layers
633
+
634
+ if self.rotary_pos_emb is not None:
635
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
636
+ else:
637
+ rotary_pos_emb = None
638
+
639
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
640
+ x = x + self.pos_emb(x)
641
+
642
+ blocks_replace = patches_replace.get("dit", {})
643
+ # Iterate over the transformer layers
644
+ for i, layer in enumerate(self.layers):
645
+ if ("double_block", i) in blocks_replace:
646
+ def block_wrap(args):
647
+ out = {}
648
+ out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
649
+ return out
650
+
651
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
652
+ x = out["img"]
653
+ else:
654
+ x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
655
+ # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
656
+
657
+ if return_info:
658
+ info["hidden_states"].append(x)
659
+
660
+ x = self.project_out(x)
661
+
662
+ if return_info:
663
+ return x, info
664
+
665
+ return x
666
+
667
+ class AudioDiffusionTransformer(nn.Module):
668
+ def __init__(self,
669
+ io_channels=64,
670
+ patch_size=1,
671
+ embed_dim=1536,
672
+ cond_token_dim=768,
673
+ project_cond_tokens=False,
674
+ global_cond_dim=1536,
675
+ project_global_cond=True,
676
+ input_concat_dim=0,
677
+ prepend_cond_dim=0,
678
+ depth=24,
679
+ num_heads=24,
680
+ transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
681
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
682
+ audio_model="",
683
+ dtype=None,
684
+ device=None,
685
+ operations=None,
686
+ **kwargs):
687
+
688
+ super().__init__()
689
+
690
+ self.dtype = dtype
691
+ self.cond_token_dim = cond_token_dim
692
+
693
+ # Timestep embeddings
694
+ timestep_features_dim = 256
695
+
696
+ self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
697
+
698
+ self.to_timestep_embed = nn.Sequential(
699
+ operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
700
+ nn.SiLU(),
701
+ operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
702
+ )
703
+
704
+ if cond_token_dim > 0:
705
+ # Conditioning tokens
706
+
707
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
708
+ self.to_cond_embed = nn.Sequential(
709
+ operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
710
+ nn.SiLU(),
711
+ operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
712
+ )
713
+ else:
714
+ cond_embed_dim = 0
715
+
716
+ if global_cond_dim > 0:
717
+ # Global conditioning
718
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
719
+ self.to_global_embed = nn.Sequential(
720
+ operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
721
+ nn.SiLU(),
722
+ operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
723
+ )
724
+
725
+ if prepend_cond_dim > 0:
726
+ # Prepend conditioning
727
+ self.to_prepend_embed = nn.Sequential(
728
+ operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
729
+ nn.SiLU(),
730
+ operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
731
+ )
732
+
733
+ self.input_concat_dim = input_concat_dim
734
+
735
+ dim_in = io_channels + self.input_concat_dim
736
+
737
+ self.patch_size = patch_size
738
+
739
+ # Transformer
740
+
741
+ self.transformer_type = transformer_type
742
+
743
+ self.global_cond_type = global_cond_type
744
+
745
+ if self.transformer_type == "continuous_transformer":
746
+
747
+ global_dim = None
748
+
749
+ if self.global_cond_type == "adaLN":
750
+ # The global conditioning is projected to the embed_dim already at this point
751
+ global_dim = embed_dim
752
+
753
+ self.transformer = ContinuousTransformer(
754
+ dim=embed_dim,
755
+ depth=depth,
756
+ dim_heads=embed_dim // num_heads,
757
+ dim_in=dim_in * patch_size,
758
+ dim_out=io_channels * patch_size,
759
+ cross_attend = cond_token_dim > 0,
760
+ cond_token_dim = cond_embed_dim,
761
+ global_cond_dim=global_dim,
762
+ dtype=dtype,
763
+ device=device,
764
+ operations=operations,
765
+ **kwargs
766
+ )
767
+ else:
768
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
769
+
770
+ self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
771
+ self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
772
+
773
+ def _forward(
774
+ self,
775
+ x,
776
+ t,
777
+ mask=None,
778
+ cross_attn_cond=None,
779
+ cross_attn_cond_mask=None,
780
+ input_concat_cond=None,
781
+ global_embed=None,
782
+ prepend_cond=None,
783
+ prepend_cond_mask=None,
784
+ return_info=False,
785
+ **kwargs):
786
+
787
+ if cross_attn_cond is not None:
788
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
789
+
790
+ if global_embed is not None:
791
+ # Project the global conditioning to the embedding dimension
792
+ global_embed = self.to_global_embed(global_embed)
793
+
794
+ prepend_inputs = None
795
+ prepend_mask = None
796
+ prepend_length = 0
797
+ if prepend_cond is not None:
798
+ # Project the prepend conditioning to the embedding dimension
799
+ prepend_cond = self.to_prepend_embed(prepend_cond)
800
+
801
+ prepend_inputs = prepend_cond
802
+ if prepend_cond_mask is not None:
803
+ prepend_mask = prepend_cond_mask
804
+
805
+ if input_concat_cond is not None:
806
+
807
+ # Interpolate input_concat_cond to the same length as x
808
+ if input_concat_cond.shape[2] != x.shape[2]:
809
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
810
+
811
+ x = torch.cat([x, input_concat_cond], dim=1)
812
+
813
+ # Get the batch of timestep embeddings
814
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
815
+
816
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
817
+ if global_embed is not None:
818
+ global_embed = global_embed + timestep_embed
819
+ else:
820
+ global_embed = timestep_embed
821
+
822
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
823
+ if self.global_cond_type == "prepend":
824
+ if prepend_inputs is None:
825
+ # Prepend inputs are just the global embed, and the mask is all ones
826
+ prepend_inputs = global_embed.unsqueeze(1)
827
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
828
+ else:
829
+ # Prepend inputs are the prepend conditioning + the global embed
830
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
831
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
832
+
833
+ prepend_length = prepend_inputs.shape[1]
834
+
835
+ x = self.preprocess_conv(x) + x
836
+
837
+ x = rearrange(x, "b c t -> b t c")
838
+
839
+ extra_args = {}
840
+
841
+ if self.global_cond_type == "adaLN":
842
+ extra_args["global_cond"] = global_embed
843
+
844
+ if self.patch_size > 1:
845
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
846
+
847
+ if self.transformer_type == "x-transformers":
848
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
849
+ elif self.transformer_type == "continuous_transformer":
850
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
851
+
852
+ if return_info:
853
+ output, info = output
854
+ elif self.transformer_type == "mm_transformer":
855
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
856
+
857
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
858
+
859
+ if self.patch_size > 1:
860
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
861
+
862
+ output = self.postprocess_conv(output) + output
863
+
864
+ if return_info:
865
+ return output, info
866
+
867
+ return output
868
+
869
+ def forward(
870
+ self,
871
+ x,
872
+ timestep,
873
+ context=None,
874
+ context_mask=None,
875
+ input_concat_cond=None,
876
+ global_embed=None,
877
+ negative_global_embed=None,
878
+ prepend_cond=None,
879
+ prepend_cond_mask=None,
880
+ mask=None,
881
+ return_info=False,
882
+ control=None,
883
+ **kwargs):
884
+ return self._forward(
885
+ x,
886
+ timestep,
887
+ cross_attn_cond=context,
888
+ cross_attn_cond_mask=context_mask,
889
+ input_concat_cond=input_concat_cond,
890
+ global_embed=global_embed,
891
+ prepend_cond=prepend_cond,
892
+ prepend_cond_mask=prepend_cond_mask,
893
+ mask=mask,
894
+ return_info=return_info,
895
+ **kwargs
896
+ )
ComfyUI/comfy/ldm/audio/embedders.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+ from typing import List, Union
7
+ from einops import rearrange
8
+ import math
9
+ import comfy.ops
10
+
11
+ class LearnedPositionalEmbedding(nn.Module):
12
+ """Used for continuous time"""
13
+
14
+ def __init__(self, dim: int):
15
+ super().__init__()
16
+ assert (dim % 2) == 0
17
+ half_dim = dim // 2
18
+ self.weights = nn.Parameter(torch.empty(half_dim))
19
+
20
+ def forward(self, x: Tensor) -> Tensor:
21
+ x = rearrange(x, "b -> b 1")
22
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
23
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
24
+ fouriered = torch.cat((x, fouriered), dim=-1)
25
+ return fouriered
26
+
27
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
28
+ return nn.Sequential(
29
+ LearnedPositionalEmbedding(dim),
30
+ comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
31
+ )
32
+
33
+
34
+ class NumberEmbedder(nn.Module):
35
+ def __init__(
36
+ self,
37
+ features: int,
38
+ dim: int = 256,
39
+ ):
40
+ super().__init__()
41
+ self.features = features
42
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
43
+
44
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
45
+ if not torch.is_tensor(x):
46
+ device = next(self.embedding.parameters()).device
47
+ x = torch.tensor(x, device=device)
48
+ assert isinstance(x, Tensor)
49
+ shape = x.shape
50
+ x = rearrange(x, "... -> (...)")
51
+ embedding = self.embedding(x)
52
+ x = embedding.view(*shape, self.features)
53
+ return x # type: ignore
54
+
55
+
56
+ class Conditioner(nn.Module):
57
+ def __init__(
58
+ self,
59
+ dim: int,
60
+ output_dim: int,
61
+ project_out: bool = False
62
+ ):
63
+
64
+ super().__init__()
65
+
66
+ self.dim = dim
67
+ self.output_dim = output_dim
68
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
69
+
70
+ def forward(self, x):
71
+ raise NotImplementedError()
72
+
73
+ class NumberConditioner(Conditioner):
74
+ '''
75
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
76
+ '''
77
+ def __init__(self,
78
+ output_dim: int,
79
+ min_val: float=0,
80
+ max_val: float=1
81
+ ):
82
+ super().__init__(output_dim, output_dim)
83
+
84
+ self.min_val = min_val
85
+ self.max_val = max_val
86
+
87
+ self.embedder = NumberEmbedder(features=output_dim)
88
+
89
+ def forward(self, floats, device=None):
90
+ # Cast the inputs to floats
91
+ floats = [float(x) for x in floats]
92
+
93
+ if device is None:
94
+ device = next(self.embedder.parameters()).device
95
+
96
+ floats = torch.tensor(floats).to(device)
97
+
98
+ floats = floats.clamp(self.min_val, self.max_val)
99
+
100
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
101
+
102
+ # Cast floats to same type as embedder
103
+ embedder_dtype = next(self.embedder.parameters()).dtype
104
+ normalized_floats = normalized_floats.to(embedder_dtype)
105
+
106
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
107
+
108
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
ComfyUI/comfy/ldm/aura/mmdit.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #AuraFlow MMDiT
2
+ #Originally written by the AuraFlow Authors
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from comfy.ldm.modules.attention import optimized_attention
11
+ import comfy.ops
12
+ import comfy.ldm.common_dit
13
+
14
+ def modulate(x, shift, scale):
15
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
16
+
17
+
18
+ def find_multiple(n: int, k: int) -> int:
19
+ if n % k == 0:
20
+ return n
21
+ return n + k - (n % k)
22
+
23
+
24
+ class MLP(nn.Module):
25
+ def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
26
+ super().__init__()
27
+ if hidden_dim is None:
28
+ hidden_dim = 4 * dim
29
+
30
+ n_hidden = int(2 * hidden_dim / 3)
31
+ n_hidden = find_multiple(n_hidden, 256)
32
+
33
+ self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
34
+ self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
35
+ self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
39
+ x = self.c_proj(x)
40
+ return x
41
+
42
+
43
+ class MultiHeadLayerNorm(nn.Module):
44
+ def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
45
+ # Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
46
+
47
+ super().__init__()
48
+ self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
49
+ self.variance_epsilon = eps
50
+
51
+ def forward(self, hidden_states):
52
+ input_dtype = hidden_states.dtype
53
+ hidden_states = hidden_states.to(torch.float32)
54
+ mean = hidden_states.mean(-1, keepdim=True)
55
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
56
+ hidden_states = (hidden_states - mean) * torch.rsqrt(
57
+ variance + self.variance_epsilon
58
+ )
59
+ hidden_states = self.weight.to(torch.float32) * hidden_states
60
+ return hidden_states.to(input_dtype)
61
+
62
+ class SingleAttention(nn.Module):
63
+ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
64
+ super().__init__()
65
+
66
+ self.n_heads = n_heads
67
+ self.head_dim = dim // n_heads
68
+
69
+ # this is for cond
70
+ self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
71
+ self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
72
+ self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
73
+ self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
74
+
75
+ self.q_norm1 = (
76
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
77
+ if mh_qknorm
78
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
79
+ )
80
+ self.k_norm1 = (
81
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
82
+ if mh_qknorm
83
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
84
+ )
85
+
86
+ #@torch.compile()
87
+ def forward(self, c):
88
+
89
+ bsz, seqlen1, _ = c.shape
90
+
91
+ q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
92
+ q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
93
+ k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
94
+ v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
95
+ q, k = self.q_norm1(q), self.k_norm1(k)
96
+
97
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
98
+ c = self.w1o(output)
99
+ return c
100
+
101
+
102
+
103
+ class DoubleAttention(nn.Module):
104
+ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
105
+ super().__init__()
106
+
107
+ self.n_heads = n_heads
108
+ self.head_dim = dim // n_heads
109
+
110
+ # this is for cond
111
+ self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
112
+ self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
113
+ self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
114
+ self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
115
+
116
+ # this is for x
117
+ self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
118
+ self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
119
+ self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
120
+ self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
121
+
122
+ self.q_norm1 = (
123
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
124
+ if mh_qknorm
125
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
126
+ )
127
+ self.k_norm1 = (
128
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
129
+ if mh_qknorm
130
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
131
+ )
132
+
133
+ self.q_norm2 = (
134
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
135
+ if mh_qknorm
136
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
137
+ )
138
+ self.k_norm2 = (
139
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
140
+ if mh_qknorm
141
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
142
+ )
143
+
144
+
145
+ #@torch.compile()
146
+ def forward(self, c, x):
147
+
148
+ bsz, seqlen1, _ = c.shape
149
+ bsz, seqlen2, _ = x.shape
150
+
151
+ cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
152
+ cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
153
+ ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
154
+ cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
155
+ cq, ck = self.q_norm1(cq), self.k_norm1(ck)
156
+
157
+ xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
158
+ xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
159
+ xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
160
+ xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
161
+ xq, xk = self.q_norm2(xq), self.k_norm2(xk)
162
+
163
+ # concat all
164
+ q, k, v = (
165
+ torch.cat([cq, xq], dim=1),
166
+ torch.cat([ck, xk], dim=1),
167
+ torch.cat([cv, xv], dim=1),
168
+ )
169
+
170
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
171
+
172
+ c, x = output.split([seqlen1, seqlen2], dim=1)
173
+ c = self.w1o(c)
174
+ x = self.w2o(x)
175
+
176
+ return c, x
177
+
178
+
179
+ class MMDiTBlock(nn.Module):
180
+ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
181
+ super().__init__()
182
+
183
+ self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
184
+ self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
185
+ if not is_last:
186
+ self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
187
+ self.modC = nn.Sequential(
188
+ nn.SiLU(),
189
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
190
+ )
191
+ else:
192
+ self.modC = nn.Sequential(
193
+ nn.SiLU(),
194
+ operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
195
+ )
196
+
197
+ self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
198
+ self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
199
+ self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
200
+ self.modX = nn.Sequential(
201
+ nn.SiLU(),
202
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
203
+ )
204
+
205
+ self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
206
+ self.is_last = is_last
207
+
208
+ #@torch.compile()
209
+ def forward(self, c, x, global_cond, **kwargs):
210
+
211
+ cres, xres = c, x
212
+
213
+ cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
214
+ self.modC(global_cond).chunk(6, dim=1)
215
+ )
216
+
217
+ c = modulate(self.normC1(c), cshift_msa, cscale_msa)
218
+
219
+ # xpath
220
+ xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
221
+ self.modX(global_cond).chunk(6, dim=1)
222
+ )
223
+
224
+ x = modulate(self.normX1(x), xshift_msa, xscale_msa)
225
+
226
+ # attention
227
+ c, x = self.attn(c, x)
228
+
229
+
230
+ c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
231
+ c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
232
+ c = cres + c
233
+
234
+ x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
235
+ x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
236
+ x = xres + x
237
+
238
+ return c, x
239
+
240
+ class DiTBlock(nn.Module):
241
+ # like MMDiTBlock, but it only has X
242
+ def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
243
+ super().__init__()
244
+
245
+ self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
246
+ self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
247
+
248
+ self.modCX = nn.Sequential(
249
+ nn.SiLU(),
250
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
251
+ )
252
+
253
+ self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
254
+ self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
255
+
256
+ #@torch.compile()
257
+ def forward(self, cx, global_cond, **kwargs):
258
+ cxres = cx
259
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
260
+ global_cond
261
+ ).chunk(6, dim=1)
262
+ cx = modulate(self.norm1(cx), shift_msa, scale_msa)
263
+ cx = self.attn(cx)
264
+ cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
265
+ mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
266
+ cx = gate_mlp.unsqueeze(1) * mlpout
267
+
268
+ cx = cxres + cx
269
+
270
+ return cx
271
+
272
+
273
+
274
+ class TimestepEmbedder(nn.Module):
275
+ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
276
+ super().__init__()
277
+ self.mlp = nn.Sequential(
278
+ operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
279
+ nn.SiLU(),
280
+ operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
281
+ )
282
+ self.frequency_embedding_size = frequency_embedding_size
283
+
284
+ @staticmethod
285
+ def timestep_embedding(t, dim, max_period=10000):
286
+ half = dim // 2
287
+ freqs = 1000 * torch.exp(
288
+ -math.log(max_period) * torch.arange(start=0, end=half) / half
289
+ ).to(t.device)
290
+ args = t[:, None] * freqs[None]
291
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
292
+ if dim % 2:
293
+ embedding = torch.cat(
294
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
295
+ )
296
+ return embedding
297
+
298
+ #@torch.compile()
299
+ def forward(self, t, dtype):
300
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
301
+ t_emb = self.mlp(t_freq)
302
+ return t_emb
303
+
304
+
305
+ class MMDiT(nn.Module):
306
+ def __init__(
307
+ self,
308
+ in_channels=4,
309
+ out_channels=4,
310
+ patch_size=2,
311
+ dim=3072,
312
+ n_layers=36,
313
+ n_double_layers=4,
314
+ n_heads=12,
315
+ global_conddim=3072,
316
+ cond_seq_dim=2048,
317
+ max_seq=32 * 32,
318
+ device=None,
319
+ dtype=None,
320
+ operations=None,
321
+ ):
322
+ super().__init__()
323
+ self.dtype = dtype
324
+
325
+ self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
326
+
327
+ self.cond_seq_linear = operations.Linear(
328
+ cond_seq_dim, dim, bias=False, dtype=dtype, device=device
329
+ ) # linear for something like text sequence.
330
+ self.init_x_linear = operations.Linear(
331
+ patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
332
+ ) # init linear for patchified image.
333
+
334
+ self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
335
+ self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
336
+
337
+ self.double_layers = nn.ModuleList([])
338
+ self.single_layers = nn.ModuleList([])
339
+
340
+
341
+ for idx in range(n_double_layers):
342
+ self.double_layers.append(
343
+ MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
344
+ )
345
+
346
+ for idx in range(n_double_layers, n_layers):
347
+ self.single_layers.append(
348
+ DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
349
+ )
350
+
351
+
352
+ self.final_linear = operations.Linear(
353
+ dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
354
+ )
355
+
356
+ self.modF = nn.Sequential(
357
+ nn.SiLU(),
358
+ operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
359
+ )
360
+
361
+ self.out_channels = out_channels
362
+ self.patch_size = patch_size
363
+ self.n_double_layers = n_double_layers
364
+ self.n_layers = n_layers
365
+
366
+ self.h_max = round(max_seq**0.5)
367
+ self.w_max = round(max_seq**0.5)
368
+
369
+ @torch.no_grad()
370
+ def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
371
+ # extend pe
372
+ pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
373
+
374
+ pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
375
+
376
+ # now we need to extend this to target_dim. for this we will use interpolation.
377
+ # we will use torch.nn.functional.interpolate
378
+ pe_as_2d = F.interpolate(
379
+ pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
380
+ )
381
+ pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
382
+ self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
383
+ self.h_max, self.w_max = target_dim
384
+
385
+ def pe_selection_index_based_on_dim(self, h, w):
386
+ h_p, w_p = h // self.patch_size, w // self.patch_size
387
+ original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
388
+ original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
389
+ starth = self.h_max // 2 - h_p // 2
390
+ endh =starth + h_p
391
+ startw = self.w_max // 2 - w_p // 2
392
+ endw = startw + w_p
393
+ original_pe_indexes = original_pe_indexes[
394
+ starth:endh, startw:endw
395
+ ]
396
+ return original_pe_indexes.flatten()
397
+
398
+ def unpatchify(self, x, h, w):
399
+ c = self.out_channels
400
+ p = self.patch_size
401
+
402
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
403
+ x = torch.einsum("nhwpqc->nchpwq", x)
404
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
405
+ return imgs
406
+
407
+ def patchify(self, x):
408
+ B, C, H, W = x.size()
409
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
410
+ x = x.view(
411
+ B,
412
+ C,
413
+ (H + 1) // self.patch_size,
414
+ self.patch_size,
415
+ (W + 1) // self.patch_size,
416
+ self.patch_size,
417
+ )
418
+ x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
419
+ return x
420
+
421
+ def apply_pos_embeds(self, x, h, w):
422
+ h = (h + 1) // self.patch_size
423
+ w = (w + 1) // self.patch_size
424
+ max_dim = max(h, w)
425
+
426
+ cur_dim = self.h_max
427
+ pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
428
+
429
+ if max_dim > cur_dim:
430
+ pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
431
+ cur_dim = max_dim
432
+
433
+ from_h = (cur_dim - h) // 2
434
+ from_w = (cur_dim - w) // 2
435
+ pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
436
+ return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
437
+
438
+ def forward(self, x, timestep, context, transformer_options={}, **kwargs):
439
+ patches_replace = transformer_options.get("patches_replace", {})
440
+ # patchify x, add PE
441
+ b, c, h, w = x.shape
442
+
443
+ # pe_indexes = self.pe_selection_index_based_on_dim(h, w)
444
+ # print(pe_indexes, pe_indexes.shape)
445
+
446
+ x = self.init_x_linear(self.patchify(x)) # B, T_x, D
447
+ x = self.apply_pos_embeds(x, h, w)
448
+ # x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
449
+ # x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
450
+
451
+ # process conditions for MMDiT Blocks
452
+ c_seq = context # B, T_c, D_c
453
+ t = timestep
454
+
455
+ c = self.cond_seq_linear(c_seq) # B, T_c, D
456
+ c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
457
+
458
+ global_cond = self.t_embedder(t, x.dtype) # B, D
459
+
460
+ blocks_replace = patches_replace.get("dit", {})
461
+ if len(self.double_layers) > 0:
462
+ for i, layer in enumerate(self.double_layers):
463
+ if ("double_block", i) in blocks_replace:
464
+ def block_wrap(args):
465
+ out = {}
466
+ out["txt"], out["img"] = layer(args["txt"],
467
+ args["img"],
468
+ args["vec"])
469
+ return out
470
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
471
+ c = out["txt"]
472
+ x = out["img"]
473
+ else:
474
+ c, x = layer(c, x, global_cond, **kwargs)
475
+
476
+ if len(self.single_layers) > 0:
477
+ c_len = c.size(1)
478
+ cx = torch.cat([c, x], dim=1)
479
+ for i, layer in enumerate(self.single_layers):
480
+ if ("single_block", i) in blocks_replace:
481
+ def block_wrap(args):
482
+ out = {}
483
+ out["img"] = layer(args["img"], args["vec"])
484
+ return out
485
+
486
+ out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
487
+ cx = out["img"]
488
+ else:
489
+ cx = layer(cx, global_cond, **kwargs)
490
+
491
+ x = cx[:, c_len:]
492
+
493
+ fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
494
+
495
+ x = modulate(x, fshift, fscale)
496
+ x = self.final_linear(x)
497
+ x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
498
+ return x
ComfyUI/comfy/ldm/cascade/common.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from comfy.ldm.modules.attention import optimized_attention
22
+ import comfy.ops
23
+
24
+ class OptimizedAttention(nn.Module):
25
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
26
+ super().__init__()
27
+ self.heads = nhead
28
+
29
+ self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
30
+ self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
31
+ self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
32
+
33
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
34
+
35
+ def forward(self, q, k, v):
36
+ q = self.to_q(q)
37
+ k = self.to_k(k)
38
+ v = self.to_v(v)
39
+
40
+ out = optimized_attention(q, k, v, self.heads)
41
+
42
+ return self.out_proj(out)
43
+
44
+ class Attention2D(nn.Module):
45
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
48
+ # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
49
+
50
+ def forward(self, x, kv, self_attn=False):
51
+ orig_shape = x.shape
52
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
53
+ if self_attn:
54
+ kv = torch.cat([x, kv], dim=1)
55
+ # x = self.attn(x, kv, kv, need_weights=False)[0]
56
+ x = self.attn(x, kv, kv)
57
+ x = x.permute(0, 2, 1).view(*orig_shape)
58
+ return x
59
+
60
+
61
+ def LayerNorm2d_op(operations):
62
+ class LayerNorm2d(operations.LayerNorm):
63
+ def __init__(self, *args, **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+
66
+ def forward(self, x):
67
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
68
+ return LayerNorm2d
69
+
70
+ class GlobalResponseNorm(nn.Module):
71
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
72
+ def __init__(self, dim, dtype=None, device=None):
73
+ super().__init__()
74
+ self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
75
+ self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
76
+
77
+ def forward(self, x):
78
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
79
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
80
+ return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
81
+
82
+
83
+ class ResBlock(nn.Module):
84
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
85
+ super().__init__()
86
+ self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
87
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
88
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
89
+ self.channelwise = nn.Sequential(
90
+ operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
91
+ nn.GELU(),
92
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
93
+ nn.Dropout(dropout),
94
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
95
+ )
96
+
97
+ def forward(self, x, x_skip=None):
98
+ x_res = x
99
+ x = self.norm(self.depthwise(x))
100
+ if x_skip is not None:
101
+ x = torch.cat([x, x_skip], dim=1)
102
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
103
+ return x + x_res
104
+
105
+
106
+ class AttnBlock(nn.Module):
107
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
108
+ super().__init__()
109
+ self.self_attn = self_attn
110
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
111
+ self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
112
+ self.kv_mapper = nn.Sequential(
113
+ nn.SiLU(),
114
+ operations.Linear(c_cond, c, dtype=dtype, device=device)
115
+ )
116
+
117
+ def forward(self, x, kv):
118
+ kv = self.kv_mapper(kv)
119
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
120
+ return x
121
+
122
+
123
+ class FeedForwardBlock(nn.Module):
124
+ def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
125
+ super().__init__()
126
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
127
+ self.channelwise = nn.Sequential(
128
+ operations.Linear(c, c * 4, dtype=dtype, device=device),
129
+ nn.GELU(),
130
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
131
+ nn.Dropout(dropout),
132
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
133
+ )
134
+
135
+ def forward(self, x):
136
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
137
+ return x
138
+
139
+
140
+ class TimestepBlock(nn.Module):
141
+ def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
142
+ super().__init__()
143
+ self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
144
+ self.conds = conds
145
+ for cname in conds:
146
+ setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
147
+
148
+ def forward(self, x, t):
149
+ t = t.chunk(len(self.conds) + 1, dim=1)
150
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
151
+ for i, c in enumerate(self.conds):
152
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
153
+ a, b = a + ac, b + bc
154
+ return x * (1 + a) + b
ComfyUI/comfy/ldm/cascade/controlnet.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torchvision
20
+ from torch import nn
21
+ from .common import LayerNorm2d_op
22
+
23
+
24
+ class CNetResBlock(nn.Module):
25
+ def __init__(self, c, dtype=None, device=None, operations=None):
26
+ super().__init__()
27
+ self.blocks = nn.Sequential(
28
+ LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
29
+ nn.GELU(),
30
+ operations.Conv2d(c, c, kernel_size=3, padding=1),
31
+ LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
32
+ nn.GELU(),
33
+ operations.Conv2d(c, c, kernel_size=3, padding=1),
34
+ )
35
+
36
+ def forward(self, x):
37
+ return x + self.blocks(x)
38
+
39
+
40
+ class ControlNet(nn.Module):
41
+ def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
42
+ super().__init__()
43
+ if bottleneck_mode is None:
44
+ bottleneck_mode = 'effnet'
45
+ self.proj_blocks = proj_blocks
46
+ if bottleneck_mode == 'effnet':
47
+ embd_channels = 1280
48
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
49
+ if c_in != 3:
50
+ in_weights = self.backbone[0][0].weight.data
51
+ self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
52
+ if c_in > 3:
53
+ # nn.init.constant_(self.backbone[0][0].weight, 0)
54
+ self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
55
+ else:
56
+ self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
57
+ elif bottleneck_mode == 'simple':
58
+ embd_channels = c_in
59
+ self.backbone = nn.Sequential(
60
+ operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
61
+ nn.LeakyReLU(0.2, inplace=True),
62
+ operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
63
+ )
64
+ elif bottleneck_mode == 'large':
65
+ self.backbone = nn.Sequential(
66
+ operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
67
+ nn.LeakyReLU(0.2, inplace=True),
68
+ operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
69
+ *[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
70
+ operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
71
+ )
72
+ embd_channels = 1280
73
+ else:
74
+ raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
75
+ self.projections = nn.ModuleList()
76
+ for _ in range(len(proj_blocks)):
77
+ self.projections.append(nn.Sequential(
78
+ operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
79
+ nn.LeakyReLU(0.2, inplace=True),
80
+ operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
81
+ ))
82
+ # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
83
+ self.xl = False
84
+ self.input_channels = c_in
85
+ self.unshuffle_amount = 8
86
+
87
+ def forward(self, x):
88
+ x = self.backbone(x)
89
+ proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
90
+ for i, idx in enumerate(self.proj_blocks):
91
+ proj_outputs[idx] = self.projections[i](x)
92
+ return {"input": proj_outputs[::-1]}
ComfyUI/comfy/ldm/cascade/stage_a.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.autograd import Function
22
+ import comfy.ops
23
+
24
+ ops = comfy.ops.disable_weight_init
25
+
26
+
27
+ class vector_quantize(Function):
28
+ @staticmethod
29
+ def forward(ctx, x, codebook):
30
+ with torch.no_grad():
31
+ codebook_sqr = torch.sum(codebook ** 2, dim=1)
32
+ x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
33
+
34
+ dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
35
+ _, indices = dist.min(dim=1)
36
+
37
+ ctx.save_for_backward(indices, codebook)
38
+ ctx.mark_non_differentiable(indices)
39
+
40
+ nn = torch.index_select(codebook, 0, indices)
41
+ return nn, indices
42
+
43
+ @staticmethod
44
+ def backward(ctx, grad_output, grad_indices):
45
+ grad_inputs, grad_codebook = None, None
46
+
47
+ if ctx.needs_input_grad[0]:
48
+ grad_inputs = grad_output.clone()
49
+ if ctx.needs_input_grad[1]:
50
+ # Gradient wrt. the codebook
51
+ indices, codebook = ctx.saved_tensors
52
+
53
+ grad_codebook = torch.zeros_like(codebook)
54
+ grad_codebook.index_add_(0, indices, grad_output)
55
+
56
+ return (grad_inputs, grad_codebook)
57
+
58
+
59
+ class VectorQuantize(nn.Module):
60
+ def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
61
+ """
62
+ Takes an input of variable size (as long as the last dimension matches the embedding size).
63
+ Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
64
+ with the same size as the input, vq and commitment components for the loss as a touple
65
+ in the second output and the indices of the quantized vectors in the third:
66
+ quantized, (vq_loss, commit_loss), indices
67
+ """
68
+ super(VectorQuantize, self).__init__()
69
+
70
+ self.codebook = nn.Embedding(k, embedding_size)
71
+ self.codebook.weight.data.uniform_(-1./k, 1./k)
72
+ self.vq = vector_quantize.apply
73
+
74
+ self.ema_decay = ema_decay
75
+ self.ema_loss = ema_loss
76
+ if ema_loss:
77
+ self.register_buffer('ema_element_count', torch.ones(k))
78
+ self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
79
+
80
+ def _laplace_smoothing(self, x, epsilon):
81
+ n = torch.sum(x)
82
+ return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
83
+
84
+ def _updateEMA(self, z_e_x, indices):
85
+ mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
86
+ elem_count = mask.sum(dim=0)
87
+ weight_sum = torch.mm(mask.t(), z_e_x)
88
+
89
+ self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
90
+ self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
91
+ self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
92
+
93
+ self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
94
+
95
+ def idx2vq(self, idx, dim=-1):
96
+ q_idx = self.codebook(idx)
97
+ if dim != -1:
98
+ q_idx = q_idx.movedim(-1, dim)
99
+ return q_idx
100
+
101
+ def forward(self, x, get_losses=True, dim=-1):
102
+ if dim != -1:
103
+ x = x.movedim(dim, -1)
104
+ z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
105
+ z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
106
+ vq_loss, commit_loss = None, None
107
+ if self.ema_loss and self.training:
108
+ self._updateEMA(z_e_x.detach(), indices.detach())
109
+ # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
110
+ z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
111
+ if get_losses:
112
+ vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
113
+ commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
114
+
115
+ z_q_x = z_q_x.view(x.shape)
116
+ if dim != -1:
117
+ z_q_x = z_q_x.movedim(-1, dim)
118
+ return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
119
+
120
+
121
+ class ResBlock(nn.Module):
122
+ def __init__(self, c, c_hidden):
123
+ super().__init__()
124
+ # depthwise/attention
125
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
126
+ self.depthwise = nn.Sequential(
127
+ nn.ReplicationPad2d(1),
128
+ ops.Conv2d(c, c, kernel_size=3, groups=c)
129
+ )
130
+
131
+ # channelwise
132
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
133
+ self.channelwise = nn.Sequential(
134
+ ops.Linear(c, c_hidden),
135
+ nn.GELU(),
136
+ ops.Linear(c_hidden, c),
137
+ )
138
+
139
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
140
+
141
+ # Init weights
142
+ def _basic_init(module):
143
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
144
+ torch.nn.init.xavier_uniform_(module.weight)
145
+ if module.bias is not None:
146
+ nn.init.constant_(module.bias, 0)
147
+
148
+ self.apply(_basic_init)
149
+
150
+ def _norm(self, x, norm):
151
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
152
+
153
+ def forward(self, x):
154
+ mods = self.gammas
155
+
156
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
157
+ try:
158
+ x = x + self.depthwise(x_temp) * mods[2]
159
+ except: #operation not implemented for bf16
160
+ x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
161
+ x = x + self.depthwise[1](x_temp) * mods[2]
162
+
163
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
164
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
165
+
166
+ return x
167
+
168
+
169
+ class StageA(nn.Module):
170
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
171
+ super().__init__()
172
+ self.c_latent = c_latent
173
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
174
+
175
+ # Encoder blocks
176
+ self.in_block = nn.Sequential(
177
+ nn.PixelUnshuffle(2),
178
+ ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
179
+ )
180
+ down_blocks = []
181
+ for i in range(levels):
182
+ if i > 0:
183
+ down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
184
+ block = ResBlock(c_levels[i], c_levels[i] * 4)
185
+ down_blocks.append(block)
186
+ down_blocks.append(nn.Sequential(
187
+ ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
188
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
189
+ ))
190
+ self.down_blocks = nn.Sequential(*down_blocks)
191
+ self.down_blocks[0]
192
+
193
+ self.codebook_size = codebook_size
194
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
195
+
196
+ # Decoder blocks
197
+ up_blocks = [nn.Sequential(
198
+ ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
199
+ )]
200
+ for i in range(levels):
201
+ for j in range(bottleneck_blocks if i == 0 else 1):
202
+ block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
203
+ up_blocks.append(block)
204
+ if i < levels - 1:
205
+ up_blocks.append(
206
+ ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
207
+ padding=1))
208
+ self.up_blocks = nn.Sequential(*up_blocks)
209
+ self.out_block = nn.Sequential(
210
+ ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
211
+ nn.PixelShuffle(2),
212
+ )
213
+
214
+ def encode(self, x, quantize=False):
215
+ x = self.in_block(x)
216
+ x = self.down_blocks(x)
217
+ if quantize:
218
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
219
+ return qe, x, indices, vq_loss + commit_loss * 0.25
220
+ else:
221
+ return x
222
+
223
+ def decode(self, x):
224
+ x = self.up_blocks(x)
225
+ x = self.out_block(x)
226
+ return x
227
+
228
+ def forward(self, x, quantize=False):
229
+ qe, x, _, vq_loss = self.encode(x, quantize)
230
+ x = self.decode(qe)
231
+ return x, vq_loss
232
+
233
+
234
+ class Discriminator(nn.Module):
235
+ def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
236
+ super().__init__()
237
+ d = max(depth - 3, 3)
238
+ layers = [
239
+ nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
240
+ nn.LeakyReLU(0.2),
241
+ ]
242
+ for i in range(depth - 1):
243
+ c_in = c_hidden // (2 ** max((d - i), 0))
244
+ c_out = c_hidden // (2 ** max((d - 1 - i), 0))
245
+ layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
246
+ layers.append(nn.InstanceNorm2d(c_out))
247
+ layers.append(nn.LeakyReLU(0.2))
248
+ self.encoder = nn.Sequential(*layers)
249
+ self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
250
+ self.logits = nn.Sigmoid()
251
+
252
+ def forward(self, x, cond=None):
253
+ x = self.encoder(x)
254
+ if cond is not None:
255
+ cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
256
+ x = torch.cat([x, cond], dim=1)
257
+ x = self.shuffle(x)
258
+ x = self.logits(x)
259
+ return x
ComfyUI/comfy/ldm/cascade/stage_b.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import math
20
+ import torch
21
+ from torch import nn
22
+ from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
23
+
24
+ class StageB(nn.Module):
25
+ def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
26
+ nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
27
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
28
+ c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
29
+ t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
30
+ super().__init__()
31
+ self.dtype = dtype
32
+ self.c_r = c_r
33
+ self.t_conds = t_conds
34
+ self.c_clip_seq = c_clip_seq
35
+ if not isinstance(dropout, list):
36
+ dropout = [dropout] * len(c_hidden)
37
+ if not isinstance(self_attn, list):
38
+ self_attn = [self_attn] * len(c_hidden)
39
+
40
+ # CONDITIONING
41
+ self.effnet_mapper = nn.Sequential(
42
+ operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
43
+ nn.GELU(),
44
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
45
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
46
+ )
47
+ self.pixels_mapper = nn.Sequential(
48
+ operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
49
+ nn.GELU(),
50
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
51
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
52
+ )
53
+ self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
54
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
55
+
56
+ self.embedding = nn.Sequential(
57
+ nn.PixelUnshuffle(patch_size),
58
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
59
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
60
+ )
61
+
62
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
63
+ if block_type == 'C':
64
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
65
+ elif block_type == 'A':
66
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
67
+ elif block_type == 'F':
68
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
69
+ elif block_type == 'T':
70
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
71
+ else:
72
+ raise Exception(f'Block type {block_type} not supported')
73
+
74
+ # BLOCKS
75
+ # -- down blocks
76
+ self.down_blocks = nn.ModuleList()
77
+ self.down_downscalers = nn.ModuleList()
78
+ self.down_repeat_mappers = nn.ModuleList()
79
+ for i in range(len(c_hidden)):
80
+ if i > 0:
81
+ self.down_downscalers.append(nn.Sequential(
82
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
83
+ operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
84
+ ))
85
+ else:
86
+ self.down_downscalers.append(nn.Identity())
87
+ down_block = nn.ModuleList()
88
+ for _ in range(blocks[0][i]):
89
+ for block_type in level_config[i]:
90
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
91
+ down_block.append(block)
92
+ self.down_blocks.append(down_block)
93
+ if block_repeat is not None:
94
+ block_repeat_mappers = nn.ModuleList()
95
+ for _ in range(block_repeat[0][i] - 1):
96
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
97
+ self.down_repeat_mappers.append(block_repeat_mappers)
98
+
99
+ # -- up blocks
100
+ self.up_blocks = nn.ModuleList()
101
+ self.up_upscalers = nn.ModuleList()
102
+ self.up_repeat_mappers = nn.ModuleList()
103
+ for i in reversed(range(len(c_hidden))):
104
+ if i > 0:
105
+ self.up_upscalers.append(nn.Sequential(
106
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
107
+ operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
108
+ ))
109
+ else:
110
+ self.up_upscalers.append(nn.Identity())
111
+ up_block = nn.ModuleList()
112
+ for j in range(blocks[1][::-1][i]):
113
+ for k, block_type in enumerate(level_config[i]):
114
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
115
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
116
+ self_attn=self_attn[i])
117
+ up_block.append(block)
118
+ self.up_blocks.append(up_block)
119
+ if block_repeat is not None:
120
+ block_repeat_mappers = nn.ModuleList()
121
+ for _ in range(block_repeat[1][::-1][i] - 1):
122
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
123
+ self.up_repeat_mappers.append(block_repeat_mappers)
124
+
125
+ # OUTPUT
126
+ self.clf = nn.Sequential(
127
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
128
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
129
+ nn.PixelShuffle(patch_size),
130
+ )
131
+
132
+ # --- WEIGHT INIT ---
133
+ # self.apply(self._init_weights) # General init
134
+ # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
135
+ # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
136
+ # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
137
+ # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
138
+ # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
139
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
140
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
141
+ #
142
+ # # blocks
143
+ # for level_block in self.down_blocks + self.up_blocks:
144
+ # for block in level_block:
145
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
146
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
147
+ # elif isinstance(block, TimestepBlock):
148
+ # for layer in block.modules():
149
+ # if isinstance(layer, nn.Linear):
150
+ # nn.init.constant_(layer.weight, 0)
151
+ #
152
+ # def _init_weights(self, m):
153
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
154
+ # torch.nn.init.xavier_uniform_(m.weight)
155
+ # if m.bias is not None:
156
+ # nn.init.constant_(m.bias, 0)
157
+
158
+ def gen_r_embedding(self, r, max_positions=10000):
159
+ r = r * max_positions
160
+ half_dim = self.c_r // 2
161
+ emb = math.log(max_positions) / (half_dim - 1)
162
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
163
+ emb = r[:, None] * emb[None, :]
164
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
165
+ if self.c_r % 2 == 1: # zero pad
166
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
167
+ return emb
168
+
169
+ def gen_c_embeddings(self, clip):
170
+ if len(clip.shape) == 2:
171
+ clip = clip.unsqueeze(1)
172
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
173
+ clip = self.clip_norm(clip)
174
+ return clip
175
+
176
+ def _down_encode(self, x, r_embed, clip):
177
+ level_outputs = []
178
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
179
+ for down_block, downscaler, repmap in block_group:
180
+ x = downscaler(x)
181
+ for i in range(len(repmap) + 1):
182
+ for block in down_block:
183
+ if isinstance(block, ResBlock) or (
184
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
185
+ ResBlock)):
186
+ x = block(x)
187
+ elif isinstance(block, AttnBlock) or (
188
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
189
+ AttnBlock)):
190
+ x = block(x, clip)
191
+ elif isinstance(block, TimestepBlock) or (
192
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
193
+ TimestepBlock)):
194
+ x = block(x, r_embed)
195
+ else:
196
+ x = block(x)
197
+ if i < len(repmap):
198
+ x = repmap[i](x)
199
+ level_outputs.insert(0, x)
200
+ return level_outputs
201
+
202
+ def _up_decode(self, level_outputs, r_embed, clip):
203
+ x = level_outputs[0]
204
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
205
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
206
+ for j in range(len(repmap) + 1):
207
+ for k, block in enumerate(up_block):
208
+ if isinstance(block, ResBlock) or (
209
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
210
+ ResBlock)):
211
+ skip = level_outputs[i] if k == 0 and i > 0 else None
212
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
213
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
214
+ align_corners=True)
215
+ x = block(x, skip)
216
+ elif isinstance(block, AttnBlock) or (
217
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
218
+ AttnBlock)):
219
+ x = block(x, clip)
220
+ elif isinstance(block, TimestepBlock) or (
221
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
222
+ TimestepBlock)):
223
+ x = block(x, r_embed)
224
+ else:
225
+ x = block(x)
226
+ if j < len(repmap):
227
+ x = repmap[j](x)
228
+ x = upscaler(x)
229
+ return x
230
+
231
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
232
+ if pixels is None:
233
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
234
+
235
+ # Process the conditioning embeddings
236
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
237
+ for c in self.t_conds:
238
+ t_cond = kwargs.get(c, torch.zeros_like(r))
239
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
240
+ clip = self.gen_c_embeddings(clip)
241
+
242
+ # Model Blocks
243
+ x = self.embedding(x)
244
+ x = x + self.effnet_mapper(
245
+ nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
246
+ x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
247
+ align_corners=True)
248
+ level_outputs = self._down_encode(x, r_embed, clip)
249
+ x = self._up_decode(level_outputs, r_embed, clip)
250
+ return self.clf(x)
251
+
252
+ def update_weights_ema(self, src_model, beta=0.999):
253
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
254
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
255
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
256
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
ComfyUI/comfy/ldm/cascade/stage_c.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from torch import nn
21
+ import math
22
+ from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
23
+ # from .controlnet import ControlNetDeliverer
24
+
25
+ class UpDownBlock2d(nn.Module):
26
+ def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
27
+ super().__init__()
28
+ assert mode in ['up', 'down']
29
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
30
+ align_corners=True) if enabled else nn.Identity()
31
+ mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
32
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
33
+
34
+ def forward(self, x):
35
+ for block in self.blocks:
36
+ x = block(x)
37
+ return x
38
+
39
+
40
+ class StageC(nn.Module):
41
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
42
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
43
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
44
+ dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
45
+ dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.dtype = dtype
48
+ self.c_r = c_r
49
+ self.t_conds = t_conds
50
+ self.c_clip_seq = c_clip_seq
51
+ if not isinstance(dropout, list):
52
+ dropout = [dropout] * len(c_hidden)
53
+ if not isinstance(self_attn, list):
54
+ self_attn = [self_attn] * len(c_hidden)
55
+
56
+ # CONDITIONING
57
+ self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
58
+ self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
59
+ self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
60
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
61
+
62
+ self.embedding = nn.Sequential(
63
+ nn.PixelUnshuffle(patch_size),
64
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
65
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
66
+ )
67
+
68
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
69
+ if block_type == 'C':
70
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
71
+ elif block_type == 'A':
72
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
73
+ elif block_type == 'F':
74
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
75
+ elif block_type == 'T':
76
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
77
+ else:
78
+ raise Exception(f'Block type {block_type} not supported')
79
+
80
+ # BLOCKS
81
+ # -- down blocks
82
+ self.down_blocks = nn.ModuleList()
83
+ self.down_downscalers = nn.ModuleList()
84
+ self.down_repeat_mappers = nn.ModuleList()
85
+ for i in range(len(c_hidden)):
86
+ if i > 0:
87
+ self.down_downscalers.append(nn.Sequential(
88
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
89
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
90
+ ))
91
+ else:
92
+ self.down_downscalers.append(nn.Identity())
93
+ down_block = nn.ModuleList()
94
+ for _ in range(blocks[0][i]):
95
+ for block_type in level_config[i]:
96
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
97
+ down_block.append(block)
98
+ self.down_blocks.append(down_block)
99
+ if block_repeat is not None:
100
+ block_repeat_mappers = nn.ModuleList()
101
+ for _ in range(block_repeat[0][i] - 1):
102
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
103
+ self.down_repeat_mappers.append(block_repeat_mappers)
104
+
105
+ # -- up blocks
106
+ self.up_blocks = nn.ModuleList()
107
+ self.up_upscalers = nn.ModuleList()
108
+ self.up_repeat_mappers = nn.ModuleList()
109
+ for i in reversed(range(len(c_hidden))):
110
+ if i > 0:
111
+ self.up_upscalers.append(nn.Sequential(
112
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
113
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
114
+ ))
115
+ else:
116
+ self.up_upscalers.append(nn.Identity())
117
+ up_block = nn.ModuleList()
118
+ for j in range(blocks[1][::-1][i]):
119
+ for k, block_type in enumerate(level_config[i]):
120
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
121
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
122
+ self_attn=self_attn[i])
123
+ up_block.append(block)
124
+ self.up_blocks.append(up_block)
125
+ if block_repeat is not None:
126
+ block_repeat_mappers = nn.ModuleList()
127
+ for _ in range(block_repeat[1][::-1][i] - 1):
128
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
129
+ self.up_repeat_mappers.append(block_repeat_mappers)
130
+
131
+ # OUTPUT
132
+ self.clf = nn.Sequential(
133
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
134
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
135
+ nn.PixelShuffle(patch_size),
136
+ )
137
+
138
+ # --- WEIGHT INIT ---
139
+ # self.apply(self._init_weights) # General init
140
+ # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
141
+ # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
142
+ # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
143
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
144
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
145
+ #
146
+ # # blocks
147
+ # for level_block in self.down_blocks + self.up_blocks:
148
+ # for block in level_block:
149
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
150
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
151
+ # elif isinstance(block, TimestepBlock):
152
+ # for layer in block.modules():
153
+ # if isinstance(layer, nn.Linear):
154
+ # nn.init.constant_(layer.weight, 0)
155
+ #
156
+ # def _init_weights(self, m):
157
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
158
+ # torch.nn.init.xavier_uniform_(m.weight)
159
+ # if m.bias is not None:
160
+ # nn.init.constant_(m.bias, 0)
161
+
162
+ def gen_r_embedding(self, r, max_positions=10000):
163
+ r = r * max_positions
164
+ half_dim = self.c_r // 2
165
+ emb = math.log(max_positions) / (half_dim - 1)
166
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
167
+ emb = r[:, None] * emb[None, :]
168
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
169
+ if self.c_r % 2 == 1: # zero pad
170
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
171
+ return emb
172
+
173
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
174
+ clip_txt = self.clip_txt_mapper(clip_txt)
175
+ if len(clip_txt_pooled.shape) == 2:
176
+ clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
177
+ if len(clip_img.shape) == 2:
178
+ clip_img = clip_img.unsqueeze(1)
179
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
180
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
181
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
182
+ clip = self.clip_norm(clip)
183
+ return clip
184
+
185
+ def _down_encode(self, x, r_embed, clip, cnet=None):
186
+ level_outputs = []
187
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
188
+ for down_block, downscaler, repmap in block_group:
189
+ x = downscaler(x)
190
+ for i in range(len(repmap) + 1):
191
+ for block in down_block:
192
+ if isinstance(block, ResBlock) or (
193
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
194
+ ResBlock)):
195
+ if cnet is not None:
196
+ next_cnet = cnet.pop()
197
+ if next_cnet is not None:
198
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
199
+ align_corners=True).to(x.dtype)
200
+ x = block(x)
201
+ elif isinstance(block, AttnBlock) or (
202
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
203
+ AttnBlock)):
204
+ x = block(x, clip)
205
+ elif isinstance(block, TimestepBlock) or (
206
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
207
+ TimestepBlock)):
208
+ x = block(x, r_embed)
209
+ else:
210
+ x = block(x)
211
+ if i < len(repmap):
212
+ x = repmap[i](x)
213
+ level_outputs.insert(0, x)
214
+ return level_outputs
215
+
216
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
217
+ x = level_outputs[0]
218
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
219
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
220
+ for j in range(len(repmap) + 1):
221
+ for k, block in enumerate(up_block):
222
+ if isinstance(block, ResBlock) or (
223
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
224
+ ResBlock)):
225
+ skip = level_outputs[i] if k == 0 and i > 0 else None
226
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
227
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
228
+ align_corners=True)
229
+ if cnet is not None:
230
+ next_cnet = cnet.pop()
231
+ if next_cnet is not None:
232
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
233
+ align_corners=True).to(x.dtype)
234
+ x = block(x, skip)
235
+ elif isinstance(block, AttnBlock) or (
236
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
237
+ AttnBlock)):
238
+ x = block(x, clip)
239
+ elif isinstance(block, TimestepBlock) or (
240
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
241
+ TimestepBlock)):
242
+ x = block(x, r_embed)
243
+ else:
244
+ x = block(x)
245
+ if j < len(repmap):
246
+ x = repmap[j](x)
247
+ x = upscaler(x)
248
+ return x
249
+
250
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
251
+ # Process the conditioning embeddings
252
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
253
+ for c in self.t_conds:
254
+ t_cond = kwargs.get(c, torch.zeros_like(r))
255
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
256
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
257
+
258
+ if control is not None:
259
+ cnet = control.get("input")
260
+ else:
261
+ cnet = None
262
+
263
+ # Model Blocks
264
+ x = self.embedding(x)
265
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
266
+ x = self._up_decode(level_outputs, r_embed, clip, cnet)
267
+ return self.clf(x)
268
+
269
+ def update_weights_ema(self, src_model, beta=0.999):
270
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
271
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
272
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
273
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
ComfyUI/comfy/ldm/cascade/stage_c_coder.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+ import torch
19
+ import torchvision
20
+ from torch import nn
21
+
22
+ import comfy.ops
23
+
24
+ ops = comfy.ops.disable_weight_init
25
+
26
+ # EfficientNet
27
+ class EfficientNetEncoder(nn.Module):
28
+ def __init__(self, c_latent=16):
29
+ super().__init__()
30
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
31
+ self.mapper = nn.Sequential(
32
+ ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
33
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
34
+ )
35
+ self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
36
+ self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
37
+
38
+ def forward(self, x):
39
+ x = x * 0.5 + 0.5
40
+ x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
41
+ o = self.mapper(self.backbone(x))
42
+ return o
43
+
44
+
45
+ # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
46
+ class Previewer(nn.Module):
47
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
48
+ super().__init__()
49
+ self.blocks = nn.Sequential(
50
+ ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
51
+ nn.GELU(),
52
+ nn.BatchNorm2d(c_hidden),
53
+
54
+ ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
55
+ nn.GELU(),
56
+ nn.BatchNorm2d(c_hidden),
57
+
58
+ ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
59
+ nn.GELU(),
60
+ nn.BatchNorm2d(c_hidden // 2),
61
+
62
+ ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
63
+ nn.GELU(),
64
+ nn.BatchNorm2d(c_hidden // 2),
65
+
66
+ ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
67
+ nn.GELU(),
68
+ nn.BatchNorm2d(c_hidden // 4),
69
+
70
+ ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
71
+ nn.GELU(),
72
+ nn.BatchNorm2d(c_hidden // 4),
73
+
74
+ ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
75
+ nn.GELU(),
76
+ nn.BatchNorm2d(c_hidden // 4),
77
+
78
+ ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
79
+ nn.GELU(),
80
+ nn.BatchNorm2d(c_hidden // 4),
81
+
82
+ ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
83
+ )
84
+
85
+ def forward(self, x):
86
+ return (self.blocks(x) - 0.5) * 2.0
87
+
88
+ class StageC_coder(nn.Module):
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.previewer = Previewer()
92
+ self.encoder = EfficientNetEncoder()
93
+
94
+ def encode(self, x):
95
+ return self.encoder(x)
96
+
97
+ def decode(self, x):
98
+ return self.previewer(x)
ComfyUI/comfy/ldm/chroma/layers.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+ from comfy.ldm.flux.math import attention
5
+ from comfy.ldm.flux.layers import (
6
+ MLPEmbedder,
7
+ RMSNorm,
8
+ QKNorm,
9
+ SelfAttention,
10
+ ModulationOut,
11
+ )
12
+
13
+
14
+
15
+ class ChromaModulationOut(ModulationOut):
16
+ @classmethod
17
+ def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
18
+ return cls(
19
+ shift=tensor[:, offset : offset + 1, :],
20
+ scale=tensor[:, offset + 1 : offset + 2, :],
21
+ gate=tensor[:, offset + 2 : offset + 3, :],
22
+ )
23
+
24
+
25
+
26
+
27
+ class Approximator(nn.Module):
28
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
29
+ super().__init__()
30
+ self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
31
+ self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
32
+ self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
33
+ self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
34
+
35
+ @property
36
+ def device(self):
37
+ # Get the device of the module (assumes all parameters are on the same device)
38
+ return next(self.parameters()).device
39
+
40
+ def forward(self, x: Tensor) -> Tensor:
41
+ x = self.in_proj(x)
42
+
43
+ for layer, norms in zip(self.layers, self.norms):
44
+ x = x + layer(norms(x))
45
+
46
+ x = self.out_proj(x)
47
+
48
+ return x
49
+
50
+
51
+ class DoubleStreamBlock(nn.Module):
52
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
53
+ super().__init__()
54
+
55
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
56
+ self.num_heads = num_heads
57
+ self.hidden_size = hidden_size
58
+ self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
59
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
60
+
61
+ self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
62
+ self.img_mlp = nn.Sequential(
63
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
64
+ nn.GELU(approximate="tanh"),
65
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
66
+ )
67
+
68
+ self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
69
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
70
+
71
+ self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
72
+ self.txt_mlp = nn.Sequential(
73
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
74
+ nn.GELU(approximate="tanh"),
75
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
76
+ )
77
+ self.flipped_img_txt = flipped_img_txt
78
+
79
+ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
80
+ (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
81
+
82
+ # prepare image for attention
83
+ img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
84
+ img_qkv = self.img_attn.qkv(img_modulated)
85
+ img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
86
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
87
+
88
+ # prepare txt for attention
89
+ txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
90
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
91
+ txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
92
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
93
+
94
+ # run actual attention
95
+ attn = attention(torch.cat((txt_q, img_q), dim=2),
96
+ torch.cat((txt_k, img_k), dim=2),
97
+ torch.cat((txt_v, img_v), dim=2),
98
+ pe=pe, mask=attn_mask)
99
+
100
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
101
+
102
+ # calculate the img bloks
103
+ img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
104
+ img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
105
+
106
+ # calculate the txt bloks
107
+ txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
108
+ txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
109
+
110
+ if txt.dtype == torch.float16:
111
+ txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
112
+
113
+ return img, txt
114
+
115
+
116
+ class SingleStreamBlock(nn.Module):
117
+ """
118
+ A DiT block with parallel linear layers as described in
119
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ hidden_size: int,
125
+ num_heads: int,
126
+ mlp_ratio: float = 4.0,
127
+ qk_scale: float = None,
128
+ dtype=None,
129
+ device=None,
130
+ operations=None
131
+ ):
132
+ super().__init__()
133
+ self.hidden_dim = hidden_size
134
+ self.num_heads = num_heads
135
+ head_dim = hidden_size // num_heads
136
+ self.scale = qk_scale or head_dim**-0.5
137
+
138
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
139
+ # qkv and mlp_in
140
+ self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
141
+ # proj and mlp_out
142
+ self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
143
+
144
+ self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
145
+
146
+ self.hidden_size = hidden_size
147
+ self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
148
+
149
+ self.mlp_act = nn.GELU(approximate="tanh")
150
+
151
+ def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
152
+ mod = vec
153
+ x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
154
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
155
+
156
+ q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
157
+ q, k = self.norm(q, k, v)
158
+
159
+ # compute attention
160
+ attn = attention(q, k, v, pe=pe, mask=attn_mask)
161
+ # compute activation in mlp stream, cat again and run second linear layer
162
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
163
+ x.addcmul_(mod.gate, output)
164
+ if x.dtype == torch.float16:
165
+ x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
166
+ return x
167
+
168
+
169
+ class LastLayer(nn.Module):
170
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
171
+ super().__init__()
172
+ self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
173
+ self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
174
+
175
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
176
+ shift, scale = vec
177
+ shift = shift.squeeze(1)
178
+ scale = scale.squeeze(1)
179
+ x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
180
+ x = self.linear(x)
181
+ return x
ComfyUI/comfy/ldm/chroma/model.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Original code can be found on: https://github.com/black-forest-labs/flux
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+ from einops import rearrange, repeat
8
+ import comfy.ldm.common_dit
9
+
10
+ from comfy.ldm.flux.layers import (
11
+ EmbedND,
12
+ timestep_embedding,
13
+ )
14
+
15
+ from .layers import (
16
+ DoubleStreamBlock,
17
+ LastLayer,
18
+ SingleStreamBlock,
19
+ Approximator,
20
+ ChromaModulationOut,
21
+ )
22
+
23
+
24
+ @dataclass
25
+ class ChromaParams:
26
+ in_channels: int
27
+ out_channels: int
28
+ context_in_dim: int
29
+ hidden_size: int
30
+ mlp_ratio: float
31
+ num_heads: int
32
+ depth: int
33
+ depth_single_blocks: int
34
+ axes_dim: list
35
+ theta: int
36
+ patch_size: int
37
+ qkv_bias: bool
38
+ in_dim: int
39
+ out_dim: int
40
+ hidden_dim: int
41
+ n_layers: int
42
+
43
+
44
+
45
+
46
+ class Chroma(nn.Module):
47
+ """
48
+ Transformer model for flow matching on sequences.
49
+ """
50
+
51
+ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
52
+ super().__init__()
53
+ self.dtype = dtype
54
+ params = ChromaParams(**kwargs)
55
+ self.params = params
56
+ self.patch_size = params.patch_size
57
+ self.in_channels = params.in_channels
58
+ self.out_channels = params.out_channels
59
+ if params.hidden_size % params.num_heads != 0:
60
+ raise ValueError(
61
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
62
+ )
63
+ pe_dim = params.hidden_size // params.num_heads
64
+ if sum(params.axes_dim) != pe_dim:
65
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
66
+ self.hidden_size = params.hidden_size
67
+ self.num_heads = params.num_heads
68
+ self.in_dim = params.in_dim
69
+ self.out_dim = params.out_dim
70
+ self.hidden_dim = params.hidden_dim
71
+ self.n_layers = params.n_layers
72
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
73
+ self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
74
+ self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
75
+ # set as nn identity for now, will overwrite it later.
76
+ self.distilled_guidance_layer = Approximator(
77
+ in_dim=self.in_dim,
78
+ hidden_dim=self.hidden_dim,
79
+ out_dim=self.out_dim,
80
+ n_layers=self.n_layers,
81
+ dtype=dtype, device=device, operations=operations
82
+ )
83
+
84
+
85
+ self.double_blocks = nn.ModuleList(
86
+ [
87
+ DoubleStreamBlock(
88
+ self.hidden_size,
89
+ self.num_heads,
90
+ mlp_ratio=params.mlp_ratio,
91
+ qkv_bias=params.qkv_bias,
92
+ dtype=dtype, device=device, operations=operations
93
+ )
94
+ for _ in range(params.depth)
95
+ ]
96
+ )
97
+
98
+ self.single_blocks = nn.ModuleList(
99
+ [
100
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
101
+ for _ in range(params.depth_single_blocks)
102
+ ]
103
+ )
104
+
105
+ if final_layer:
106
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
107
+
108
+ self.skip_mmdit = []
109
+ self.skip_dit = []
110
+ self.lite = False
111
+
112
+ def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
113
+ # This function slices up the modulations tensor which has the following layout:
114
+ # single : num_single_blocks * 3 elements
115
+ # double_img : num_double_blocks * 6 elements
116
+ # double_txt : num_double_blocks * 6 elements
117
+ # final : 2 elements
118
+ if block_type == "final":
119
+ return (tensor[:, -2:-1, :], tensor[:, -1:, :])
120
+ single_block_count = self.params.depth_single_blocks
121
+ double_block_count = self.params.depth
122
+ offset = 3 * idx
123
+ if block_type == "single":
124
+ return ChromaModulationOut.from_offset(tensor, offset)
125
+ # Double block modulations are 6 elements so we double 3 * idx.
126
+ offset *= 2
127
+ if block_type in {"double_img", "double_txt"}:
128
+ # Advance past the single block modulations.
129
+ offset += 3 * single_block_count
130
+ if block_type == "double_txt":
131
+ # Advance past the double block img modulations.
132
+ offset += 6 * double_block_count
133
+ return (
134
+ ChromaModulationOut.from_offset(tensor, offset),
135
+ ChromaModulationOut.from_offset(tensor, offset + 3),
136
+ )
137
+ raise ValueError("Bad block_type")
138
+
139
+
140
+ def forward_orig(
141
+ self,
142
+ img: Tensor,
143
+ img_ids: Tensor,
144
+ txt: Tensor,
145
+ txt_ids: Tensor,
146
+ timesteps: Tensor,
147
+ guidance: Tensor = None,
148
+ control = None,
149
+ transformer_options={},
150
+ attn_mask: Tensor = None,
151
+ ) -> Tensor:
152
+ patches_replace = transformer_options.get("patches_replace", {})
153
+ if img.ndim != 3 or txt.ndim != 3:
154
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
155
+
156
+ # running on sequences img
157
+ img = self.img_in(img)
158
+
159
+ # distilled vector guidance
160
+ mod_index_length = 344
161
+ distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
162
+ # guidance = guidance *
163
+ distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
164
+
165
+ # get all modulation index
166
+ modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
167
+ # we need to broadcast the modulation index here so each batch has all of the index
168
+ modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
169
+ # and we need to broadcast timestep and guidance along too
170
+ timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
171
+ # then and only then we could concatenate it together
172
+ input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
173
+
174
+ mod_vectors = self.distilled_guidance_layer(input_vec)
175
+
176
+ txt = self.txt_in(txt)
177
+
178
+ ids = torch.cat((txt_ids, img_ids), dim=1)
179
+ pe = self.pe_embedder(ids)
180
+
181
+ blocks_replace = patches_replace.get("dit", {})
182
+ for i, block in enumerate(self.double_blocks):
183
+ if i not in self.skip_mmdit:
184
+ double_mod = (
185
+ self.get_modulations(mod_vectors, "double_img", idx=i),
186
+ self.get_modulations(mod_vectors, "double_txt", idx=i),
187
+ )
188
+ if ("double_block", i) in blocks_replace:
189
+ def block_wrap(args):
190
+ out = {}
191
+ out["img"], out["txt"] = block(img=args["img"],
192
+ txt=args["txt"],
193
+ vec=args["vec"],
194
+ pe=args["pe"],
195
+ attn_mask=args.get("attn_mask"))
196
+ return out
197
+
198
+ out = blocks_replace[("double_block", i)]({"img": img,
199
+ "txt": txt,
200
+ "vec": double_mod,
201
+ "pe": pe,
202
+ "attn_mask": attn_mask},
203
+ {"original_block": block_wrap})
204
+ txt = out["txt"]
205
+ img = out["img"]
206
+ else:
207
+ img, txt = block(img=img,
208
+ txt=txt,
209
+ vec=double_mod,
210
+ pe=pe,
211
+ attn_mask=attn_mask)
212
+
213
+ if control is not None: # Controlnet
214
+ control_i = control.get("input")
215
+ if i < len(control_i):
216
+ add = control_i[i]
217
+ if add is not None:
218
+ img += add
219
+
220
+ img = torch.cat((txt, img), 1)
221
+
222
+ for i, block in enumerate(self.single_blocks):
223
+ if i not in self.skip_dit:
224
+ single_mod = self.get_modulations(mod_vectors, "single", idx=i)
225
+ if ("single_block", i) in blocks_replace:
226
+ def block_wrap(args):
227
+ out = {}
228
+ out["img"] = block(args["img"],
229
+ vec=args["vec"],
230
+ pe=args["pe"],
231
+ attn_mask=args.get("attn_mask"))
232
+ return out
233
+
234
+ out = blocks_replace[("single_block", i)]({"img": img,
235
+ "vec": single_mod,
236
+ "pe": pe,
237
+ "attn_mask": attn_mask},
238
+ {"original_block": block_wrap})
239
+ img = out["img"]
240
+ else:
241
+ img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
242
+
243
+ if control is not None: # Controlnet
244
+ control_o = control.get("output")
245
+ if i < len(control_o):
246
+ add = control_o[i]
247
+ if add is not None:
248
+ img[:, txt.shape[1] :, ...] += add
249
+
250
+ img = img[:, txt.shape[1] :, ...]
251
+ final_mod = self.get_modulations(mod_vectors, "final")
252
+ img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
253
+ return img
254
+
255
+ def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
256
+ bs, c, h, w = x.shape
257
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
258
+
259
+ img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
260
+
261
+ h_len = ((h + (self.patch_size // 2)) // self.patch_size)
262
+ w_len = ((w + (self.patch_size // 2)) // self.patch_size)
263
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
264
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
265
+ img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
266
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
267
+
268
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
269
+ out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
270
+ return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
ComfyUI/comfy/ldm/cosmos/blocks.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional
18
+ import logging
19
+
20
+ import numpy as np
21
+ import torch
22
+ from einops import rearrange, repeat
23
+ from einops.layers.torch import Rearrange
24
+ from torch import nn
25
+
26
+ from comfy.ldm.modules.attention import optimized_attention
27
+
28
+
29
+ def get_normalization(name: str, channels: int, weight_args={}, operations=None):
30
+ if name == "I":
31
+ return nn.Identity()
32
+ elif name == "R":
33
+ return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
34
+ else:
35
+ raise ValueError(f"Normalization {name} not found")
36
+
37
+
38
+ class BaseAttentionOp(nn.Module):
39
+ def __init__(self):
40
+ super().__init__()
41
+
42
+
43
+ class Attention(nn.Module):
44
+ """
45
+ Generalized attention impl.
46
+
47
+ Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
48
+ If `context_dim` is None, self-attention is assumed.
49
+
50
+ Parameters:
51
+ query_dim (int): Dimension of each query vector.
52
+ context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
53
+ heads (int, optional): Number of attention heads. Defaults to 8.
54
+ dim_head (int, optional): Dimension of each head. Defaults to 64.
55
+ dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
56
+ attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
57
+ qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
58
+ out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
59
+ qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
60
+ Defaults to "SSI".
61
+ qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
62
+ Defaults to 'per_head'. Only support 'per_head'.
63
+
64
+ Examples:
65
+ >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
66
+ >>> query = torch.randn(10, 128) # Batch size of 10
67
+ >>> context = torch.randn(10, 256) # Batch size of 10
68
+ >>> output = attn(query, context) # Perform the attention operation
69
+
70
+ Note:
71
+ https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ query_dim: int,
77
+ context_dim=None,
78
+ heads=8,
79
+ dim_head=64,
80
+ dropout=0.0,
81
+ attn_op: Optional[BaseAttentionOp] = None,
82
+ qkv_bias: bool = False,
83
+ out_bias: bool = False,
84
+ qkv_norm: str = "SSI",
85
+ qkv_norm_mode: str = "per_head",
86
+ backend: str = "transformer_engine",
87
+ qkv_format: str = "bshd",
88
+ weight_args={},
89
+ operations=None,
90
+ ) -> None:
91
+ super().__init__()
92
+
93
+ self.is_selfattn = context_dim is None # self attention
94
+
95
+ inner_dim = dim_head * heads
96
+ context_dim = query_dim if context_dim is None else context_dim
97
+
98
+ self.heads = heads
99
+ self.dim_head = dim_head
100
+ self.qkv_norm_mode = qkv_norm_mode
101
+ self.qkv_format = qkv_format
102
+
103
+ if self.qkv_norm_mode == "per_head":
104
+ norm_dim = dim_head
105
+ else:
106
+ raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
107
+
108
+ self.backend = backend
109
+
110
+ self.to_q = nn.Sequential(
111
+ operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
112
+ get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
113
+ )
114
+ self.to_k = nn.Sequential(
115
+ operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
116
+ get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
117
+ )
118
+ self.to_v = nn.Sequential(
119
+ operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
120
+ get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
121
+ )
122
+
123
+ self.to_out = nn.Sequential(
124
+ operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
125
+ nn.Dropout(dropout),
126
+ )
127
+
128
+ def cal_qkv(
129
+ self, x, context=None, mask=None, rope_emb=None, **kwargs
130
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
131
+ del kwargs
132
+
133
+
134
+ """
135
+ self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
136
+ Before 07/24/2024, these modules normalize across all heads.
137
+ After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
138
+ we support to normalize per head.
139
+ To keep the checkpoint copatibility with the previous code,
140
+ we keep the nn.Sequential but call the projection and the normalization layers separately.
141
+ We use a flag `self.qkv_norm_mode` to control the normalization behavior.
142
+ The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
143
+ """
144
+ if self.qkv_norm_mode == "per_head":
145
+ q = self.to_q[0](x)
146
+ context = x if context is None else context
147
+ k = self.to_k[0](context)
148
+ v = self.to_v[0](context)
149
+ q, k, v = map(
150
+ lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
151
+ (q, k, v),
152
+ )
153
+ else:
154
+ raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
155
+
156
+ q = self.to_q[1](q)
157
+ k = self.to_k[1](k)
158
+ v = self.to_v[1](v)
159
+ if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
160
+ # apply_rotary_pos_emb inlined
161
+ q_shape = q.shape
162
+ q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
163
+ q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
164
+ q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
165
+
166
+ # apply_rotary_pos_emb inlined
167
+ k_shape = k.shape
168
+ k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
169
+ k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
170
+ k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
171
+ return q, k, v
172
+
173
+ def forward(
174
+ self,
175
+ x,
176
+ context=None,
177
+ mask=None,
178
+ rope_emb=None,
179
+ **kwargs,
180
+ ):
181
+ """
182
+ Args:
183
+ x (Tensor): The query tensor of shape [B, Mq, K]
184
+ context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
185
+ """
186
+ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
187
+ out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
188
+ del q, k, v
189
+ out = rearrange(out, " b n s c -> s b (n c)")
190
+ return self.to_out(out)
191
+
192
+
193
+ class FeedForward(nn.Module):
194
+ """
195
+ Transformer FFN with optional gating
196
+
197
+ Parameters:
198
+ d_model (int): Dimensionality of input features.
199
+ d_ff (int): Dimensionality of the hidden layer.
200
+ dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
201
+ activation (callable, optional): The activation function applied after the first linear layer.
202
+ Defaults to nn.ReLU().
203
+ is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
204
+ Defaults to False.
205
+ bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
206
+
207
+ Example:
208
+ >>> ff = FeedForward(d_model=512, d_ff=2048)
209
+ >>> x = torch.randn(64, 10, 512) # Example input tensor
210
+ >>> output = ff(x)
211
+ >>> print(output.shape) # Expected shape: (64, 10, 512)
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ d_model: int,
217
+ d_ff: int,
218
+ dropout: float = 0.1,
219
+ activation=nn.ReLU(),
220
+ is_gated: bool = False,
221
+ bias: bool = False,
222
+ weight_args={},
223
+ operations=None,
224
+ ) -> None:
225
+ super().__init__()
226
+
227
+ self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
228
+ self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
229
+
230
+ self.dropout = nn.Dropout(dropout)
231
+ self.activation = activation
232
+ self.is_gated = is_gated
233
+ if is_gated:
234
+ self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
235
+
236
+ def forward(self, x: torch.Tensor):
237
+ g = self.activation(self.layer1(x))
238
+ if self.is_gated:
239
+ x = g * self.linear_gate(x)
240
+ else:
241
+ x = g
242
+ assert self.dropout.p == 0.0, "we skip dropout"
243
+ return self.layer2(x)
244
+
245
+
246
+ class GPT2FeedForward(FeedForward):
247
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
248
+ super().__init__(
249
+ d_model=d_model,
250
+ d_ff=d_ff,
251
+ dropout=dropout,
252
+ activation=nn.GELU(),
253
+ is_gated=False,
254
+ bias=bias,
255
+ weight_args=weight_args,
256
+ operations=operations,
257
+ )
258
+
259
+ def forward(self, x: torch.Tensor):
260
+ assert self.dropout.p == 0.0, "we skip dropout"
261
+
262
+ x = self.layer1(x)
263
+ x = self.activation(x)
264
+ x = self.layer2(x)
265
+
266
+ return x
267
+
268
+
269
+ def modulate(x, shift, scale):
270
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
271
+
272
+
273
+ class Timesteps(nn.Module):
274
+ def __init__(self, num_channels):
275
+ super().__init__()
276
+ self.num_channels = num_channels
277
+
278
+ def forward(self, timesteps):
279
+ half_dim = self.num_channels // 2
280
+ exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
281
+ exponent = exponent / (half_dim - 0.0)
282
+
283
+ emb = torch.exp(exponent)
284
+ emb = timesteps[:, None].float() * emb[None, :]
285
+
286
+ sin_emb = torch.sin(emb)
287
+ cos_emb = torch.cos(emb)
288
+ emb = torch.cat([cos_emb, sin_emb], dim=-1)
289
+
290
+ return emb
291
+
292
+
293
+ class TimestepEmbedding(nn.Module):
294
+ def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
295
+ super().__init__()
296
+ logging.debug(
297
+ f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
298
+ )
299
+ self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
300
+ self.activation = nn.SiLU()
301
+ self.use_adaln_lora = use_adaln_lora
302
+ if use_adaln_lora:
303
+ self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
304
+ else:
305
+ self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
306
+
307
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
308
+ emb = self.linear_1(sample)
309
+ emb = self.activation(emb)
310
+ emb = self.linear_2(emb)
311
+
312
+ if self.use_adaln_lora:
313
+ adaln_lora_B_3D = emb
314
+ emb_B_D = sample
315
+ else:
316
+ emb_B_D = emb
317
+ adaln_lora_B_3D = None
318
+
319
+ return emb_B_D, adaln_lora_B_3D
320
+
321
+
322
+ class FourierFeatures(nn.Module):
323
+ """
324
+ Implements a layer that generates Fourier features from input tensors, based on randomly sampled
325
+ frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
326
+
327
+ [B] -> [B, D]
328
+
329
+ Parameters:
330
+ num_channels (int): The number of Fourier features to generate.
331
+ bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
332
+ normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
333
+ the variance of the features. Defaults to False.
334
+
335
+ Example:
336
+ >>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
337
+ >>> x = torch.randn(10, 256) # Example input tensor
338
+ >>> output = layer(x)
339
+ >>> print(output.shape) # Expected shape: (10, 256)
340
+ """
341
+
342
+ def __init__(self, num_channels, bandwidth=1, normalize=False):
343
+ super().__init__()
344
+ self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
345
+ self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
346
+ self.gain = np.sqrt(2) if normalize else 1
347
+
348
+ def forward(self, x, gain: float = 1.0):
349
+ """
350
+ Apply the Fourier feature transformation to the input tensor.
351
+
352
+ Args:
353
+ x (torch.Tensor): The input tensor.
354
+ gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
355
+
356
+ Returns:
357
+ torch.Tensor: The transformed tensor, with Fourier features applied.
358
+ """
359
+ in_dtype = x.dtype
360
+ x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
361
+ x = x.cos().mul(self.gain * gain).to(in_dtype)
362
+ return x
363
+
364
+
365
+ class PatchEmbed(nn.Module):
366
+ """
367
+ PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
368
+ depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
369
+ making it suitable for video and image processing tasks. It supports dividing the input into patches
370
+ and embedding each patch into a vector of size `out_channels`.
371
+
372
+ Parameters:
373
+ - spatial_patch_size (int): The size of each spatial patch.
374
+ - temporal_patch_size (int): The size of each temporal patch.
375
+ - in_channels (int): Number of input channels. Default: 3.
376
+ - out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
377
+ - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
378
+ """
379
+
380
+ def __init__(
381
+ self,
382
+ spatial_patch_size,
383
+ temporal_patch_size,
384
+ in_channels=3,
385
+ out_channels=768,
386
+ bias=True,
387
+ weight_args={},
388
+ operations=None,
389
+ ):
390
+ super().__init__()
391
+ self.spatial_patch_size = spatial_patch_size
392
+ self.temporal_patch_size = temporal_patch_size
393
+
394
+ self.proj = nn.Sequential(
395
+ Rearrange(
396
+ "b c (t r) (h m) (w n) -> b t h w (c r m n)",
397
+ r=temporal_patch_size,
398
+ m=spatial_patch_size,
399
+ n=spatial_patch_size,
400
+ ),
401
+ operations.Linear(
402
+ in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
403
+ ),
404
+ )
405
+ self.out = nn.Identity()
406
+
407
+ def forward(self, x):
408
+ """
409
+ Forward pass of the PatchEmbed module.
410
+
411
+ Parameters:
412
+ - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
413
+ B is the batch size,
414
+ C is the number of channels,
415
+ T is the temporal dimension,
416
+ H is the height, and
417
+ W is the width of the input.
418
+
419
+ Returns:
420
+ - torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
421
+ """
422
+ assert x.dim() == 5
423
+ _, _, T, H, W = x.shape
424
+ assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
425
+ assert T % self.temporal_patch_size == 0
426
+ x = self.proj(x)
427
+ return self.out(x)
428
+
429
+
430
+ class FinalLayer(nn.Module):
431
+ """
432
+ The final layer of video DiT.
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ hidden_size,
438
+ spatial_patch_size,
439
+ temporal_patch_size,
440
+ out_channels,
441
+ use_adaln_lora: bool = False,
442
+ adaln_lora_dim: int = 256,
443
+ weight_args={},
444
+ operations=None,
445
+ ):
446
+ super().__init__()
447
+ self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
448
+ self.linear = operations.Linear(
449
+ hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
450
+ )
451
+ self.hidden_size = hidden_size
452
+ self.n_adaln_chunks = 2
453
+ self.use_adaln_lora = use_adaln_lora
454
+ if use_adaln_lora:
455
+ self.adaLN_modulation = nn.Sequential(
456
+ nn.SiLU(),
457
+ operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
458
+ operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
459
+ )
460
+ else:
461
+ self.adaLN_modulation = nn.Sequential(
462
+ nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
463
+ )
464
+
465
+ def forward(
466
+ self,
467
+ x_BT_HW_D,
468
+ emb_B_D,
469
+ adaln_lora_B_3D: Optional[torch.Tensor] = None,
470
+ ):
471
+ if self.use_adaln_lora:
472
+ assert adaln_lora_B_3D is not None
473
+ shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
474
+ 2, dim=1
475
+ )
476
+ else:
477
+ shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
478
+
479
+ B = emb_B_D.shape[0]
480
+ T = x_BT_HW_D.shape[0] // B
481
+ shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
482
+ x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
483
+
484
+ x_BT_HW_D = self.linear(x_BT_HW_D)
485
+ return x_BT_HW_D
486
+
487
+
488
+ class VideoAttn(nn.Module):
489
+ """
490
+ Implements video attention with optional cross-attention capabilities.
491
+
492
+ This module processes video features while maintaining their spatio-temporal structure. It can perform
493
+ self-attention within the video features or cross-attention with external context features.
494
+
495
+ Parameters:
496
+ x_dim (int): Dimension of input feature vectors
497
+ context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
498
+ num_heads (int): Number of attention heads
499
+ bias (bool): Whether to include bias in attention projections. Default: False
500
+ qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
501
+ x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
502
+
503
+ Input shape:
504
+ - x: (T, H, W, B, D) video features
505
+ - context (optional): (M, B, D) context features for cross-attention
506
+ where:
507
+ T: temporal dimension
508
+ H: height
509
+ W: width
510
+ B: batch size
511
+ D: feature dimension
512
+ M: context sequence length
513
+ """
514
+
515
+ def __init__(
516
+ self,
517
+ x_dim: int,
518
+ context_dim: Optional[int],
519
+ num_heads: int,
520
+ bias: bool = False,
521
+ qkv_norm_mode: str = "per_head",
522
+ x_format: str = "BTHWD",
523
+ weight_args={},
524
+ operations=None,
525
+ ) -> None:
526
+ super().__init__()
527
+ self.x_format = x_format
528
+
529
+ self.attn = Attention(
530
+ x_dim,
531
+ context_dim,
532
+ num_heads,
533
+ x_dim // num_heads,
534
+ qkv_bias=bias,
535
+ qkv_norm="RRI",
536
+ out_bias=bias,
537
+ qkv_norm_mode=qkv_norm_mode,
538
+ qkv_format="sbhd",
539
+ weight_args=weight_args,
540
+ operations=operations,
541
+ )
542
+
543
+ def forward(
544
+ self,
545
+ x: torch.Tensor,
546
+ context: Optional[torch.Tensor] = None,
547
+ crossattn_mask: Optional[torch.Tensor] = None,
548
+ rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
549
+ ) -> torch.Tensor:
550
+ """
551
+ Forward pass for video attention.
552
+
553
+ Args:
554
+ x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
555
+ context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
556
+ where M is the sequence length of the context.
557
+ crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
558
+ rope_emb_L_1_1_D (Optional[Tensor]):
559
+ Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
560
+
561
+ Returns:
562
+ Tensor: The output tensor with applied attention, maintaining the input shape.
563
+ """
564
+
565
+ x_T_H_W_B_D = x
566
+ context_M_B_D = context
567
+ T, H, W, B, D = x_T_H_W_B_D.shape
568
+ x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
569
+ x_THW_B_D = self.attn(
570
+ x_THW_B_D,
571
+ context_M_B_D,
572
+ crossattn_mask,
573
+ rope_emb=rope_emb_L_1_1_D,
574
+ )
575
+ x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
576
+ return x_T_H_W_B_D
577
+
578
+
579
+ def adaln_norm_state(norm_state, x, scale, shift):
580
+ normalized = norm_state(x)
581
+ return normalized * (1 + scale) + shift
582
+
583
+
584
+ class DITBuildingBlock(nn.Module):
585
+ """
586
+ A building block for the DiT (Diffusion Transformer) architecture that supports different types of
587
+ attention and MLP operations with adaptive layer normalization.
588
+
589
+ Parameters:
590
+ block_type (str): Type of block - one of:
591
+ - "cross_attn"/"ca": Cross-attention
592
+ - "full_attn"/"fa": Full self-attention
593
+ - "mlp"/"ff": MLP/feedforward block
594
+ x_dim (int): Dimension of input features
595
+ context_dim (Optional[int]): Dimension of context features for cross-attention
596
+ num_heads (int): Number of attention heads
597
+ mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
598
+ bias (bool): Whether to use bias in layers. Default: False
599
+ mlp_dropout (float): Dropout rate for MLP. Default: 0.0
600
+ qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
601
+ x_format (str): Input tensor format. Default: "BTHWD"
602
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
603
+ adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
604
+ """
605
+
606
+ def __init__(
607
+ self,
608
+ block_type: str,
609
+ x_dim: int,
610
+ context_dim: Optional[int],
611
+ num_heads: int,
612
+ mlp_ratio: float = 4.0,
613
+ bias: bool = False,
614
+ mlp_dropout: float = 0.0,
615
+ qkv_norm_mode: str = "per_head",
616
+ x_format: str = "BTHWD",
617
+ use_adaln_lora: bool = False,
618
+ adaln_lora_dim: int = 256,
619
+ weight_args={},
620
+ operations=None
621
+ ) -> None:
622
+ block_type = block_type.lower()
623
+
624
+ super().__init__()
625
+ self.x_format = x_format
626
+ if block_type in ["cross_attn", "ca"]:
627
+ self.block = VideoAttn(
628
+ x_dim,
629
+ context_dim,
630
+ num_heads,
631
+ bias=bias,
632
+ qkv_norm_mode=qkv_norm_mode,
633
+ x_format=self.x_format,
634
+ weight_args=weight_args,
635
+ operations=operations,
636
+ )
637
+ elif block_type in ["full_attn", "fa"]:
638
+ self.block = VideoAttn(
639
+ x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
640
+ )
641
+ elif block_type in ["mlp", "ff"]:
642
+ self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
643
+ else:
644
+ raise ValueError(f"Unknown block type: {block_type}")
645
+
646
+ self.block_type = block_type
647
+ self.use_adaln_lora = use_adaln_lora
648
+
649
+ self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
650
+ self.n_adaln_chunks = 3
651
+ if use_adaln_lora:
652
+ self.adaLN_modulation = nn.Sequential(
653
+ nn.SiLU(),
654
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
655
+ operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
656
+ )
657
+ else:
658
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
659
+
660
+ def forward(
661
+ self,
662
+ x: torch.Tensor,
663
+ emb_B_D: torch.Tensor,
664
+ crossattn_emb: torch.Tensor,
665
+ crossattn_mask: Optional[torch.Tensor] = None,
666
+ rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
667
+ adaln_lora_B_3D: Optional[torch.Tensor] = None,
668
+ ) -> torch.Tensor:
669
+ """
670
+ Forward pass for dynamically configured blocks with adaptive normalization.
671
+
672
+ Args:
673
+ x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
674
+ emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
675
+ crossattn_emb (Tensor): Tensor for cross-attention blocks.
676
+ crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
677
+ rope_emb_L_1_1_D (Optional[Tensor]):
678
+ Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
679
+
680
+ Returns:
681
+ Tensor: The output tensor after processing through the configured block and adaptive normalization.
682
+ """
683
+ if self.use_adaln_lora:
684
+ shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
685
+ self.n_adaln_chunks, dim=1
686
+ )
687
+ else:
688
+ shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
689
+
690
+ shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
691
+ shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
692
+ scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
693
+ gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
694
+ )
695
+
696
+ if self.block_type in ["mlp", "ff"]:
697
+ x = x + gate_1_1_1_B_D * self.block(
698
+ adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
699
+ )
700
+ elif self.block_type in ["full_attn", "fa"]:
701
+ x = x + gate_1_1_1_B_D * self.block(
702
+ adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
703
+ context=None,
704
+ rope_emb_L_1_1_D=rope_emb_L_1_1_D,
705
+ )
706
+ elif self.block_type in ["cross_attn", "ca"]:
707
+ x = x + gate_1_1_1_B_D * self.block(
708
+ adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
709
+ context=crossattn_emb,
710
+ crossattn_mask=crossattn_mask,
711
+ rope_emb_L_1_1_D=rope_emb_L_1_1_D,
712
+ )
713
+ else:
714
+ raise ValueError(f"Unknown block type: {self.block_type}")
715
+
716
+ return x
717
+
718
+
719
+ class GeneralDITTransformerBlock(nn.Module):
720
+ """
721
+ A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
722
+ Each block in the sequence is specified by a block configuration string.
723
+
724
+ Parameters:
725
+ x_dim (int): Dimension of input features
726
+ context_dim (int): Dimension of context features for cross-attention blocks
727
+ num_heads (int): Number of attention heads
728
+ block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
729
+ full-attention, then MLP)
730
+ mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
731
+ x_format (str): Input tensor format. Default: "BTHWD"
732
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
733
+ adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
734
+
735
+ The block_config string uses "-" to separate block types:
736
+ - "ca"/"cross_attn": Cross-attention block
737
+ - "fa"/"full_attn": Full self-attention block
738
+ - "mlp"/"ff": MLP/feedforward block
739
+
740
+ Example:
741
+ block_config = "ca-fa-mlp" creates a sequence of:
742
+ 1. Cross-attention block
743
+ 2. Full self-attention block
744
+ 3. MLP block
745
+ """
746
+
747
+ def __init__(
748
+ self,
749
+ x_dim: int,
750
+ context_dim: int,
751
+ num_heads: int,
752
+ block_config: str,
753
+ mlp_ratio: float = 4.0,
754
+ x_format: str = "BTHWD",
755
+ use_adaln_lora: bool = False,
756
+ adaln_lora_dim: int = 256,
757
+ weight_args={},
758
+ operations=None
759
+ ):
760
+ super().__init__()
761
+ self.blocks = nn.ModuleList()
762
+ self.x_format = x_format
763
+ for block_type in block_config.split("-"):
764
+ self.blocks.append(
765
+ DITBuildingBlock(
766
+ block_type,
767
+ x_dim,
768
+ context_dim,
769
+ num_heads,
770
+ mlp_ratio,
771
+ x_format=self.x_format,
772
+ use_adaln_lora=use_adaln_lora,
773
+ adaln_lora_dim=adaln_lora_dim,
774
+ weight_args=weight_args,
775
+ operations=operations,
776
+ )
777
+ )
778
+
779
+ def forward(
780
+ self,
781
+ x: torch.Tensor,
782
+ emb_B_D: torch.Tensor,
783
+ crossattn_emb: torch.Tensor,
784
+ crossattn_mask: Optional[torch.Tensor] = None,
785
+ rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
786
+ adaln_lora_B_3D: Optional[torch.Tensor] = None,
787
+ ) -> torch.Tensor:
788
+ for block in self.blocks:
789
+ x = block(
790
+ x,
791
+ emb_B_D,
792
+ crossattn_emb,
793
+ crossattn_mask,
794
+ rope_emb_L_1_1_D=rope_emb_L_1_1_D,
795
+ adaln_lora_B_3D=adaln_lora_B_3D,
796
+ )
797
+ return x
ComfyUI/comfy/ldm/cosmos/model.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
18
+ """
19
+
20
+ from typing import Optional, Tuple
21
+
22
+ import torch
23
+ from einops import rearrange
24
+ from torch import nn
25
+ from torchvision import transforms
26
+
27
+ from enum import Enum
28
+ import logging
29
+
30
+ from .blocks import (
31
+ FinalLayer,
32
+ GeneralDITTransformerBlock,
33
+ PatchEmbed,
34
+ TimestepEmbedding,
35
+ Timesteps,
36
+ )
37
+
38
+ from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
39
+
40
+
41
+ class DataType(Enum):
42
+ IMAGE = "image"
43
+ VIDEO = "video"
44
+
45
+
46
+ class GeneralDIT(nn.Module):
47
+ """
48
+ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
49
+
50
+ Args:
51
+ max_img_h (int): Maximum height of the input images.
52
+ max_img_w (int): Maximum width of the input images.
53
+ max_frames (int): Maximum number of frames in the video sequence.
54
+ in_channels (int): Number of input channels (e.g., RGB channels for color images).
55
+ out_channels (int): Number of output channels.
56
+ patch_spatial (tuple): Spatial resolution of patches for input processing.
57
+ patch_temporal (int): Temporal resolution of patches for input processing.
58
+ concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
59
+ block_config (str): Configuration of the transformer block. See Notes for supported block types.
60
+ model_channels (int): Base number of channels used throughout the model.
61
+ num_blocks (int): Number of transformer blocks.
62
+ num_heads (int): Number of heads in the multi-head attention layers.
63
+ mlp_ratio (float): Expansion ratio for MLP blocks.
64
+ block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
65
+ crossattn_emb_channels (int): Number of embedding channels for cross-attention.
66
+ use_cross_attn_mask (bool): Whether to use mask in cross-attention.
67
+ pos_emb_cls (str): Type of positional embeddings.
68
+ pos_emb_learnable (bool): Whether positional embeddings are learnable.
69
+ pos_emb_interpolation (str): Method for interpolating positional embeddings.
70
+ affline_emb_norm (bool): Whether to normalize affine embeddings.
71
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA.
72
+ adaln_lora_dim (int): Dimension for AdaLN-LoRA.
73
+ rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
74
+ rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
75
+ rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
76
+ extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
77
+ extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
78
+ extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
79
+ extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
80
+ extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
81
+
82
+ Notes:
83
+ Supported block types in block_config:
84
+ * cross_attn, ca: Cross attention
85
+ * full_attn: Full attention on all flattened tokens
86
+ * mlp, ff: Feed forward block
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ max_img_h: int,
92
+ max_img_w: int,
93
+ max_frames: int,
94
+ in_channels: int,
95
+ out_channels: int,
96
+ patch_spatial: tuple,
97
+ patch_temporal: int,
98
+ concat_padding_mask: bool = True,
99
+ # attention settings
100
+ block_config: str = "FA-CA-MLP",
101
+ model_channels: int = 768,
102
+ num_blocks: int = 10,
103
+ num_heads: int = 16,
104
+ mlp_ratio: float = 4.0,
105
+ block_x_format: str = "BTHWD",
106
+ # cross attention settings
107
+ crossattn_emb_channels: int = 1024,
108
+ use_cross_attn_mask: bool = False,
109
+ # positional embedding settings
110
+ pos_emb_cls: str = "sincos",
111
+ pos_emb_learnable: bool = False,
112
+ pos_emb_interpolation: str = "crop",
113
+ affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
114
+ use_adaln_lora: bool = False,
115
+ adaln_lora_dim: int = 256,
116
+ rope_h_extrapolation_ratio: float = 1.0,
117
+ rope_w_extrapolation_ratio: float = 1.0,
118
+ rope_t_extrapolation_ratio: float = 1.0,
119
+ extra_per_block_abs_pos_emb: bool = False,
120
+ extra_per_block_abs_pos_emb_type: str = "sincos",
121
+ extra_h_extrapolation_ratio: float = 1.0,
122
+ extra_w_extrapolation_ratio: float = 1.0,
123
+ extra_t_extrapolation_ratio: float = 1.0,
124
+ image_model=None,
125
+ device=None,
126
+ dtype=None,
127
+ operations=None,
128
+ ) -> None:
129
+ super().__init__()
130
+ self.max_img_h = max_img_h
131
+ self.max_img_w = max_img_w
132
+ self.max_frames = max_frames
133
+ self.in_channels = in_channels
134
+ self.out_channels = out_channels
135
+ self.patch_spatial = patch_spatial
136
+ self.patch_temporal = patch_temporal
137
+ self.num_heads = num_heads
138
+ self.num_blocks = num_blocks
139
+ self.model_channels = model_channels
140
+ self.use_cross_attn_mask = use_cross_attn_mask
141
+ self.concat_padding_mask = concat_padding_mask
142
+ # positional embedding settings
143
+ self.pos_emb_cls = pos_emb_cls
144
+ self.pos_emb_learnable = pos_emb_learnable
145
+ self.pos_emb_interpolation = pos_emb_interpolation
146
+ self.affline_emb_norm = affline_emb_norm
147
+ self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
148
+ self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
149
+ self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
150
+ self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
151
+ self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
152
+ self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
153
+ self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
154
+ self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
155
+ self.dtype = dtype
156
+ weight_args = {"device": device, "dtype": dtype}
157
+
158
+ in_channels = in_channels + 1 if concat_padding_mask else in_channels
159
+ self.x_embedder = PatchEmbed(
160
+ spatial_patch_size=patch_spatial,
161
+ temporal_patch_size=patch_temporal,
162
+ in_channels=in_channels,
163
+ out_channels=model_channels,
164
+ bias=False,
165
+ weight_args=weight_args,
166
+ operations=operations,
167
+ )
168
+
169
+ self.build_pos_embed(device=device, dtype=dtype)
170
+ self.block_x_format = block_x_format
171
+ self.use_adaln_lora = use_adaln_lora
172
+ self.adaln_lora_dim = adaln_lora_dim
173
+ self.t_embedder = nn.ModuleList(
174
+ [Timesteps(model_channels),
175
+ TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
176
+ )
177
+
178
+ self.blocks = nn.ModuleDict()
179
+
180
+ for idx in range(num_blocks):
181
+ self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
182
+ x_dim=model_channels,
183
+ context_dim=crossattn_emb_channels,
184
+ num_heads=num_heads,
185
+ block_config=block_config,
186
+ mlp_ratio=mlp_ratio,
187
+ x_format=self.block_x_format,
188
+ use_adaln_lora=use_adaln_lora,
189
+ adaln_lora_dim=adaln_lora_dim,
190
+ weight_args=weight_args,
191
+ operations=operations,
192
+ )
193
+
194
+ if self.affline_emb_norm:
195
+ logging.debug("Building affine embedding normalization layer")
196
+ self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
197
+ else:
198
+ self.affline_norm = nn.Identity()
199
+
200
+ self.final_layer = FinalLayer(
201
+ hidden_size=self.model_channels,
202
+ spatial_patch_size=self.patch_spatial,
203
+ temporal_patch_size=self.patch_temporal,
204
+ out_channels=self.out_channels,
205
+ use_adaln_lora=self.use_adaln_lora,
206
+ adaln_lora_dim=self.adaln_lora_dim,
207
+ weight_args=weight_args,
208
+ operations=operations,
209
+ )
210
+
211
+ def build_pos_embed(self, device=None, dtype=None):
212
+ if self.pos_emb_cls == "rope3d":
213
+ cls_type = VideoRopePosition3DEmb
214
+ else:
215
+ raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
216
+
217
+ logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
218
+ kwargs = dict(
219
+ model_channels=self.model_channels,
220
+ len_h=self.max_img_h // self.patch_spatial,
221
+ len_w=self.max_img_w // self.patch_spatial,
222
+ len_t=self.max_frames // self.patch_temporal,
223
+ is_learnable=self.pos_emb_learnable,
224
+ interpolation=self.pos_emb_interpolation,
225
+ head_dim=self.model_channels // self.num_heads,
226
+ h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
227
+ w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
228
+ t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
229
+ device=device,
230
+ )
231
+ self.pos_embedder = cls_type(
232
+ **kwargs,
233
+ )
234
+
235
+ if self.extra_per_block_abs_pos_emb:
236
+ assert self.extra_per_block_abs_pos_emb_type in [
237
+ "learnable",
238
+ ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
239
+ kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
240
+ kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
241
+ kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
242
+ kwargs["device"] = device
243
+ kwargs["dtype"] = dtype
244
+ self.extra_pos_embedder = LearnablePosEmbAxis(
245
+ **kwargs,
246
+ )
247
+
248
+ def prepare_embedded_sequence(
249
+ self,
250
+ x_B_C_T_H_W: torch.Tensor,
251
+ fps: Optional[torch.Tensor] = None,
252
+ padding_mask: Optional[torch.Tensor] = None,
253
+ latent_condition: Optional[torch.Tensor] = None,
254
+ latent_condition_sigma: Optional[torch.Tensor] = None,
255
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
256
+ """
257
+ Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
258
+
259
+ Args:
260
+ x_B_C_T_H_W (torch.Tensor): video
261
+ fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
262
+ If None, a default value (`self.base_fps`) will be used.
263
+ padding_mask (Optional[torch.Tensor]): current it is not used
264
+
265
+ Returns:
266
+ Tuple[torch.Tensor, Optional[torch.Tensor]]:
267
+ - A tensor of shape (B, T, H, W, D) with the embedded sequence.
268
+ - An optional positional embedding tensor, returned only if the positional embedding class
269
+ (`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
270
+
271
+ Notes:
272
+ - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
273
+ - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
274
+ - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
275
+ the `self.pos_embedder` with the shape [T, H, W].
276
+ - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
277
+ `self.pos_embedder` with the fps tensor.
278
+ - Otherwise, the positional embeddings are generated without considering fps.
279
+ """
280
+ if self.concat_padding_mask:
281
+ if padding_mask is not None:
282
+ padding_mask = transforms.functional.resize(
283
+ padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
284
+ )
285
+ else:
286
+ padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
287
+
288
+ x_B_C_T_H_W = torch.cat(
289
+ [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
290
+ )
291
+ x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
292
+
293
+ if self.extra_per_block_abs_pos_emb:
294
+ extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
295
+ else:
296
+ extra_pos_emb = None
297
+
298
+ if "rope" in self.pos_emb_cls.lower():
299
+ return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
300
+
301
+ if "fps_aware" in self.pos_emb_cls:
302
+ x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
303
+ else:
304
+ x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
305
+
306
+ return x_B_T_H_W_D, None, extra_pos_emb
307
+
308
+ def decoder_head(
309
+ self,
310
+ x_B_T_H_W_D: torch.Tensor,
311
+ emb_B_D: torch.Tensor,
312
+ crossattn_emb: torch.Tensor,
313
+ origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
314
+ crossattn_mask: Optional[torch.Tensor] = None,
315
+ adaln_lora_B_3D: Optional[torch.Tensor] = None,
316
+ ) -> torch.Tensor:
317
+ del crossattn_emb, crossattn_mask
318
+ B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
319
+ x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
320
+ x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
321
+ # This is to ensure x_BT_HW_D has the correct shape because
322
+ # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
323
+ x_BT_HW_D = x_BT_HW_D.view(
324
+ B * T_before_patchify // self.patch_temporal,
325
+ H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
326
+ -1,
327
+ )
328
+ x_B_D_T_H_W = rearrange(
329
+ x_BT_HW_D,
330
+ "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
331
+ p1=self.patch_spatial,
332
+ p2=self.patch_spatial,
333
+ H=H_before_patchify // self.patch_spatial,
334
+ W=W_before_patchify // self.patch_spatial,
335
+ t=self.patch_temporal,
336
+ B=B,
337
+ )
338
+ return x_B_D_T_H_W
339
+
340
+ def forward_before_blocks(
341
+ self,
342
+ x: torch.Tensor,
343
+ timesteps: torch.Tensor,
344
+ crossattn_emb: torch.Tensor,
345
+ crossattn_mask: Optional[torch.Tensor] = None,
346
+ fps: Optional[torch.Tensor] = None,
347
+ image_size: Optional[torch.Tensor] = None,
348
+ padding_mask: Optional[torch.Tensor] = None,
349
+ scalar_feature: Optional[torch.Tensor] = None,
350
+ data_type: Optional[DataType] = DataType.VIDEO,
351
+ latent_condition: Optional[torch.Tensor] = None,
352
+ latent_condition_sigma: Optional[torch.Tensor] = None,
353
+ **kwargs,
354
+ ) -> torch.Tensor:
355
+ """
356
+ Args:
357
+ x: (B, C, T, H, W) tensor of spatial-temp inputs
358
+ timesteps: (B, ) tensor of timesteps
359
+ crossattn_emb: (B, N, D) tensor of cross-attention embeddings
360
+ crossattn_mask: (B, N) tensor of cross-attention masks
361
+ """
362
+ del kwargs
363
+ assert isinstance(
364
+ data_type, DataType
365
+ ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
366
+ original_shape = x.shape
367
+ x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
368
+ x,
369
+ fps=fps,
370
+ padding_mask=padding_mask,
371
+ latent_condition=latent_condition,
372
+ latent_condition_sigma=latent_condition_sigma,
373
+ )
374
+ # logging affline scale information
375
+ affline_scale_log_info = {}
376
+
377
+ timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
378
+ affline_emb_B_D = timesteps_B_D
379
+ affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
380
+
381
+ if scalar_feature is not None:
382
+ raise NotImplementedError("Scalar feature is not implemented yet.")
383
+
384
+ affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
385
+ affline_emb_B_D = self.affline_norm(affline_emb_B_D)
386
+
387
+ if self.use_cross_attn_mask:
388
+ if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
389
+ crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
390
+ crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
391
+ else:
392
+ crossattn_mask = None
393
+
394
+ if self.blocks["block0"].x_format == "THWBD":
395
+ x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
396
+ if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
397
+ extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
398
+ extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
399
+ )
400
+ crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
401
+
402
+ if crossattn_mask:
403
+ crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
404
+
405
+ elif self.blocks["block0"].x_format == "BTHWD":
406
+ x = x_B_T_H_W_D
407
+ else:
408
+ raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
409
+ output = {
410
+ "x": x,
411
+ "affline_emb_B_D": affline_emb_B_D,
412
+ "crossattn_emb": crossattn_emb,
413
+ "crossattn_mask": crossattn_mask,
414
+ "rope_emb_L_1_1_D": rope_emb_L_1_1_D,
415
+ "adaln_lora_B_3D": adaln_lora_B_3D,
416
+ "original_shape": original_shape,
417
+ "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
418
+ }
419
+ return output
420
+
421
+ def forward(
422
+ self,
423
+ x: torch.Tensor,
424
+ timesteps: torch.Tensor,
425
+ context: torch.Tensor,
426
+ attention_mask: Optional[torch.Tensor] = None,
427
+ # crossattn_emb: torch.Tensor,
428
+ # crossattn_mask: Optional[torch.Tensor] = None,
429
+ fps: Optional[torch.Tensor] = None,
430
+ image_size: Optional[torch.Tensor] = None,
431
+ padding_mask: Optional[torch.Tensor] = None,
432
+ scalar_feature: Optional[torch.Tensor] = None,
433
+ data_type: Optional[DataType] = DataType.VIDEO,
434
+ latent_condition: Optional[torch.Tensor] = None,
435
+ latent_condition_sigma: Optional[torch.Tensor] = None,
436
+ condition_video_augment_sigma: Optional[torch.Tensor] = None,
437
+ **kwargs,
438
+ ):
439
+ """
440
+ Args:
441
+ x: (B, C, T, H, W) tensor of spatial-temp inputs
442
+ timesteps: (B, ) tensor of timesteps
443
+ crossattn_emb: (B, N, D) tensor of cross-attention embeddings
444
+ crossattn_mask: (B, N) tensor of cross-attention masks
445
+ condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
446
+ augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
447
+ we need forward_before_blocks pass to the forward_before_blocks function.
448
+ """
449
+
450
+ crossattn_emb = context
451
+ crossattn_mask = attention_mask
452
+
453
+ inputs = self.forward_before_blocks(
454
+ x=x,
455
+ timesteps=timesteps,
456
+ crossattn_emb=crossattn_emb,
457
+ crossattn_mask=crossattn_mask,
458
+ fps=fps,
459
+ image_size=image_size,
460
+ padding_mask=padding_mask,
461
+ scalar_feature=scalar_feature,
462
+ data_type=data_type,
463
+ latent_condition=latent_condition,
464
+ latent_condition_sigma=latent_condition_sigma,
465
+ condition_video_augment_sigma=condition_video_augment_sigma,
466
+ **kwargs,
467
+ )
468
+ x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
469
+ inputs["x"],
470
+ inputs["affline_emb_B_D"],
471
+ inputs["crossattn_emb"],
472
+ inputs["crossattn_mask"],
473
+ inputs["rope_emb_L_1_1_D"],
474
+ inputs["adaln_lora_B_3D"],
475
+ inputs["original_shape"],
476
+ )
477
+ extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
478
+ del inputs
479
+
480
+ if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
481
+ assert (
482
+ x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
483
+ ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
484
+
485
+ for _, block in self.blocks.items():
486
+ assert (
487
+ self.blocks["block0"].x_format == block.x_format
488
+ ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
489
+
490
+ if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
491
+ x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
492
+ x = block(
493
+ x,
494
+ affline_emb_B_D,
495
+ crossattn_emb,
496
+ crossattn_mask,
497
+ rope_emb_L_1_1_D=rope_emb_L_1_1_D,
498
+ adaln_lora_B_3D=adaln_lora_B_3D,
499
+ )
500
+
501
+ x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
502
+
503
+ x_B_D_T_H_W = self.decoder_head(
504
+ x_B_T_H_W_D=x_B_T_H_W_D,
505
+ emb_B_D=affline_emb_B_D,
506
+ crossattn_emb=None,
507
+ origin_shape=original_shape,
508
+ crossattn_mask=None,
509
+ adaln_lora_B_3D=adaln_lora_B_3D,
510
+ )
511
+
512
+ return x_B_D_T_H_W
ComfyUI/comfy/ldm/cosmos/position_embedding.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Optional
17
+
18
+ import torch
19
+ from einops import rearrange, repeat
20
+ from torch import nn
21
+ import math
22
+
23
+
24
+ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
25
+ """
26
+ Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
27
+
28
+ Args:
29
+ x (torch.Tensor): The input tensor to normalize.
30
+ dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
31
+ eps (float, optional): A small constant to ensure numerical stability during division.
32
+
33
+ Returns:
34
+ torch.Tensor: The normalized tensor.
35
+ """
36
+ if dim is None:
37
+ dim = list(range(1, x.ndim))
38
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
39
+ norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
40
+ return x / norm.to(x.dtype)
41
+
42
+
43
+ class VideoPositionEmb(nn.Module):
44
+ def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
45
+ """
46
+ It delegates the embedding generation to generate_embeddings function.
47
+ """
48
+ B_T_H_W_C = x_B_T_H_W_C.shape
49
+ embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
50
+
51
+ return embeddings
52
+
53
+ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
54
+ raise NotImplementedError
55
+
56
+
57
+ class VideoRopePosition3DEmb(VideoPositionEmb):
58
+ def __init__(
59
+ self,
60
+ *, # enforce keyword arguments
61
+ head_dim: int,
62
+ len_h: int,
63
+ len_w: int,
64
+ len_t: int,
65
+ base_fps: int = 24,
66
+ h_extrapolation_ratio: float = 1.0,
67
+ w_extrapolation_ratio: float = 1.0,
68
+ t_extrapolation_ratio: float = 1.0,
69
+ enable_fps_modulation: bool = True,
70
+ device=None,
71
+ **kwargs, # used for compatibility with other positional embeddings; unused in this class
72
+ ):
73
+ del kwargs
74
+ super().__init__()
75
+ self.base_fps = base_fps
76
+ self.max_h = len_h
77
+ self.max_w = len_w
78
+ self.enable_fps_modulation = enable_fps_modulation
79
+
80
+ dim = head_dim
81
+ dim_h = dim // 6 * 2
82
+ dim_w = dim_h
83
+ dim_t = dim - 2 * dim_h
84
+ assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
85
+ self.register_buffer(
86
+ "dim_spatial_range",
87
+ torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
88
+ persistent=False,
89
+ )
90
+ self.register_buffer(
91
+ "dim_temporal_range",
92
+ torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
93
+ persistent=False,
94
+ )
95
+
96
+ self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
97
+ self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
98
+ self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
99
+
100
+ def generate_embeddings(
101
+ self,
102
+ B_T_H_W_C: torch.Size,
103
+ fps: Optional[torch.Tensor] = None,
104
+ h_ntk_factor: Optional[float] = None,
105
+ w_ntk_factor: Optional[float] = None,
106
+ t_ntk_factor: Optional[float] = None,
107
+ device=None,
108
+ dtype=None,
109
+ ):
110
+ """
111
+ Generate embeddings for the given input size.
112
+
113
+ Args:
114
+ B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
115
+ fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
116
+ h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
117
+ w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
118
+ t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
119
+
120
+ Returns:
121
+ Not specified in the original code snippet.
122
+ """
123
+ h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
124
+ w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
125
+ t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
126
+
127
+ h_theta = 10000.0 * h_ntk_factor
128
+ w_theta = 10000.0 * w_ntk_factor
129
+ t_theta = 10000.0 * t_ntk_factor
130
+
131
+ h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
132
+ w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
133
+ temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
134
+
135
+ B, T, H, W, _ = B_T_H_W_C
136
+ seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
137
+ uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
138
+ assert (
139
+ uniform_fps or B == 1 or T == 1
140
+ ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
141
+ half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
142
+ half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
143
+
144
+ # apply sequence scaling in temporal dimension
145
+ if fps is None or self.enable_fps_modulation is False: # image case
146
+ half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
147
+ else:
148
+ half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
149
+
150
+ half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
151
+ half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
152
+ half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
153
+
154
+ em_T_H_W_D = torch.cat(
155
+ [
156
+ repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
157
+ repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
158
+ repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
159
+ ]
160
+ , dim=-2,
161
+ )
162
+
163
+ return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
164
+
165
+
166
+ class LearnablePosEmbAxis(VideoPositionEmb):
167
+ def __init__(
168
+ self,
169
+ *, # enforce keyword arguments
170
+ interpolation: str,
171
+ model_channels: int,
172
+ len_h: int,
173
+ len_w: int,
174
+ len_t: int,
175
+ device=None,
176
+ dtype=None,
177
+ **kwargs,
178
+ ):
179
+ """
180
+ Args:
181
+ interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
182
+ """
183
+ del kwargs # unused
184
+ super().__init__()
185
+ self.interpolation = interpolation
186
+ assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
187
+
188
+ self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
189
+ self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
190
+ self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
191
+
192
+ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
193
+ B, T, H, W, _ = B_T_H_W_C
194
+ if self.interpolation == "crop":
195
+ emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
196
+ emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
197
+ emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
198
+ emb = (
199
+ repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
200
+ + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
201
+ + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
202
+ )
203
+ assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
204
+ else:
205
+ raise ValueError(f"Unknown interpolation method {self.interpolation}")
206
+
207
+ return normalize(emb, dim=-1, eps=1e-6)
ComfyUI/comfy/ldm/cosmos/predict2.py ADDED
@@ -0,0 +1,864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original code from: https://github.com/nvidia-cosmos/cosmos-predict2
2
+
3
+ import torch
4
+ from torch import nn
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ import logging
8
+ from typing import Callable, Optional, Tuple
9
+ import math
10
+
11
+ from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
12
+ from torchvision import transforms
13
+
14
+ from comfy.ldm.modules.attention import optimized_attention
15
+
16
+ def apply_rotary_pos_emb(
17
+ t: torch.Tensor,
18
+ freqs: torch.Tensor,
19
+ ) -> torch.Tensor:
20
+ t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
21
+ t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
22
+ t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
23
+ return t_out
24
+
25
+
26
+ # ---------------------- Feed Forward Network -----------------------
27
+ class GPT2FeedForward(nn.Module):
28
+ def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
29
+ super().__init__()
30
+ self.activation = nn.GELU()
31
+ self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
32
+ self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
33
+
34
+ self._layer_id = None
35
+ self._dim = d_model
36
+ self._hidden_dim = d_ff
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ x = self.layer1(x)
40
+
41
+ x = self.activation(x)
42
+ x = self.layer2(x)
43
+ return x
44
+
45
+
46
+ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
47
+ """Computes multi-head attention using PyTorch's native implementation.
48
+
49
+ This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
50
+ It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
51
+ attention, and rearranges the output back to the original format.
52
+
53
+ The input tensor names use the following dimension conventions:
54
+
55
+ - B: batch size
56
+ - S: sequence length
57
+ - H: number of attention heads
58
+ - D: head dimension
59
+
60
+ Args:
61
+ q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
62
+ k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
63
+ v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
64
+
65
+ Returns:
66
+ Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
67
+ """
68
+ in_q_shape = q_B_S_H_D.shape
69
+ in_k_shape = k_B_S_H_D.shape
70
+ q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
71
+ k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
72
+ v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
73
+ return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
74
+
75
+
76
+ class Attention(nn.Module):
77
+ """
78
+ A flexible attention module supporting both self-attention and cross-attention mechanisms.
79
+
80
+ This module implements a multi-head attention layer that can operate in either self-attention
81
+ or cross-attention mode. The mode is determined by whether a context dimension is provided.
82
+ The implementation uses scaled dot-product attention and supports optional bias terms and
83
+ dropout regularization.
84
+
85
+ Args:
86
+ query_dim (int): The dimensionality of the query vectors.
87
+ context_dim (int, optional): The dimensionality of the context (key/value) vectors.
88
+ If None, the module operates in self-attention mode using query_dim. Default: None
89
+ n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
90
+ head_dim (int, optional): The dimension of each attention head. Default: 64
91
+ dropout (float, optional): Dropout probability applied to the output. Default: 0.0
92
+ qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
93
+ backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
94
+
95
+ Examples:
96
+ >>> # Self-attention with 512 dimensions and 8 heads
97
+ >>> self_attn = Attention(query_dim=512)
98
+ >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
99
+ >>> out = self_attn(x) # (32, 16, 512)
100
+
101
+ >>> # Cross-attention
102
+ >>> cross_attn = Attention(query_dim=512, context_dim=256)
103
+ >>> query = torch.randn(32, 16, 512)
104
+ >>> context = torch.randn(32, 8, 256)
105
+ >>> out = cross_attn(query, context) # (32, 16, 512)
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ query_dim: int,
111
+ context_dim: Optional[int] = None,
112
+ n_heads: int = 8,
113
+ head_dim: int = 64,
114
+ dropout: float = 0.0,
115
+ device=None,
116
+ dtype=None,
117
+ operations=None,
118
+ ) -> None:
119
+ super().__init__()
120
+ logging.debug(
121
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
122
+ f"{n_heads} heads with a dimension of {head_dim}."
123
+ )
124
+ self.is_selfattn = context_dim is None # self attention
125
+
126
+ context_dim = query_dim if context_dim is None else context_dim
127
+ inner_dim = head_dim * n_heads
128
+
129
+ self.n_heads = n_heads
130
+ self.head_dim = head_dim
131
+ self.query_dim = query_dim
132
+ self.context_dim = context_dim
133
+
134
+ self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
135
+ self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
136
+
137
+ self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
138
+ self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
139
+
140
+ self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
141
+ self.v_norm = nn.Identity()
142
+
143
+ self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
144
+ self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
145
+
146
+ self.attn_op = torch_attention_op
147
+
148
+ self._query_dim = query_dim
149
+ self._context_dim = context_dim
150
+ self._inner_dim = inner_dim
151
+
152
+ def compute_qkv(
153
+ self,
154
+ x: torch.Tensor,
155
+ context: Optional[torch.Tensor] = None,
156
+ rope_emb: Optional[torch.Tensor] = None,
157
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
158
+ q = self.q_proj(x)
159
+ context = x if context is None else context
160
+ k = self.k_proj(context)
161
+ v = self.v_proj(context)
162
+ q, k, v = map(
163
+ lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
164
+ (q, k, v),
165
+ )
166
+
167
+ def apply_norm_and_rotary_pos_emb(
168
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
169
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
170
+ q = self.q_norm(q)
171
+ k = self.k_norm(k)
172
+ v = self.v_norm(v)
173
+ if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
174
+ q = apply_rotary_pos_emb(q, rope_emb)
175
+ k = apply_rotary_pos_emb(k, rope_emb)
176
+ return q, k, v
177
+
178
+ q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
179
+
180
+ return q, k, v
181
+
182
+ def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
183
+ result = self.attn_op(q, k, v) # [B, S, H, D]
184
+ return self.output_dropout(self.output_proj(result))
185
+
186
+ def forward(
187
+ self,
188
+ x: torch.Tensor,
189
+ context: Optional[torch.Tensor] = None,
190
+ rope_emb: Optional[torch.Tensor] = None,
191
+ ) -> torch.Tensor:
192
+ """
193
+ Args:
194
+ x (Tensor): The query tensor of shape [B, Mq, K]
195
+ context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
196
+ """
197
+ q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
198
+ return self.compute_attention(q, k, v)
199
+
200
+
201
+ class Timesteps(nn.Module):
202
+ def __init__(self, num_channels: int):
203
+ super().__init__()
204
+ self.num_channels = num_channels
205
+
206
+ def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
207
+ assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
208
+ timesteps = timesteps_B_T.flatten().float()
209
+ half_dim = self.num_channels // 2
210
+ exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
211
+ exponent = exponent / (half_dim - 0.0)
212
+
213
+ emb = torch.exp(exponent)
214
+ emb = timesteps[:, None].float() * emb[None, :]
215
+
216
+ sin_emb = torch.sin(emb)
217
+ cos_emb = torch.cos(emb)
218
+ emb = torch.cat([cos_emb, sin_emb], dim=-1)
219
+
220
+ return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
221
+
222
+
223
+ class TimestepEmbedding(nn.Module):
224
+ def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
225
+ super().__init__()
226
+ logging.debug(
227
+ f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
228
+ )
229
+ self.in_dim = in_features
230
+ self.out_dim = out_features
231
+ self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
232
+ self.activation = nn.SiLU()
233
+ self.use_adaln_lora = use_adaln_lora
234
+ if use_adaln_lora:
235
+ self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
236
+ else:
237
+ self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
238
+
239
+ def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
240
+ emb = self.linear_1(sample)
241
+ emb = self.activation(emb)
242
+ emb = self.linear_2(emb)
243
+
244
+ if self.use_adaln_lora:
245
+ adaln_lora_B_T_3D = emb
246
+ emb_B_T_D = sample
247
+ else:
248
+ adaln_lora_B_T_3D = None
249
+ emb_B_T_D = emb
250
+
251
+ return emb_B_T_D, adaln_lora_B_T_3D
252
+
253
+
254
+ class PatchEmbed(nn.Module):
255
+ """
256
+ PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
257
+ depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
258
+ making it suitable for video and image processing tasks. It supports dividing the input into patches
259
+ and embedding each patch into a vector of size `out_channels`.
260
+
261
+ Parameters:
262
+ - spatial_patch_size (int): The size of each spatial patch.
263
+ - temporal_patch_size (int): The size of each temporal patch.
264
+ - in_channels (int): Number of input channels. Default: 3.
265
+ - out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
266
+ - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
267
+ """
268
+
269
+ def __init__(
270
+ self,
271
+ spatial_patch_size: int,
272
+ temporal_patch_size: int,
273
+ in_channels: int = 3,
274
+ out_channels: int = 768,
275
+ device=None, dtype=None, operations=None
276
+ ):
277
+ super().__init__()
278
+ self.spatial_patch_size = spatial_patch_size
279
+ self.temporal_patch_size = temporal_patch_size
280
+
281
+ self.proj = nn.Sequential(
282
+ Rearrange(
283
+ "b c (t r) (h m) (w n) -> b t h w (c r m n)",
284
+ r=temporal_patch_size,
285
+ m=spatial_patch_size,
286
+ n=spatial_patch_size,
287
+ ),
288
+ operations.Linear(
289
+ in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
290
+ ),
291
+ )
292
+ self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
293
+
294
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
295
+ """
296
+ Forward pass of the PatchEmbed module.
297
+
298
+ Parameters:
299
+ - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
300
+ B is the batch size,
301
+ C is the number of channels,
302
+ T is the temporal dimension,
303
+ H is the height, and
304
+ W is the width of the input.
305
+
306
+ Returns:
307
+ - torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
308
+ """
309
+ assert x.dim() == 5
310
+ _, _, T, H, W = x.shape
311
+ assert (
312
+ H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
313
+ ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
314
+ assert T % self.temporal_patch_size == 0
315
+ x = self.proj(x)
316
+ return x
317
+
318
+
319
+ class FinalLayer(nn.Module):
320
+ """
321
+ The final layer of video DiT.
322
+ """
323
+
324
+ def __init__(
325
+ self,
326
+ hidden_size: int,
327
+ spatial_patch_size: int,
328
+ temporal_patch_size: int,
329
+ out_channels: int,
330
+ use_adaln_lora: bool = False,
331
+ adaln_lora_dim: int = 256,
332
+ device=None, dtype=None, operations=None
333
+ ):
334
+ super().__init__()
335
+ self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
336
+ self.linear = operations.Linear(
337
+ hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
338
+ )
339
+ self.hidden_size = hidden_size
340
+ self.n_adaln_chunks = 2
341
+ self.use_adaln_lora = use_adaln_lora
342
+ self.adaln_lora_dim = adaln_lora_dim
343
+ if use_adaln_lora:
344
+ self.adaln_modulation = nn.Sequential(
345
+ nn.SiLU(),
346
+ operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
347
+ operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
348
+ )
349
+ else:
350
+ self.adaln_modulation = nn.Sequential(
351
+ nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
352
+ )
353
+
354
+ def forward(
355
+ self,
356
+ x_B_T_H_W_D: torch.Tensor,
357
+ emb_B_T_D: torch.Tensor,
358
+ adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
359
+ ):
360
+ if self.use_adaln_lora:
361
+ assert adaln_lora_B_T_3D is not None
362
+ shift_B_T_D, scale_B_T_D = (
363
+ self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
364
+ ).chunk(2, dim=-1)
365
+ else:
366
+ shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
367
+
368
+ shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
369
+ scale_B_T_D, "b t d -> b t 1 1 d"
370
+ )
371
+
372
+ def _fn(
373
+ _x_B_T_H_W_D: torch.Tensor,
374
+ _norm_layer: nn.Module,
375
+ _scale_B_T_1_1_D: torch.Tensor,
376
+ _shift_B_T_1_1_D: torch.Tensor,
377
+ ) -> torch.Tensor:
378
+ return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
379
+
380
+ x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
381
+ x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
382
+ return x_B_T_H_W_O
383
+
384
+
385
+ class Block(nn.Module):
386
+ """
387
+ A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
388
+ Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
389
+
390
+ Parameters:
391
+ x_dim (int): Dimension of input features
392
+ context_dim (int): Dimension of context features for cross-attention
393
+ num_heads (int): Number of attention heads
394
+ mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
395
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
396
+ adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
397
+
398
+ The block applies the following sequence:
399
+ 1. Self-attention with AdaLN modulation
400
+ 2. Cross-attention with AdaLN modulation
401
+ 3. MLP with AdaLN modulation
402
+
403
+ Each component uses skip connections and layer normalization.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ x_dim: int,
409
+ context_dim: int,
410
+ num_heads: int,
411
+ mlp_ratio: float = 4.0,
412
+ use_adaln_lora: bool = False,
413
+ adaln_lora_dim: int = 256,
414
+ device=None,
415
+ dtype=None,
416
+ operations=None,
417
+ ):
418
+ super().__init__()
419
+ self.x_dim = x_dim
420
+ self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
421
+ self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
422
+
423
+ self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
424
+ self.cross_attn = Attention(
425
+ x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
426
+ )
427
+
428
+ self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
429
+ self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
430
+
431
+ self.use_adaln_lora = use_adaln_lora
432
+ if self.use_adaln_lora:
433
+ self.adaln_modulation_self_attn = nn.Sequential(
434
+ nn.SiLU(),
435
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
436
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
437
+ )
438
+ self.adaln_modulation_cross_attn = nn.Sequential(
439
+ nn.SiLU(),
440
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
441
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
442
+ )
443
+ self.adaln_modulation_mlp = nn.Sequential(
444
+ nn.SiLU(),
445
+ operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
446
+ operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
447
+ )
448
+ else:
449
+ self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
450
+ self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
451
+ self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
452
+
453
+ def forward(
454
+ self,
455
+ x_B_T_H_W_D: torch.Tensor,
456
+ emb_B_T_D: torch.Tensor,
457
+ crossattn_emb: torch.Tensor,
458
+ rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
459
+ adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
460
+ extra_per_block_pos_emb: Optional[torch.Tensor] = None,
461
+ ) -> torch.Tensor:
462
+ if extra_per_block_pos_emb is not None:
463
+ x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
464
+
465
+ if self.use_adaln_lora:
466
+ shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
467
+ self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
468
+ ).chunk(3, dim=-1)
469
+ shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
470
+ self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
471
+ ).chunk(3, dim=-1)
472
+ shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
473
+ self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
474
+ ).chunk(3, dim=-1)
475
+ else:
476
+ shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
477
+ emb_B_T_D
478
+ ).chunk(3, dim=-1)
479
+ shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
480
+ emb_B_T_D
481
+ ).chunk(3, dim=-1)
482
+ shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
483
+
484
+ # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
485
+ shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
486
+ scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
487
+ gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
488
+
489
+ shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
490
+ scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
491
+ gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
492
+
493
+ shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
494
+ scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
495
+ gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
496
+
497
+ B, T, H, W, D = x_B_T_H_W_D.shape
498
+
499
+ def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
500
+ return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
501
+
502
+ normalized_x_B_T_H_W_D = _fn(
503
+ x_B_T_H_W_D,
504
+ self.layer_norm_self_attn,
505
+ scale_self_attn_B_T_1_1_D,
506
+ shift_self_attn_B_T_1_1_D,
507
+ )
508
+ result_B_T_H_W_D = rearrange(
509
+ self.self_attn(
510
+ # normalized_x_B_T_HW_D,
511
+ rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
512
+ None,
513
+ rope_emb=rope_emb_L_1_1_D,
514
+ ),
515
+ "b (t h w) d -> b t h w d",
516
+ t=T,
517
+ h=H,
518
+ w=W,
519
+ )
520
+ x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
521
+
522
+ def _x_fn(
523
+ _x_B_T_H_W_D: torch.Tensor,
524
+ layer_norm_cross_attn: Callable,
525
+ _scale_cross_attn_B_T_1_1_D: torch.Tensor,
526
+ _shift_cross_attn_B_T_1_1_D: torch.Tensor,
527
+ ) -> torch.Tensor:
528
+ _normalized_x_B_T_H_W_D = _fn(
529
+ _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
530
+ )
531
+ _result_B_T_H_W_D = rearrange(
532
+ self.cross_attn(
533
+ rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
534
+ crossattn_emb,
535
+ rope_emb=rope_emb_L_1_1_D,
536
+ ),
537
+ "b (t h w) d -> b t h w d",
538
+ t=T,
539
+ h=H,
540
+ w=W,
541
+ )
542
+ return _result_B_T_H_W_D
543
+
544
+ result_B_T_H_W_D = _x_fn(
545
+ x_B_T_H_W_D,
546
+ self.layer_norm_cross_attn,
547
+ scale_cross_attn_B_T_1_1_D,
548
+ shift_cross_attn_B_T_1_1_D,
549
+ )
550
+ x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
551
+
552
+ normalized_x_B_T_H_W_D = _fn(
553
+ x_B_T_H_W_D,
554
+ self.layer_norm_mlp,
555
+ scale_mlp_B_T_1_1_D,
556
+ shift_mlp_B_T_1_1_D,
557
+ )
558
+ result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
559
+ x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
560
+ return x_B_T_H_W_D
561
+
562
+
563
+ class MiniTrainDIT(nn.Module):
564
+ """
565
+ A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
566
+ A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
567
+
568
+ Args:
569
+ max_img_h (int): Maximum height of the input images.
570
+ max_img_w (int): Maximum width of the input images.
571
+ max_frames (int): Maximum number of frames in the video sequence.
572
+ in_channels (int): Number of input channels (e.g., RGB channels for color images).
573
+ out_channels (int): Number of output channels.
574
+ patch_spatial (tuple): Spatial resolution of patches for input processing.
575
+ patch_temporal (int): Temporal resolution of patches for input processing.
576
+ concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
577
+ model_channels (int): Base number of channels used throughout the model.
578
+ num_blocks (int): Number of transformer blocks.
579
+ num_heads (int): Number of heads in the multi-head attention layers.
580
+ mlp_ratio (float): Expansion ratio for MLP blocks.
581
+ crossattn_emb_channels (int): Number of embedding channels for cross-attention.
582
+ pos_emb_cls (str): Type of positional embeddings.
583
+ pos_emb_learnable (bool): Whether positional embeddings are learnable.
584
+ pos_emb_interpolation (str): Method for interpolating positional embeddings.
585
+ min_fps (int): Minimum frames per second.
586
+ max_fps (int): Maximum frames per second.
587
+ use_adaln_lora (bool): Whether to use AdaLN-LoRA.
588
+ adaln_lora_dim (int): Dimension for AdaLN-LoRA.
589
+ rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
590
+ rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
591
+ rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
592
+ extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
593
+ extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
594
+ extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
595
+ extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
596
+ """
597
+
598
+ def __init__(
599
+ self,
600
+ max_img_h: int,
601
+ max_img_w: int,
602
+ max_frames: int,
603
+ in_channels: int,
604
+ out_channels: int,
605
+ patch_spatial: int, # tuple,
606
+ patch_temporal: int,
607
+ concat_padding_mask: bool = True,
608
+ # attention settings
609
+ model_channels: int = 768,
610
+ num_blocks: int = 10,
611
+ num_heads: int = 16,
612
+ mlp_ratio: float = 4.0,
613
+ # cross attention settings
614
+ crossattn_emb_channels: int = 1024,
615
+ # positional embedding settings
616
+ pos_emb_cls: str = "sincos",
617
+ pos_emb_learnable: bool = False,
618
+ pos_emb_interpolation: str = "crop",
619
+ min_fps: int = 1,
620
+ max_fps: int = 30,
621
+ use_adaln_lora: bool = False,
622
+ adaln_lora_dim: int = 256,
623
+ rope_h_extrapolation_ratio: float = 1.0,
624
+ rope_w_extrapolation_ratio: float = 1.0,
625
+ rope_t_extrapolation_ratio: float = 1.0,
626
+ extra_per_block_abs_pos_emb: bool = False,
627
+ extra_h_extrapolation_ratio: float = 1.0,
628
+ extra_w_extrapolation_ratio: float = 1.0,
629
+ extra_t_extrapolation_ratio: float = 1.0,
630
+ rope_enable_fps_modulation: bool = True,
631
+ image_model=None,
632
+ device=None,
633
+ dtype=None,
634
+ operations=None,
635
+ ) -> None:
636
+ super().__init__()
637
+ self.dtype = dtype
638
+ self.max_img_h = max_img_h
639
+ self.max_img_w = max_img_w
640
+ self.max_frames = max_frames
641
+ self.in_channels = in_channels
642
+ self.out_channels = out_channels
643
+ self.patch_spatial = patch_spatial
644
+ self.patch_temporal = patch_temporal
645
+ self.num_heads = num_heads
646
+ self.num_blocks = num_blocks
647
+ self.model_channels = model_channels
648
+ self.concat_padding_mask = concat_padding_mask
649
+ # positional embedding settings
650
+ self.pos_emb_cls = pos_emb_cls
651
+ self.pos_emb_learnable = pos_emb_learnable
652
+ self.pos_emb_interpolation = pos_emb_interpolation
653
+ self.min_fps = min_fps
654
+ self.max_fps = max_fps
655
+ self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
656
+ self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
657
+ self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
658
+ self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
659
+ self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
660
+ self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
661
+ self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
662
+ self.rope_enable_fps_modulation = rope_enable_fps_modulation
663
+
664
+ self.build_pos_embed(device=device, dtype=dtype)
665
+ self.use_adaln_lora = use_adaln_lora
666
+ self.adaln_lora_dim = adaln_lora_dim
667
+ self.t_embedder = nn.Sequential(
668
+ Timesteps(model_channels),
669
+ TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
670
+ )
671
+
672
+ in_channels = in_channels + 1 if concat_padding_mask else in_channels
673
+ self.x_embedder = PatchEmbed(
674
+ spatial_patch_size=patch_spatial,
675
+ temporal_patch_size=patch_temporal,
676
+ in_channels=in_channels,
677
+ out_channels=model_channels,
678
+ device=device, dtype=dtype, operations=operations,
679
+ )
680
+
681
+ self.blocks = nn.ModuleList(
682
+ [
683
+ Block(
684
+ x_dim=model_channels,
685
+ context_dim=crossattn_emb_channels,
686
+ num_heads=num_heads,
687
+ mlp_ratio=mlp_ratio,
688
+ use_adaln_lora=use_adaln_lora,
689
+ adaln_lora_dim=adaln_lora_dim,
690
+ device=device, dtype=dtype, operations=operations,
691
+ )
692
+ for _ in range(num_blocks)
693
+ ]
694
+ )
695
+
696
+ self.final_layer = FinalLayer(
697
+ hidden_size=self.model_channels,
698
+ spatial_patch_size=self.patch_spatial,
699
+ temporal_patch_size=self.patch_temporal,
700
+ out_channels=self.out_channels,
701
+ use_adaln_lora=self.use_adaln_lora,
702
+ adaln_lora_dim=self.adaln_lora_dim,
703
+ device=device, dtype=dtype, operations=operations,
704
+ )
705
+
706
+ self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
707
+
708
+ def build_pos_embed(self, device=None, dtype=None) -> None:
709
+ if self.pos_emb_cls == "rope3d":
710
+ cls_type = VideoRopePosition3DEmb
711
+ else:
712
+ raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
713
+
714
+ logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
715
+ kwargs = dict(
716
+ model_channels=self.model_channels,
717
+ len_h=self.max_img_h // self.patch_spatial,
718
+ len_w=self.max_img_w // self.patch_spatial,
719
+ len_t=self.max_frames // self.patch_temporal,
720
+ max_fps=self.max_fps,
721
+ min_fps=self.min_fps,
722
+ is_learnable=self.pos_emb_learnable,
723
+ interpolation=self.pos_emb_interpolation,
724
+ head_dim=self.model_channels // self.num_heads,
725
+ h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
726
+ w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
727
+ t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
728
+ enable_fps_modulation=self.rope_enable_fps_modulation,
729
+ device=device,
730
+ )
731
+ self.pos_embedder = cls_type(
732
+ **kwargs, # type: ignore
733
+ )
734
+
735
+ if self.extra_per_block_abs_pos_emb:
736
+ kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
737
+ kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
738
+ kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
739
+ kwargs["device"] = device
740
+ kwargs["dtype"] = dtype
741
+ self.extra_pos_embedder = LearnablePosEmbAxis(
742
+ **kwargs, # type: ignore
743
+ )
744
+
745
+ def prepare_embedded_sequence(
746
+ self,
747
+ x_B_C_T_H_W: torch.Tensor,
748
+ fps: Optional[torch.Tensor] = None,
749
+ padding_mask: Optional[torch.Tensor] = None,
750
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
751
+ """
752
+ Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
753
+
754
+ Args:
755
+ x_B_C_T_H_W (torch.Tensor): video
756
+ fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
757
+ If None, a default value (`self.base_fps`) will be used.
758
+ padding_mask (Optional[torch.Tensor]): current it is not used
759
+
760
+ Returns:
761
+ Tuple[torch.Tensor, Optional[torch.Tensor]]:
762
+ - A tensor of shape (B, T, H, W, D) with the embedded sequence.
763
+ - An optional positional embedding tensor, returned only if the positional embedding class
764
+ (`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
765
+
766
+ Notes:
767
+ - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
768
+ - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
769
+ - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
770
+ the `self.pos_embedder` with the shape [T, H, W].
771
+ - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
772
+ `self.pos_embedder` with the fps tensor.
773
+ - Otherwise, the positional embeddings are generated without considering fps.
774
+ """
775
+ if self.concat_padding_mask:
776
+ if padding_mask is None:
777
+ padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
778
+ else:
779
+ padding_mask = transforms.functional.resize(
780
+ padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
781
+ )
782
+ x_B_C_T_H_W = torch.cat(
783
+ [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
784
+ )
785
+ x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
786
+
787
+ if self.extra_per_block_abs_pos_emb:
788
+ extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
789
+ else:
790
+ extra_pos_emb = None
791
+
792
+ if "rope" in self.pos_emb_cls.lower():
793
+ return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
794
+ x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
795
+
796
+ return x_B_T_H_W_D, None, extra_pos_emb
797
+
798
+ def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
799
+ x_B_C_Tt_Hp_Wp = rearrange(
800
+ x_B_T_H_W_M,
801
+ "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
802
+ p1=self.patch_spatial,
803
+ p2=self.patch_spatial,
804
+ t=self.patch_temporal,
805
+ )
806
+ return x_B_C_Tt_Hp_Wp
807
+
808
+ def forward(
809
+ self,
810
+ x: torch.Tensor,
811
+ timesteps: torch.Tensor,
812
+ context: torch.Tensor,
813
+ fps: Optional[torch.Tensor] = None,
814
+ padding_mask: Optional[torch.Tensor] = None,
815
+ **kwargs,
816
+ ):
817
+ x_B_C_T_H_W = x
818
+ timesteps_B_T = timesteps
819
+ crossattn_emb = context
820
+ """
821
+ Args:
822
+ x: (B, C, T, H, W) tensor of spatial-temp inputs
823
+ timesteps: (B, ) tensor of timesteps
824
+ crossattn_emb: (B, N, D) tensor of cross-attention embeddings
825
+ """
826
+ x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
827
+ x_B_C_T_H_W,
828
+ fps=fps,
829
+ padding_mask=padding_mask,
830
+ )
831
+
832
+ if timesteps_B_T.ndim == 1:
833
+ timesteps_B_T = timesteps_B_T.unsqueeze(1)
834
+ t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
835
+ t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
836
+
837
+ # for logging purpose
838
+ affline_scale_log_info = {}
839
+ affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
840
+ self.affline_scale_log_info = affline_scale_log_info
841
+ self.affline_emb = t_embedding_B_T_D
842
+ self.crossattn_emb = crossattn_emb
843
+
844
+ if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
845
+ assert (
846
+ x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
847
+ ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
848
+
849
+ block_kwargs = {
850
+ "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
851
+ "adaln_lora_B_T_3D": adaln_lora_B_T_3D,
852
+ "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
853
+ }
854
+ for block in self.blocks:
855
+ x_B_T_H_W_D = block(
856
+ x_B_T_H_W_D,
857
+ t_embedding_B_T_D,
858
+ crossattn_emb,
859
+ **block_kwargs,
860
+ )
861
+
862
+ x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
863
+ x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
864
+ return x_B_C_Tt_Hp_Wp
ComfyUI/comfy/ldm/cosmos/vae.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
16
+
17
+ import logging
18
+ import torch
19
+ from torch import nn
20
+ from enum import Enum
21
+ import math
22
+
23
+ from .cosmos_tokenizer.layers3d import (
24
+ EncoderFactorized,
25
+ DecoderFactorized,
26
+ CausalConv3d,
27
+ )
28
+
29
+
30
+ class IdentityDistribution(torch.nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+
34
+ def forward(self, parameters):
35
+ return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
36
+
37
+
38
+ class GaussianDistribution(torch.nn.Module):
39
+ def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
40
+ super().__init__()
41
+ self.min_logvar = min_logvar
42
+ self.max_logvar = max_logvar
43
+
44
+ def sample(self, mean, logvar):
45
+ std = torch.exp(0.5 * logvar)
46
+ return mean + std * torch.randn_like(mean)
47
+
48
+ def forward(self, parameters):
49
+ mean, logvar = torch.chunk(parameters, 2, dim=1)
50
+ logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
51
+ return self.sample(mean, logvar), (mean, logvar)
52
+
53
+
54
+ class ContinuousFormulation(Enum):
55
+ VAE = GaussianDistribution
56
+ AE = IdentityDistribution
57
+
58
+
59
+ class CausalContinuousVideoTokenizer(nn.Module):
60
+ def __init__(
61
+ self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
62
+ ) -> None:
63
+ super().__init__()
64
+ self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
65
+ self.latent_channels = latent_channels
66
+ self.sigma_data = 0.5
67
+
68
+ # encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
69
+ self.encoder = EncoderFactorized(
70
+ z_channels=z_factor * z_channels, **kwargs
71
+ )
72
+ if kwargs.get("temporal_compression", 4) == 4:
73
+ kwargs["channels_mult"] = [2, 4]
74
+ # decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
75
+ self.decoder = DecoderFactorized(
76
+ z_channels=z_channels, **kwargs
77
+ )
78
+
79
+ self.quant_conv = CausalConv3d(
80
+ z_factor * z_channels,
81
+ z_factor * latent_channels,
82
+ kernel_size=1,
83
+ padding=0,
84
+ )
85
+ self.post_quant_conv = CausalConv3d(
86
+ latent_channels, z_channels, kernel_size=1, padding=0
87
+ )
88
+
89
+ # formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
90
+ self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
91
+
92
+ num_parameters = sum(param.numel() for param in self.parameters())
93
+ logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
94
+ logging.debug(
95
+ f"z_channels={z_channels}, latent_channels={self.latent_channels}."
96
+ )
97
+
98
+ latent_temporal_chunk = 16
99
+ self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
100
+ self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
101
+
102
+
103
+ def encode(self, x):
104
+ h = self.encoder(x)
105
+ moments = self.quant_conv(h)
106
+ z, posteriors = self.distribution(moments)
107
+ latent_ch = z.shape[1]
108
+ latent_t = z.shape[2]
109
+ in_dtype = z.dtype
110
+ mean = self.latent_mean.view(latent_ch, -1)
111
+ std = self.latent_std.view(latent_ch, -1)
112
+
113
+ mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
114
+ std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
115
+ return ((z - mean) / std) * self.sigma_data
116
+
117
+ def decode(self, z):
118
+ in_dtype = z.dtype
119
+ latent_ch = z.shape[1]
120
+ latent_t = z.shape[2]
121
+ mean = self.latent_mean.view(latent_ch, -1)
122
+ std = self.latent_std.view(latent_ch, -1)
123
+
124
+ mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
125
+ std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
126
+
127
+ z = z / self.sigma_data
128
+ z = z * std + mean
129
+ z = self.post_quant_conv(z)
130
+ return self.decoder(z)
131
+
ComfyUI/comfy/ldm/flux/controlnet.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
2
+ #modified to support different types of flux controlnets
3
+
4
+ import torch
5
+ import math
6
+ from torch import Tensor, nn
7
+ from einops import rearrange, repeat
8
+
9
+ from .layers import (timestep_embedding)
10
+
11
+ from .model import Flux
12
+ import comfy.ldm.common_dit
13
+
14
+ class MistolineCondDownsamplBlock(nn.Module):
15
+ def __init__(self, dtype=None, device=None, operations=None):
16
+ super().__init__()
17
+ self.encoder = nn.Sequential(
18
+ operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
19
+ nn.SiLU(),
20
+ operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
21
+ nn.SiLU(),
22
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
23
+ nn.SiLU(),
24
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
25
+ nn.SiLU(),
26
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
27
+ nn.SiLU(),
28
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
29
+ nn.SiLU(),
30
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
31
+ nn.SiLU(),
32
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
33
+ nn.SiLU(),
34
+ operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
35
+ nn.SiLU(),
36
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
37
+ )
38
+
39
+ def forward(self, x):
40
+ return self.encoder(x)
41
+
42
+ class MistolineControlnetBlock(nn.Module):
43
+ def __init__(self, hidden_size, dtype=None, device=None, operations=None):
44
+ super().__init__()
45
+ self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
46
+ self.act = nn.SiLU()
47
+
48
+ def forward(self, x):
49
+ return self.act(self.linear(x))
50
+
51
+
52
+ class ControlNetFlux(Flux):
53
+ def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
54
+ super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
55
+
56
+ self.main_model_double = 19
57
+ self.main_model_single = 38
58
+
59
+ self.mistoline = mistoline
60
+ # add ControlNet blocks
61
+ if self.mistoline:
62
+ control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
63
+ else:
64
+ control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
65
+
66
+ self.controlnet_blocks = nn.ModuleList([])
67
+ for _ in range(self.params.depth):
68
+ self.controlnet_blocks.append(control_block())
69
+
70
+ self.controlnet_single_blocks = nn.ModuleList([])
71
+ for _ in range(self.params.depth_single_blocks):
72
+ self.controlnet_single_blocks.append(control_block())
73
+
74
+ self.num_union_modes = num_union_modes
75
+ self.controlnet_mode_embedder = None
76
+ if self.num_union_modes > 0:
77
+ self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
78
+
79
+ self.gradient_checkpointing = False
80
+ self.latent_input = latent_input
81
+ if control_latent_channels is None:
82
+ control_latent_channels = self.in_channels
83
+ else:
84
+ control_latent_channels *= 2 * 2 #patch size
85
+
86
+ self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
87
+ if not self.latent_input:
88
+ if self.mistoline:
89
+ self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
90
+ else:
91
+ self.input_hint_block = nn.Sequential(
92
+ operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
93
+ nn.SiLU(),
94
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
95
+ nn.SiLU(),
96
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
97
+ nn.SiLU(),
98
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
99
+ nn.SiLU(),
100
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
101
+ nn.SiLU(),
102
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
103
+ nn.SiLU(),
104
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
105
+ nn.SiLU(),
106
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
107
+ )
108
+
109
+ def forward_orig(
110
+ self,
111
+ img: Tensor,
112
+ img_ids: Tensor,
113
+ controlnet_cond: Tensor,
114
+ txt: Tensor,
115
+ txt_ids: Tensor,
116
+ timesteps: Tensor,
117
+ y: Tensor,
118
+ guidance: Tensor = None,
119
+ control_type: Tensor = None,
120
+ ) -> Tensor:
121
+ if img.ndim != 3 or txt.ndim != 3:
122
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
123
+
124
+ if y is None:
125
+ y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
126
+ else:
127
+ y = y[:, :self.params.vec_in_dim]
128
+
129
+ # running on sequences img
130
+ img = self.img_in(img)
131
+
132
+ controlnet_cond = self.pos_embed_input(controlnet_cond)
133
+ img = img + controlnet_cond
134
+ vec = self.time_in(timestep_embedding(timesteps, 256))
135
+ if self.params.guidance_embed:
136
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
137
+ vec = vec + self.vector_in(y)
138
+ txt = self.txt_in(txt)
139
+
140
+ if self.controlnet_mode_embedder is not None and len(control_type) > 0:
141
+ control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
142
+ txt = torch.cat([control_cond, txt], dim=1)
143
+ txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
144
+
145
+ ids = torch.cat((txt_ids, img_ids), dim=1)
146
+ pe = self.pe_embedder(ids)
147
+
148
+ controlnet_double = ()
149
+
150
+ for i in range(len(self.double_blocks)):
151
+ img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
152
+ controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
153
+
154
+ img = torch.cat((txt, img), 1)
155
+
156
+ controlnet_single = ()
157
+
158
+ for i in range(len(self.single_blocks)):
159
+ img = self.single_blocks[i](img, vec=vec, pe=pe)
160
+ controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
161
+
162
+ repeat = math.ceil(self.main_model_double / len(controlnet_double))
163
+ if self.latent_input:
164
+ out_input = ()
165
+ for x in controlnet_double:
166
+ out_input += (x,) * repeat
167
+ else:
168
+ out_input = (controlnet_double * repeat)
169
+
170
+ out = {"input": out_input[:self.main_model_double]}
171
+ if len(controlnet_single) > 0:
172
+ repeat = math.ceil(self.main_model_single / len(controlnet_single))
173
+ out_output = ()
174
+ if self.latent_input:
175
+ for x in controlnet_single:
176
+ out_output += (x,) * repeat
177
+ else:
178
+ out_output = (controlnet_single * repeat)
179
+ out["output"] = out_output[:self.main_model_single]
180
+ return out
181
+
182
+ def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
183
+ patch_size = 2
184
+ if self.latent_input:
185
+ hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
186
+ elif self.mistoline:
187
+ hint = hint * 2.0 - 1.0
188
+ hint = self.input_cond_block(hint)
189
+ else:
190
+ hint = hint * 2.0 - 1.0
191
+ hint = self.input_hint_block(hint)
192
+
193
+ hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
194
+
195
+ bs, c, h, w = x.shape
196
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
197
+
198
+ img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
199
+
200
+ h_len = ((h + (patch_size // 2)) // patch_size)
201
+ w_len = ((w + (patch_size // 2)) // patch_size)
202
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
203
+ img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
204
+ img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
205
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
206
+
207
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
208
+ return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
ComfyUI/comfy/ldm/flux/layers.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+ from .math import attention, rope
8
+ import comfy.ops
9
+ import comfy.ldm.common_dit
10
+
11
+
12
+ class EmbedND(nn.Module):
13
+ def __init__(self, dim: int, theta: int, axes_dim: list):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.theta = theta
17
+ self.axes_dim = axes_dim
18
+
19
+ def forward(self, ids: Tensor) -> Tensor:
20
+ n_axes = ids.shape[-1]
21
+ emb = torch.cat(
22
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
23
+ dim=-3,
24
+ )
25
+
26
+ return emb.unsqueeze(1)
27
+
28
+
29
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
30
+ """
31
+ Create sinusoidal timestep embeddings.
32
+ :param t: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ :param dim: the dimension of the output.
35
+ :param max_period: controls the minimum frequency of the embeddings.
36
+ :return: an (N, D) Tensor of positional embeddings.
37
+ """
38
+ t = time_factor * t
39
+ half = dim // 2
40
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
41
+
42
+ args = t[:, None].float() * freqs[None]
43
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
44
+ if dim % 2:
45
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
46
+ if torch.is_floating_point(t):
47
+ embedding = embedding.to(t)
48
+ return embedding
49
+
50
+ class MLPEmbedder(nn.Module):
51
+ def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
52
+ super().__init__()
53
+ self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
54
+ self.silu = nn.SiLU()
55
+ self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
56
+
57
+ def forward(self, x: Tensor) -> Tensor:
58
+ return self.out_layer(self.silu(self.in_layer(x)))
59
+
60
+
61
+ class RMSNorm(torch.nn.Module):
62
+ def __init__(self, dim: int, dtype=None, device=None, operations=None):
63
+ super().__init__()
64
+ self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
65
+
66
+ def forward(self, x: Tensor):
67
+ return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
68
+
69
+
70
+ class QKNorm(torch.nn.Module):
71
+ def __init__(self, dim: int, dtype=None, device=None, operations=None):
72
+ super().__init__()
73
+ self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
74
+ self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
75
+
76
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
77
+ q = self.query_norm(q)
78
+ k = self.key_norm(k)
79
+ return q.to(v), k.to(v)
80
+
81
+
82
+ class SelfAttention(nn.Module):
83
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
84
+ super().__init__()
85
+ self.num_heads = num_heads
86
+ head_dim = dim // num_heads
87
+
88
+ self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
89
+ self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
90
+ self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
91
+
92
+
93
+ @dataclass
94
+ class ModulationOut:
95
+ shift: Tensor
96
+ scale: Tensor
97
+ gate: Tensor
98
+
99
+
100
+ class Modulation(nn.Module):
101
+ def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
102
+ super().__init__()
103
+ self.is_double = double
104
+ self.multiplier = 6 if double else 3
105
+ self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
106
+
107
+ def forward(self, vec: Tensor) -> tuple:
108
+ if vec.ndim == 2:
109
+ vec = vec[:, None, :]
110
+ out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
111
+
112
+ return (
113
+ ModulationOut(*out[:3]),
114
+ ModulationOut(*out[3:]) if self.is_double else None,
115
+ )
116
+
117
+
118
+ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
119
+ if modulation_dims is None:
120
+ if m_add is not None:
121
+ return torch.addcmul(m_add, tensor, m_mult)
122
+ else:
123
+ return tensor * m_mult
124
+ else:
125
+ for d in modulation_dims:
126
+ tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
127
+ if m_add is not None:
128
+ tensor[:, d[0]:d[1]] += m_add[:, d[2]]
129
+ return tensor
130
+
131
+
132
+ class DoubleStreamBlock(nn.Module):
133
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
134
+ super().__init__()
135
+
136
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
137
+ self.num_heads = num_heads
138
+ self.hidden_size = hidden_size
139
+ self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
140
+ self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
141
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
142
+
143
+ self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
144
+ self.img_mlp = nn.Sequential(
145
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
146
+ nn.GELU(approximate="tanh"),
147
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
148
+ )
149
+
150
+ self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
151
+ self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
152
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
153
+
154
+ self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
155
+ self.txt_mlp = nn.Sequential(
156
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
157
+ nn.GELU(approximate="tanh"),
158
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
159
+ )
160
+ self.flipped_img_txt = flipped_img_txt
161
+
162
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
163
+ img_mod1, img_mod2 = self.img_mod(vec)
164
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
165
+
166
+ # prepare image for attention
167
+ img_modulated = self.img_norm1(img)
168
+ img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
169
+ img_qkv = self.img_attn.qkv(img_modulated)
170
+ img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
171
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
172
+
173
+ # prepare txt for attention
174
+ txt_modulated = self.txt_norm1(txt)
175
+ txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
176
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
177
+ txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
178
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
179
+
180
+ if self.flipped_img_txt:
181
+ # run actual attention
182
+ attn = attention(torch.cat((img_q, txt_q), dim=2),
183
+ torch.cat((img_k, txt_k), dim=2),
184
+ torch.cat((img_v, txt_v), dim=2),
185
+ pe=pe, mask=attn_mask)
186
+
187
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
188
+ else:
189
+ # run actual attention
190
+ attn = attention(torch.cat((txt_q, img_q), dim=2),
191
+ torch.cat((txt_k, img_k), dim=2),
192
+ torch.cat((txt_v, img_v), dim=2),
193
+ pe=pe, mask=attn_mask)
194
+
195
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
196
+
197
+ # calculate the img bloks
198
+ img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
199
+ img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
200
+
201
+ # calculate the txt bloks
202
+ txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
203
+ txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
204
+
205
+ if txt.dtype == torch.float16:
206
+ txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
207
+
208
+ return img, txt
209
+
210
+
211
+ class SingleStreamBlock(nn.Module):
212
+ """
213
+ A DiT block with parallel linear layers as described in
214
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ hidden_size: int,
220
+ num_heads: int,
221
+ mlp_ratio: float = 4.0,
222
+ qk_scale: float = None,
223
+ dtype=None,
224
+ device=None,
225
+ operations=None
226
+ ):
227
+ super().__init__()
228
+ self.hidden_dim = hidden_size
229
+ self.num_heads = num_heads
230
+ head_dim = hidden_size // num_heads
231
+ self.scale = qk_scale or head_dim**-0.5
232
+
233
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
234
+ # qkv and mlp_in
235
+ self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
236
+ # proj and mlp_out
237
+ self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
238
+
239
+ self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
240
+
241
+ self.hidden_size = hidden_size
242
+ self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
243
+
244
+ self.mlp_act = nn.GELU(approximate="tanh")
245
+ self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
246
+
247
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
248
+ mod, _ = self.modulation(vec)
249
+ qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
250
+
251
+ q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
252
+ q, k = self.norm(q, k, v)
253
+
254
+ # compute attention
255
+ attn = attention(q, k, v, pe=pe, mask=attn_mask)
256
+ # compute activation in mlp stream, cat again and run second linear layer
257
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
258
+ x += apply_mod(output, mod.gate, None, modulation_dims)
259
+ if x.dtype == torch.float16:
260
+ x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
261
+ return x
262
+
263
+
264
+ class LastLayer(nn.Module):
265
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
266
+ super().__init__()
267
+ self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
268
+ self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
269
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
270
+
271
+ def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
272
+ if vec.ndim == 2:
273
+ vec = vec[:, None, :]
274
+
275
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
276
+ x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
277
+ x = self.linear(x)
278
+ return x
ComfyUI/comfy/ldm/flux/math.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+ from comfy.ldm.modules.attention import optimized_attention
6
+ import comfy.model_management
7
+
8
+
9
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
10
+ q_shape = q.shape
11
+ k_shape = k.shape
12
+
13
+ if pe is not None:
14
+ q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
15
+ k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
16
+ q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
17
+ k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
18
+
19
+ heads = q.shape[1]
20
+ x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
21
+ return x
22
+
23
+
24
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
25
+ assert dim % 2 == 0
26
+ if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
27
+ device = torch.device("cpu")
28
+ else:
29
+ device = pos.device
30
+
31
+ scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
32
+ omega = 1.0 / (theta**scale)
33
+ out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
34
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
35
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
36
+ return out.to(dtype=torch.float32, device=pos.device)
37
+
38
+
39
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
40
+ xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
41
+ xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
42
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
43
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
44
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
45
+
ComfyUI/comfy/ldm/flux/model.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Original code can be found on: https://github.com/black-forest-labs/flux
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+ from einops import rearrange, repeat
8
+ import comfy.ldm.common_dit
9
+
10
+ from .layers import (
11
+ DoubleStreamBlock,
12
+ EmbedND,
13
+ LastLayer,
14
+ MLPEmbedder,
15
+ SingleStreamBlock,
16
+ timestep_embedding,
17
+ )
18
+
19
+ @dataclass
20
+ class FluxParams:
21
+ in_channels: int
22
+ out_channels: int
23
+ vec_in_dim: int
24
+ context_in_dim: int
25
+ hidden_size: int
26
+ mlp_ratio: float
27
+ num_heads: int
28
+ depth: int
29
+ depth_single_blocks: int
30
+ axes_dim: list
31
+ theta: int
32
+ patch_size: int
33
+ qkv_bias: bool
34
+ guidance_embed: bool
35
+
36
+
37
+ class Flux(nn.Module):
38
+ """
39
+ Transformer model for flow matching on sequences.
40
+ """
41
+
42
+ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
43
+ super().__init__()
44
+ self.dtype = dtype
45
+ params = FluxParams(**kwargs)
46
+ self.params = params
47
+ self.patch_size = params.patch_size
48
+ self.in_channels = params.in_channels * params.patch_size * params.patch_size
49
+ self.out_channels = params.out_channels * params.patch_size * params.patch_size
50
+ if params.hidden_size % params.num_heads != 0:
51
+ raise ValueError(
52
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
53
+ )
54
+ pe_dim = params.hidden_size // params.num_heads
55
+ if sum(params.axes_dim) != pe_dim:
56
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
57
+ self.hidden_size = params.hidden_size
58
+ self.num_heads = params.num_heads
59
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
60
+ self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
61
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
62
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
63
+ self.guidance_in = (
64
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
65
+ )
66
+ self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
67
+
68
+ self.double_blocks = nn.ModuleList(
69
+ [
70
+ DoubleStreamBlock(
71
+ self.hidden_size,
72
+ self.num_heads,
73
+ mlp_ratio=params.mlp_ratio,
74
+ qkv_bias=params.qkv_bias,
75
+ dtype=dtype, device=device, operations=operations
76
+ )
77
+ for _ in range(params.depth)
78
+ ]
79
+ )
80
+
81
+ self.single_blocks = nn.ModuleList(
82
+ [
83
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
84
+ for _ in range(params.depth_single_blocks)
85
+ ]
86
+ )
87
+
88
+ if final_layer:
89
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
90
+
91
+ def forward_orig(
92
+ self,
93
+ img: Tensor,
94
+ img_ids: Tensor,
95
+ txt: Tensor,
96
+ txt_ids: Tensor,
97
+ timesteps: Tensor,
98
+ y: Tensor,
99
+ guidance: Tensor = None,
100
+ control = None,
101
+ transformer_options={},
102
+ attn_mask: Tensor = None,
103
+ ) -> Tensor:
104
+
105
+ if y is None:
106
+ y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
107
+
108
+ patches_replace = transformer_options.get("patches_replace", {})
109
+ if img.ndim != 3 or txt.ndim != 3:
110
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
111
+
112
+ # running on sequences img
113
+ img = self.img_in(img)
114
+ vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
115
+ if self.params.guidance_embed:
116
+ if guidance is not None:
117
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
118
+
119
+ vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
120
+ txt = self.txt_in(txt)
121
+
122
+ if img_ids is not None:
123
+ ids = torch.cat((txt_ids, img_ids), dim=1)
124
+ pe = self.pe_embedder(ids)
125
+ else:
126
+ pe = None
127
+
128
+ blocks_replace = patches_replace.get("dit", {})
129
+ for i, block in enumerate(self.double_blocks):
130
+ if ("double_block", i) in blocks_replace:
131
+ def block_wrap(args):
132
+ out = {}
133
+ out["img"], out["txt"] = block(img=args["img"],
134
+ txt=args["txt"],
135
+ vec=args["vec"],
136
+ pe=args["pe"],
137
+ attn_mask=args.get("attn_mask"))
138
+ return out
139
+
140
+ out = blocks_replace[("double_block", i)]({"img": img,
141
+ "txt": txt,
142
+ "vec": vec,
143
+ "pe": pe,
144
+ "attn_mask": attn_mask},
145
+ {"original_block": block_wrap})
146
+ txt = out["txt"]
147
+ img = out["img"]
148
+ else:
149
+ img, txt = block(img=img,
150
+ txt=txt,
151
+ vec=vec,
152
+ pe=pe,
153
+ attn_mask=attn_mask)
154
+
155
+ if control is not None: # Controlnet
156
+ control_i = control.get("input")
157
+ if i < len(control_i):
158
+ add = control_i[i]
159
+ if add is not None:
160
+ img += add
161
+
162
+ if img.dtype == torch.float16:
163
+ img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
164
+
165
+ img = torch.cat((txt, img), 1)
166
+
167
+ for i, block in enumerate(self.single_blocks):
168
+ if ("single_block", i) in blocks_replace:
169
+ def block_wrap(args):
170
+ out = {}
171
+ out["img"] = block(args["img"],
172
+ vec=args["vec"],
173
+ pe=args["pe"],
174
+ attn_mask=args.get("attn_mask"))
175
+ return out
176
+
177
+ out = blocks_replace[("single_block", i)]({"img": img,
178
+ "vec": vec,
179
+ "pe": pe,
180
+ "attn_mask": attn_mask},
181
+ {"original_block": block_wrap})
182
+ img = out["img"]
183
+ else:
184
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
185
+
186
+ if control is not None: # Controlnet
187
+ control_o = control.get("output")
188
+ if i < len(control_o):
189
+ add = control_o[i]
190
+ if add is not None:
191
+ img[:, txt.shape[1] :, ...] += add
192
+
193
+ img = img[:, txt.shape[1] :, ...]
194
+
195
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
196
+ return img
197
+
198
+ def process_img(self, x, index=0, h_offset=0, w_offset=0):
199
+ bs, c, h, w = x.shape
200
+ patch_size = self.patch_size
201
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
202
+
203
+ img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
204
+ h_len = ((h + (patch_size // 2)) // patch_size)
205
+ w_len = ((w + (patch_size // 2)) // patch_size)
206
+
207
+ h_offset = ((h_offset + (patch_size // 2)) // patch_size)
208
+ w_offset = ((w_offset + (patch_size // 2)) // patch_size)
209
+
210
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
211
+ img_ids[:, :, 0] = img_ids[:, :, 1] + index
212
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
213
+ img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
214
+ return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
215
+
216
+ def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
217
+ bs, c, h_orig, w_orig = x.shape
218
+ patch_size = self.patch_size
219
+
220
+ h_len = ((h_orig + (patch_size // 2)) // patch_size)
221
+ w_len = ((w_orig + (patch_size // 2)) // patch_size)
222
+ img, img_ids = self.process_img(x)
223
+ img_tokens = img.shape[1]
224
+ if ref_latents is not None:
225
+ h = 0
226
+ w = 0
227
+ for ref in ref_latents:
228
+ h_offset = 0
229
+ w_offset = 0
230
+ if ref.shape[-2] + h > ref.shape[-1] + w:
231
+ w_offset = w
232
+ else:
233
+ h_offset = h
234
+
235
+ kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
236
+ img = torch.cat([img, kontext], dim=1)
237
+ img_ids = torch.cat([img_ids, kontext_ids], dim=1)
238
+ h = max(h, ref.shape[-2] + h_offset)
239
+ w = max(w, ref.shape[-1] + w_offset)
240
+
241
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
242
+ out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
243
+ out = out[:, :img_tokens]
244
+ return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
ComfyUI/comfy/ldm/flux/redux.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.ops
3
+
4
+ ops = comfy.ops.manual_cast
5
+
6
+ class ReduxImageEncoder(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ redux_dim: int = 1152,
10
+ txt_in_features: int = 4096,
11
+ device=None,
12
+ dtype=None,
13
+ ) -> None:
14
+ super().__init__()
15
+
16
+ self.redux_dim = redux_dim
17
+ self.device = device
18
+ self.dtype = dtype
19
+
20
+ self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
21
+ self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
22
+
23
+ def forward(self, sigclip_embeds) -> torch.Tensor:
24
+ projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
25
+ return projected_x
ComfyUI/comfy/ldm/hidream/model.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import einops
6
+ from einops import repeat
7
+
8
+ from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
9
+ import torch.nn.functional as F
10
+
11
+ from comfy.ldm.flux.math import apply_rope, rope
12
+ from comfy.ldm.flux.layers import LastLayer
13
+
14
+ from comfy.ldm.modules.attention import optimized_attention
15
+ import comfy.model_management
16
+ import comfy.ldm.common_dit
17
+
18
+
19
+ # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
20
+ class EmbedND(nn.Module):
21
+ def __init__(self, theta: int, axes_dim: List[int]):
22
+ super().__init__()
23
+ self.theta = theta
24
+ self.axes_dim = axes_dim
25
+
26
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
27
+ n_axes = ids.shape[-1]
28
+ emb = torch.cat(
29
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
30
+ dim=-3,
31
+ )
32
+ return emb.unsqueeze(2)
33
+
34
+
35
+ class PatchEmbed(nn.Module):
36
+ def __init__(
37
+ self,
38
+ patch_size=2,
39
+ in_channels=4,
40
+ out_channels=1024,
41
+ dtype=None, device=None, operations=None
42
+ ):
43
+ super().__init__()
44
+ self.patch_size = patch_size
45
+ self.out_channels = out_channels
46
+ self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
47
+
48
+ def forward(self, latent):
49
+ latent = self.proj(latent)
50
+ return latent
51
+
52
+
53
+ class PooledEmbed(nn.Module):
54
+ def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
55
+ super().__init__()
56
+ self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
57
+
58
+ def forward(self, pooled_embed):
59
+ return self.pooled_embedder(pooled_embed)
60
+
61
+
62
+ class TimestepEmbed(nn.Module):
63
+ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
64
+ super().__init__()
65
+ self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
66
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
67
+
68
+ def forward(self, timesteps, wdtype):
69
+ t_emb = self.time_proj(timesteps).to(dtype=wdtype)
70
+ t_emb = self.timestep_embedder(t_emb)
71
+ return t_emb
72
+
73
+
74
+ def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
75
+ return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
76
+
77
+
78
+ class HiDreamAttnProcessor_flashattn:
79
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
80
+
81
+ def __call__(
82
+ self,
83
+ attn,
84
+ image_tokens: torch.FloatTensor,
85
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
86
+ text_tokens: Optional[torch.FloatTensor] = None,
87
+ rope: torch.FloatTensor = None,
88
+ *args,
89
+ **kwargs,
90
+ ) -> torch.FloatTensor:
91
+ dtype = image_tokens.dtype
92
+ batch_size = image_tokens.shape[0]
93
+
94
+ query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
95
+ key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
96
+ value_i = attn.to_v(image_tokens)
97
+
98
+ inner_dim = key_i.shape[-1]
99
+ head_dim = inner_dim // attn.heads
100
+
101
+ query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
102
+ key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
103
+ value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
104
+ if image_tokens_masks is not None:
105
+ key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
106
+
107
+ if not attn.single:
108
+ query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
109
+ key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
110
+ value_t = attn.to_v_t(text_tokens)
111
+
112
+ query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
113
+ key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
114
+ value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
115
+
116
+ num_image_tokens = query_i.shape[1]
117
+ num_text_tokens = query_t.shape[1]
118
+ query = torch.cat([query_i, query_t], dim=1)
119
+ key = torch.cat([key_i, key_t], dim=1)
120
+ value = torch.cat([value_i, value_t], dim=1)
121
+ else:
122
+ query = query_i
123
+ key = key_i
124
+ value = value_i
125
+
126
+ if query.shape[-1] == rope.shape[-3] * 2:
127
+ query, key = apply_rope(query, key, rope)
128
+ else:
129
+ query_1, query_2 = query.chunk(2, dim=-1)
130
+ key_1, key_2 = key.chunk(2, dim=-1)
131
+ query_1, key_1 = apply_rope(query_1, key_1, rope)
132
+ query = torch.cat([query_1, query_2], dim=-1)
133
+ key = torch.cat([key_1, key_2], dim=-1)
134
+
135
+ hidden_states = attention(query, key, value)
136
+
137
+ if not attn.single:
138
+ hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
139
+ hidden_states_i = attn.to_out(hidden_states_i)
140
+ hidden_states_t = attn.to_out_t(hidden_states_t)
141
+ return hidden_states_i, hidden_states_t
142
+ else:
143
+ hidden_states = attn.to_out(hidden_states)
144
+ return hidden_states
145
+
146
+ class HiDreamAttention(nn.Module):
147
+ def __init__(
148
+ self,
149
+ query_dim: int,
150
+ heads: int = 8,
151
+ dim_head: int = 64,
152
+ upcast_attention: bool = False,
153
+ upcast_softmax: bool = False,
154
+ scale_qk: bool = True,
155
+ eps: float = 1e-5,
156
+ processor = None,
157
+ out_dim: int = None,
158
+ single: bool = False,
159
+ dtype=None, device=None, operations=None
160
+ ):
161
+ # super(Attention, self).__init__()
162
+ super().__init__()
163
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
164
+ self.query_dim = query_dim
165
+ self.upcast_attention = upcast_attention
166
+ self.upcast_softmax = upcast_softmax
167
+ self.out_dim = out_dim if out_dim is not None else query_dim
168
+
169
+ self.scale_qk = scale_qk
170
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
171
+
172
+ self.heads = out_dim // dim_head if out_dim is not None else heads
173
+ self.sliceable_head_dim = heads
174
+ self.single = single
175
+
176
+ linear_cls = operations.Linear
177
+ self.linear_cls = linear_cls
178
+ self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
179
+ self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
180
+ self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
181
+ self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
182
+ self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
183
+ self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
184
+
185
+ if not single:
186
+ self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
187
+ self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
188
+ self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
189
+ self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
190
+ self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
191
+ self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
192
+
193
+ self.processor = processor
194
+
195
+ def forward(
196
+ self,
197
+ norm_image_tokens: torch.FloatTensor,
198
+ image_tokens_masks: torch.FloatTensor = None,
199
+ norm_text_tokens: torch.FloatTensor = None,
200
+ rope: torch.FloatTensor = None,
201
+ ) -> torch.Tensor:
202
+ return self.processor(
203
+ self,
204
+ image_tokens = norm_image_tokens,
205
+ image_tokens_masks = image_tokens_masks,
206
+ text_tokens = norm_text_tokens,
207
+ rope = rope,
208
+ )
209
+
210
+
211
+ class FeedForwardSwiGLU(nn.Module):
212
+ def __init__(
213
+ self,
214
+ dim: int,
215
+ hidden_dim: int,
216
+ multiple_of: int = 256,
217
+ ffn_dim_multiplier: Optional[float] = None,
218
+ dtype=None, device=None, operations=None
219
+ ):
220
+ super().__init__()
221
+ hidden_dim = int(2 * hidden_dim / 3)
222
+ # custom dim factor multiplier
223
+ if ffn_dim_multiplier is not None:
224
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
225
+ hidden_dim = multiple_of * (
226
+ (hidden_dim + multiple_of - 1) // multiple_of
227
+ )
228
+
229
+ self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
230
+ self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
231
+ self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
232
+
233
+ def forward(self, x):
234
+ return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
235
+
236
+
237
+ # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
238
+ class MoEGate(nn.Module):
239
+ def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
240
+ super().__init__()
241
+ self.top_k = num_activated_experts
242
+ self.n_routed_experts = num_routed_experts
243
+
244
+ self.scoring_func = 'softmax'
245
+ self.alpha = aux_loss_alpha
246
+ self.seq_aux = False
247
+
248
+ # topk selection algorithm
249
+ self.norm_topk_prob = False
250
+ self.gating_dim = embed_dim
251
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
252
+ self.reset_parameters()
253
+
254
+ def reset_parameters(self) -> None:
255
+ pass
256
+ # import torch.nn.init as init
257
+ # init.kaiming_uniform_(self.weight, a=math.sqrt(5))
258
+
259
+ def forward(self, hidden_states):
260
+ bsz, seq_len, h = hidden_states.shape
261
+
262
+ ### compute gating score
263
+ hidden_states = hidden_states.view(-1, h)
264
+ logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
265
+ if self.scoring_func == 'softmax':
266
+ scores = logits.softmax(dim=-1)
267
+ else:
268
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
269
+
270
+ ### select top-k experts
271
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
272
+
273
+ ### norm gate to sum 1
274
+ if self.top_k > 1 and self.norm_topk_prob:
275
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
276
+ topk_weight = topk_weight / denominator
277
+
278
+ aux_loss = None
279
+ return topk_idx, topk_weight, aux_loss
280
+
281
+
282
+ # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
283
+ class MOEFeedForwardSwiGLU(nn.Module):
284
+ def __init__(
285
+ self,
286
+ dim: int,
287
+ hidden_dim: int,
288
+ num_routed_experts: int,
289
+ num_activated_experts: int,
290
+ dtype=None, device=None, operations=None
291
+ ):
292
+ super().__init__()
293
+ self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
294
+ self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
295
+ self.gate = MoEGate(
296
+ embed_dim = dim,
297
+ num_routed_experts = num_routed_experts,
298
+ num_activated_experts = num_activated_experts,
299
+ dtype=dtype, device=device, operations=operations
300
+ )
301
+ self.num_activated_experts = num_activated_experts
302
+
303
+ def forward(self, x):
304
+ wtype = x.dtype
305
+ identity = x
306
+ orig_shape = x.shape
307
+ topk_idx, topk_weight, aux_loss = self.gate(x)
308
+ x = x.view(-1, x.shape[-1])
309
+ flat_topk_idx = topk_idx.view(-1)
310
+ if True: # self.training: # TODO: check which branch performs faster
311
+ x = x.repeat_interleave(self.num_activated_experts, dim=0)
312
+ y = torch.empty_like(x, dtype=wtype)
313
+ for i, expert in enumerate(self.experts):
314
+ y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
315
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
316
+ y = y.view(*orig_shape).to(dtype=wtype)
317
+ #y = AddAuxiliaryLoss.apply(y, aux_loss)
318
+ else:
319
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
320
+ y = y + self.shared_experts(identity)
321
+ return y
322
+
323
+ @torch.no_grad()
324
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
325
+ expert_cache = torch.zeros_like(x)
326
+ idxs = flat_expert_indices.argsort()
327
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
328
+ token_idxs = idxs // self.num_activated_experts
329
+ for i, end_idx in enumerate(tokens_per_expert):
330
+ start_idx = 0 if i == 0 else tokens_per_expert[i-1]
331
+ if start_idx == end_idx:
332
+ continue
333
+ expert = self.experts[i]
334
+ exp_token_idx = token_idxs[start_idx:end_idx]
335
+ expert_tokens = x[exp_token_idx]
336
+ expert_out = expert(expert_tokens)
337
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
338
+
339
+ # for fp16 and other dtype
340
+ expert_cache = expert_cache.to(expert_out.dtype)
341
+ expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
342
+ return expert_cache
343
+
344
+
345
+ class TextProjection(nn.Module):
346
+ def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
347
+ super().__init__()
348
+ self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
349
+
350
+ def forward(self, caption):
351
+ hidden_states = self.linear(caption)
352
+ return hidden_states
353
+
354
+
355
+ class BlockType:
356
+ TransformerBlock = 1
357
+ SingleTransformerBlock = 2
358
+
359
+
360
+ class HiDreamImageSingleTransformerBlock(nn.Module):
361
+ def __init__(
362
+ self,
363
+ dim: int,
364
+ num_attention_heads: int,
365
+ attention_head_dim: int,
366
+ num_routed_experts: int = 4,
367
+ num_activated_experts: int = 2,
368
+ dtype=None, device=None, operations=None
369
+ ):
370
+ super().__init__()
371
+ self.num_attention_heads = num_attention_heads
372
+ self.adaLN_modulation = nn.Sequential(
373
+ nn.SiLU(),
374
+ operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
375
+ )
376
+
377
+ # 1. Attention
378
+ self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
379
+ self.attn1 = HiDreamAttention(
380
+ query_dim=dim,
381
+ heads=num_attention_heads,
382
+ dim_head=attention_head_dim,
383
+ processor = HiDreamAttnProcessor_flashattn(),
384
+ single = True,
385
+ dtype=dtype, device=device, operations=operations
386
+ )
387
+
388
+ # 3. Feed-forward
389
+ self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
390
+ if num_routed_experts > 0:
391
+ self.ff_i = MOEFeedForwardSwiGLU(
392
+ dim = dim,
393
+ hidden_dim = 4 * dim,
394
+ num_routed_experts = num_routed_experts,
395
+ num_activated_experts = num_activated_experts,
396
+ dtype=dtype, device=device, operations=operations
397
+ )
398
+ else:
399
+ self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
400
+
401
+ def forward(
402
+ self,
403
+ image_tokens: torch.FloatTensor,
404
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
405
+ text_tokens: Optional[torch.FloatTensor] = None,
406
+ adaln_input: Optional[torch.FloatTensor] = None,
407
+ rope: torch.FloatTensor = None,
408
+
409
+ ) -> torch.FloatTensor:
410
+ wtype = image_tokens.dtype
411
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
412
+ self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
413
+
414
+ # 1. MM-Attention
415
+ norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
416
+ norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
417
+ attn_output_i = self.attn1(
418
+ norm_image_tokens,
419
+ image_tokens_masks,
420
+ rope = rope,
421
+ )
422
+ image_tokens = gate_msa_i * attn_output_i + image_tokens
423
+
424
+ # 2. Feed-forward
425
+ norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
426
+ norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
427
+ ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
428
+ image_tokens = ff_output_i + image_tokens
429
+ return image_tokens
430
+
431
+
432
+ class HiDreamImageTransformerBlock(nn.Module):
433
+ def __init__(
434
+ self,
435
+ dim: int,
436
+ num_attention_heads: int,
437
+ attention_head_dim: int,
438
+ num_routed_experts: int = 4,
439
+ num_activated_experts: int = 2,
440
+ dtype=None, device=None, operations=None
441
+ ):
442
+ super().__init__()
443
+ self.num_attention_heads = num_attention_heads
444
+ self.adaLN_modulation = nn.Sequential(
445
+ nn.SiLU(),
446
+ operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
447
+ )
448
+ # nn.init.zeros_(self.adaLN_modulation[1].weight)
449
+ # nn.init.zeros_(self.adaLN_modulation[1].bias)
450
+
451
+ # 1. Attention
452
+ self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
453
+ self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
454
+ self.attn1 = HiDreamAttention(
455
+ query_dim=dim,
456
+ heads=num_attention_heads,
457
+ dim_head=attention_head_dim,
458
+ processor = HiDreamAttnProcessor_flashattn(),
459
+ single = False,
460
+ dtype=dtype, device=device, operations=operations
461
+ )
462
+
463
+ # 3. Feed-forward
464
+ self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
465
+ if num_routed_experts > 0:
466
+ self.ff_i = MOEFeedForwardSwiGLU(
467
+ dim = dim,
468
+ hidden_dim = 4 * dim,
469
+ num_routed_experts = num_routed_experts,
470
+ num_activated_experts = num_activated_experts,
471
+ dtype=dtype, device=device, operations=operations
472
+ )
473
+ else:
474
+ self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
475
+ self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
476
+ self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
477
+
478
+ def forward(
479
+ self,
480
+ image_tokens: torch.FloatTensor,
481
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
482
+ text_tokens: Optional[torch.FloatTensor] = None,
483
+ adaln_input: Optional[torch.FloatTensor] = None,
484
+ rope: torch.FloatTensor = None,
485
+ ) -> torch.FloatTensor:
486
+ wtype = image_tokens.dtype
487
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
488
+ shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
489
+ self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
490
+
491
+ # 1. MM-Attention
492
+ norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
493
+ norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
494
+ norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
495
+ norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
496
+
497
+ attn_output_i, attn_output_t = self.attn1(
498
+ norm_image_tokens,
499
+ image_tokens_masks,
500
+ norm_text_tokens,
501
+ rope = rope,
502
+ )
503
+
504
+ image_tokens = gate_msa_i * attn_output_i + image_tokens
505
+ text_tokens = gate_msa_t * attn_output_t + text_tokens
506
+
507
+ # 2. Feed-forward
508
+ norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
509
+ norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
510
+ norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
511
+ norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
512
+
513
+ ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
514
+ ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
515
+ image_tokens = ff_output_i + image_tokens
516
+ text_tokens = ff_output_t + text_tokens
517
+ return image_tokens, text_tokens
518
+
519
+
520
+ class HiDreamImageBlock(nn.Module):
521
+ def __init__(
522
+ self,
523
+ dim: int,
524
+ num_attention_heads: int,
525
+ attention_head_dim: int,
526
+ num_routed_experts: int = 4,
527
+ num_activated_experts: int = 2,
528
+ block_type: BlockType = BlockType.TransformerBlock,
529
+ dtype=None, device=None, operations=None
530
+ ):
531
+ super().__init__()
532
+ block_classes = {
533
+ BlockType.TransformerBlock: HiDreamImageTransformerBlock,
534
+ BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
535
+ }
536
+ self.block = block_classes[block_type](
537
+ dim,
538
+ num_attention_heads,
539
+ attention_head_dim,
540
+ num_routed_experts,
541
+ num_activated_experts,
542
+ dtype=dtype, device=device, operations=operations
543
+ )
544
+
545
+ def forward(
546
+ self,
547
+ image_tokens: torch.FloatTensor,
548
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
549
+ text_tokens: Optional[torch.FloatTensor] = None,
550
+ adaln_input: torch.FloatTensor = None,
551
+ rope: torch.FloatTensor = None,
552
+ ) -> torch.FloatTensor:
553
+ return self.block(
554
+ image_tokens,
555
+ image_tokens_masks,
556
+ text_tokens,
557
+ adaln_input,
558
+ rope,
559
+ )
560
+
561
+
562
+ class HiDreamImageTransformer2DModel(nn.Module):
563
+ def __init__(
564
+ self,
565
+ patch_size: Optional[int] = None,
566
+ in_channels: int = 64,
567
+ out_channels: Optional[int] = None,
568
+ num_layers: int = 16,
569
+ num_single_layers: int = 32,
570
+ attention_head_dim: int = 128,
571
+ num_attention_heads: int = 20,
572
+ caption_channels: List[int] = None,
573
+ text_emb_dim: int = 2048,
574
+ num_routed_experts: int = 4,
575
+ num_activated_experts: int = 2,
576
+ axes_dims_rope: Tuple[int, int] = (32, 32),
577
+ max_resolution: Tuple[int, int] = (128, 128),
578
+ llama_layers: List[int] = None,
579
+ image_model=None,
580
+ dtype=None, device=None, operations=None
581
+ ):
582
+ self.patch_size = patch_size
583
+ self.num_attention_heads = num_attention_heads
584
+ self.attention_head_dim = attention_head_dim
585
+ self.num_layers = num_layers
586
+ self.num_single_layers = num_single_layers
587
+
588
+ self.gradient_checkpointing = False
589
+
590
+ super().__init__()
591
+ self.dtype = dtype
592
+ self.out_channels = out_channels or in_channels
593
+ self.inner_dim = self.num_attention_heads * self.attention_head_dim
594
+ self.llama_layers = llama_layers
595
+
596
+ self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
597
+ self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
598
+ self.x_embedder = PatchEmbed(
599
+ patch_size = patch_size,
600
+ in_channels = in_channels,
601
+ out_channels = self.inner_dim,
602
+ dtype=dtype, device=device, operations=operations
603
+ )
604
+ self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
605
+
606
+ self.double_stream_blocks = nn.ModuleList(
607
+ [
608
+ HiDreamImageBlock(
609
+ dim = self.inner_dim,
610
+ num_attention_heads = self.num_attention_heads,
611
+ attention_head_dim = self.attention_head_dim,
612
+ num_routed_experts = num_routed_experts,
613
+ num_activated_experts = num_activated_experts,
614
+ block_type = BlockType.TransformerBlock,
615
+ dtype=dtype, device=device, operations=operations
616
+ )
617
+ for i in range(self.num_layers)
618
+ ]
619
+ )
620
+
621
+ self.single_stream_blocks = nn.ModuleList(
622
+ [
623
+ HiDreamImageBlock(
624
+ dim = self.inner_dim,
625
+ num_attention_heads = self.num_attention_heads,
626
+ attention_head_dim = self.attention_head_dim,
627
+ num_routed_experts = num_routed_experts,
628
+ num_activated_experts = num_activated_experts,
629
+ block_type = BlockType.SingleTransformerBlock,
630
+ dtype=dtype, device=device, operations=operations
631
+ )
632
+ for i in range(self.num_single_layers)
633
+ ]
634
+ )
635
+
636
+ self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
637
+
638
+ caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
639
+ caption_projection = []
640
+ for caption_channel in caption_channels:
641
+ caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
642
+ self.caption_projection = nn.ModuleList(caption_projection)
643
+ self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
644
+
645
+ def expand_timesteps(self, timesteps, batch_size, device):
646
+ if not torch.is_tensor(timesteps):
647
+ is_mps = device.type == "mps"
648
+ if isinstance(timesteps, float):
649
+ dtype = torch.float32 if is_mps else torch.float64
650
+ else:
651
+ dtype = torch.int32 if is_mps else torch.int64
652
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
653
+ elif len(timesteps.shape) == 0:
654
+ timesteps = timesteps[None].to(device)
655
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
656
+ timesteps = timesteps.expand(batch_size)
657
+ return timesteps
658
+
659
+ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
660
+ x_arr = []
661
+ for i, img_size in enumerate(img_sizes):
662
+ pH, pW = img_size
663
+ x_arr.append(
664
+ einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
665
+ p1=self.patch_size, p2=self.patch_size)
666
+ )
667
+ x = torch.cat(x_arr, dim=0)
668
+ return x
669
+
670
+ def patchify(self, x, max_seq, img_sizes=None):
671
+ pz2 = self.patch_size * self.patch_size
672
+ if isinstance(x, torch.Tensor):
673
+ B = x.shape[0]
674
+ device = x.device
675
+ dtype = x.dtype
676
+ else:
677
+ B = len(x)
678
+ device = x[0].device
679
+ dtype = x[0].dtype
680
+ x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
681
+
682
+ if img_sizes is not None:
683
+ for i, img_size in enumerate(img_sizes):
684
+ x_masks[i, 0:img_size[0] * img_size[1]] = 1
685
+ x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
686
+ elif isinstance(x, torch.Tensor):
687
+ pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
688
+ x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
689
+ img_sizes = [[pH, pW]] * B
690
+ x_masks = None
691
+ else:
692
+ raise NotImplementedError
693
+ return x, x_masks, img_sizes
694
+
695
+ def forward(
696
+ self,
697
+ x: torch.Tensor,
698
+ t: torch.Tensor,
699
+ y: Optional[torch.Tensor] = None,
700
+ context: Optional[torch.Tensor] = None,
701
+ encoder_hidden_states_llama3=None,
702
+ image_cond=None,
703
+ control = None,
704
+ transformer_options = {},
705
+ ) -> torch.Tensor:
706
+ bs, c, h, w = x.shape
707
+ if image_cond is not None:
708
+ x = torch.cat([x, image_cond], dim=-1)
709
+ hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
710
+ timesteps = t
711
+ pooled_embeds = y
712
+ T5_encoder_hidden_states = context
713
+
714
+ img_sizes = None
715
+
716
+ # spatial forward
717
+ batch_size = hidden_states.shape[0]
718
+ hidden_states_type = hidden_states.dtype
719
+
720
+ # 0. time
721
+ timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
722
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
723
+ p_embedder = self.p_embedder(pooled_embeds)
724
+ adaln_input = timesteps + p_embedder
725
+
726
+ hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
727
+ if image_tokens_masks is None:
728
+ pH, pW = img_sizes[0]
729
+ img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
730
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
731
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
732
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
733
+ hidden_states = self.x_embedder(hidden_states)
734
+
735
+ # T5_encoder_hidden_states = encoder_hidden_states[0]
736
+ encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
737
+ encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
738
+
739
+ if self.caption_projection is not None:
740
+ new_encoder_hidden_states = []
741
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
742
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
743
+ enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
744
+ new_encoder_hidden_states.append(enc_hidden_state)
745
+ encoder_hidden_states = new_encoder_hidden_states
746
+ T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
747
+ T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
748
+ encoder_hidden_states.append(T5_encoder_hidden_states)
749
+
750
+ txt_ids = torch.zeros(
751
+ batch_size,
752
+ encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
753
+ 3,
754
+ device=img_ids.device, dtype=img_ids.dtype
755
+ )
756
+ ids = torch.cat((img_ids, txt_ids), dim=1)
757
+ rope = self.pe_embedder(ids)
758
+
759
+ # 2. Blocks
760
+ block_id = 0
761
+ initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
762
+ initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
763
+ for bid, block in enumerate(self.double_stream_blocks):
764
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
765
+ cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
766
+ hidden_states, initial_encoder_hidden_states = block(
767
+ image_tokens = hidden_states,
768
+ image_tokens_masks = image_tokens_masks,
769
+ text_tokens = cur_encoder_hidden_states,
770
+ adaln_input = adaln_input,
771
+ rope = rope,
772
+ )
773
+ initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
774
+ block_id += 1
775
+
776
+ image_tokens_seq_len = hidden_states.shape[1]
777
+ hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
778
+ hidden_states_seq_len = hidden_states.shape[1]
779
+ if image_tokens_masks is not None:
780
+ encoder_attention_mask_ones = torch.ones(
781
+ (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
782
+ device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
783
+ )
784
+ image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
785
+
786
+ for bid, block in enumerate(self.single_stream_blocks):
787
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
788
+ hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
789
+ hidden_states = block(
790
+ image_tokens=hidden_states,
791
+ image_tokens_masks=image_tokens_masks,
792
+ text_tokens=None,
793
+ adaln_input=adaln_input,
794
+ rope=rope,
795
+ )
796
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
797
+ block_id += 1
798
+
799
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
800
+ output = self.final_layer(hidden_states, adaln_input)
801
+ output = self.unpatchify(output, img_sizes)
802
+ return -output[:, :, :h, :w]
ComfyUI/comfy/ldm/hunyuan3d/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from comfy.ldm.flux.layers import (
4
+ DoubleStreamBlock,
5
+ LastLayer,
6
+ MLPEmbedder,
7
+ SingleStreamBlock,
8
+ timestep_embedding,
9
+ )
10
+
11
+
12
+ class Hunyuan3Dv2(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels=64,
16
+ context_in_dim=1536,
17
+ hidden_size=1024,
18
+ mlp_ratio=4.0,
19
+ num_heads=16,
20
+ depth=16,
21
+ depth_single_blocks=32,
22
+ qkv_bias=True,
23
+ guidance_embed=False,
24
+ image_model=None,
25
+ dtype=None,
26
+ device=None,
27
+ operations=None
28
+ ):
29
+ super().__init__()
30
+ self.dtype = dtype
31
+
32
+ if hidden_size % num_heads != 0:
33
+ raise ValueError(
34
+ f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
35
+ )
36
+
37
+ self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
38
+ self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
39
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
40
+ self.guidance_in = (
41
+ MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
42
+ )
43
+ self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
44
+ self.double_blocks = nn.ModuleList(
45
+ [
46
+ DoubleStreamBlock(
47
+ hidden_size,
48
+ num_heads,
49
+ mlp_ratio=mlp_ratio,
50
+ qkv_bias=qkv_bias,
51
+ dtype=dtype, device=device, operations=operations
52
+ )
53
+ for _ in range(depth)
54
+ ]
55
+ )
56
+ self.single_blocks = nn.ModuleList(
57
+ [
58
+ SingleStreamBlock(
59
+ hidden_size,
60
+ num_heads,
61
+ mlp_ratio=mlp_ratio,
62
+ dtype=dtype, device=device, operations=operations
63
+ )
64
+ for _ in range(depth_single_blocks)
65
+ ]
66
+ )
67
+ self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
68
+
69
+ def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
70
+ x = x.movedim(-1, -2)
71
+ timestep = 1.0 - timestep
72
+ txt = context
73
+ img = self.latent_in(x)
74
+
75
+ vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
76
+ if self.guidance_in is not None:
77
+ if guidance is not None:
78
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
79
+
80
+ txt = self.cond_in(txt)
81
+ pe = None
82
+ attn_mask = None
83
+
84
+ patches_replace = transformer_options.get("patches_replace", {})
85
+ blocks_replace = patches_replace.get("dit", {})
86
+ for i, block in enumerate(self.double_blocks):
87
+ if ("double_block", i) in blocks_replace:
88
+ def block_wrap(args):
89
+ out = {}
90
+ out["img"], out["txt"] = block(img=args["img"],
91
+ txt=args["txt"],
92
+ vec=args["vec"],
93
+ pe=args["pe"],
94
+ attn_mask=args.get("attn_mask"))
95
+ return out
96
+
97
+ out = blocks_replace[("double_block", i)]({"img": img,
98
+ "txt": txt,
99
+ "vec": vec,
100
+ "pe": pe,
101
+ "attn_mask": attn_mask},
102
+ {"original_block": block_wrap})
103
+ txt = out["txt"]
104
+ img = out["img"]
105
+ else:
106
+ img, txt = block(img=img,
107
+ txt=txt,
108
+ vec=vec,
109
+ pe=pe,
110
+ attn_mask=attn_mask)
111
+
112
+ img = torch.cat((txt, img), 1)
113
+
114
+ for i, block in enumerate(self.single_blocks):
115
+ if ("single_block", i) in blocks_replace:
116
+ def block_wrap(args):
117
+ out = {}
118
+ out["img"] = block(args["img"],
119
+ vec=args["vec"],
120
+ pe=args["pe"],
121
+ attn_mask=args.get("attn_mask"))
122
+ return out
123
+
124
+ out = blocks_replace[("single_block", i)]({"img": img,
125
+ "vec": vec,
126
+ "pe": pe,
127
+ "attn_mask": attn_mask},
128
+ {"original_block": block_wrap})
129
+ img = out["img"]
130
+ else:
131
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
132
+
133
+ img = img[:, txt.shape[1]:, ...]
134
+ img = self.final_layer(img, vec)
135
+ return img.movedim(-2, -1) * (-1.0)
ComfyUI/comfy/ldm/hunyuan3d/vae.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
2
+ # Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ from typing import Union, Tuple, List, Callable, Optional
10
+
11
+ import numpy as np
12
+ from einops import repeat, rearrange
13
+ from tqdm import tqdm
14
+ import logging
15
+
16
+ import comfy.ops
17
+ ops = comfy.ops.disable_weight_init
18
+
19
+ def generate_dense_grid_points(
20
+ bbox_min: np.ndarray,
21
+ bbox_max: np.ndarray,
22
+ octree_resolution: int,
23
+ indexing: str = "ij",
24
+ ):
25
+ length = bbox_max - bbox_min
26
+ num_cells = octree_resolution
27
+
28
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
29
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
30
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
31
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
32
+ xyz = np.stack((xs, ys, zs), axis=-1)
33
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
34
+
35
+ return xyz, grid_size, length
36
+
37
+
38
+ class VanillaVolumeDecoder:
39
+ @torch.no_grad()
40
+ def __call__(
41
+ self,
42
+ latents: torch.FloatTensor,
43
+ geo_decoder: Callable,
44
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
45
+ num_chunks: int = 10000,
46
+ octree_resolution: int = None,
47
+ enable_pbar: bool = True,
48
+ **kwargs,
49
+ ):
50
+ device = latents.device
51
+ dtype = latents.dtype
52
+ batch_size = latents.shape[0]
53
+
54
+ # 1. generate query points
55
+ if isinstance(bounds, float):
56
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
57
+
58
+ bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
59
+ xyz_samples, grid_size, length = generate_dense_grid_points(
60
+ bbox_min=bbox_min,
61
+ bbox_max=bbox_max,
62
+ octree_resolution=octree_resolution,
63
+ indexing="ij"
64
+ )
65
+ xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
66
+
67
+ # 2. latents to 3d volume
68
+ batch_logits = []
69
+ for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
70
+ disable=not enable_pbar):
71
+ chunk_queries = xyz_samples[start: start + num_chunks, :]
72
+ chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
73
+ logits = geo_decoder(queries=chunk_queries, latents=latents)
74
+ batch_logits.append(logits)
75
+
76
+ grid_logits = torch.cat(batch_logits, dim=1)
77
+ grid_logits = grid_logits.view((batch_size, *grid_size)).float()
78
+
79
+ return grid_logits
80
+
81
+
82
+ class FourierEmbedder(nn.Module):
83
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
84
+ each feature dimension of `x[..., i]` into:
85
+ [
86
+ sin(x[..., i]),
87
+ sin(f_1*x[..., i]),
88
+ sin(f_2*x[..., i]),
89
+ ...
90
+ sin(f_N * x[..., i]),
91
+ cos(x[..., i]),
92
+ cos(f_1*x[..., i]),
93
+ cos(f_2*x[..., i]),
94
+ ...
95
+ cos(f_N * x[..., i]),
96
+ x[..., i] # only present if include_input is True.
97
+ ], here f_i is the frequency.
98
+
99
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
100
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
101
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
102
+
103
+ Args:
104
+ num_freqs (int): the number of frequencies, default is 6;
105
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
106
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
107
+ input_dim (int): the input dimension, default is 3;
108
+ include_input (bool): include the input tensor or not, default is True.
109
+
110
+ Attributes:
111
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
112
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
113
+
114
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
115
+ otherwise, it is input_dim * num_freqs * 2.
116
+
117
+ """
118
+
119
+ def __init__(self,
120
+ num_freqs: int = 6,
121
+ logspace: bool = True,
122
+ input_dim: int = 3,
123
+ include_input: bool = True,
124
+ include_pi: bool = True) -> None:
125
+
126
+ """The initialization"""
127
+
128
+ super().__init__()
129
+
130
+ if logspace:
131
+ frequencies = 2.0 ** torch.arange(
132
+ num_freqs,
133
+ dtype=torch.float32
134
+ )
135
+ else:
136
+ frequencies = torch.linspace(
137
+ 1.0,
138
+ 2.0 ** (num_freqs - 1),
139
+ num_freqs,
140
+ dtype=torch.float32
141
+ )
142
+
143
+ if include_pi:
144
+ frequencies *= torch.pi
145
+
146
+ self.register_buffer("frequencies", frequencies, persistent=False)
147
+ self.include_input = include_input
148
+ self.num_freqs = num_freqs
149
+
150
+ self.out_dim = self.get_dims(input_dim)
151
+
152
+ def get_dims(self, input_dim):
153
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
154
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
155
+
156
+ return out_dim
157
+
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ """ Forward process.
160
+
161
+ Args:
162
+ x: tensor of shape [..., dim]
163
+
164
+ Returns:
165
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
166
+ where temp is 1 if include_input is True and 0 otherwise.
167
+ """
168
+
169
+ if self.num_freqs > 0:
170
+ embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
171
+ if self.include_input:
172
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
173
+ else:
174
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
175
+ else:
176
+ return x
177
+
178
+
179
+ class CrossAttentionProcessor:
180
+ def __call__(self, attn, q, k, v):
181
+ out = F.scaled_dot_product_attention(q, k, v)
182
+ return out
183
+
184
+
185
+ class DropPath(nn.Module):
186
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
187
+ """
188
+
189
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
190
+ super(DropPath, self).__init__()
191
+ self.drop_prob = drop_prob
192
+ self.scale_by_keep = scale_by_keep
193
+
194
+ def forward(self, x):
195
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
196
+
197
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
198
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
199
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
200
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
201
+ 'survival rate' as the argument.
202
+
203
+ """
204
+ if self.drop_prob == 0. or not self.training:
205
+ return x
206
+ keep_prob = 1 - self.drop_prob
207
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
208
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
209
+ if keep_prob > 0.0 and self.scale_by_keep:
210
+ random_tensor.div_(keep_prob)
211
+ return x * random_tensor
212
+
213
+ def extra_repr(self):
214
+ return f'drop_prob={round(self.drop_prob, 3):0.3f}'
215
+
216
+
217
+ class MLP(nn.Module):
218
+ def __init__(
219
+ self, *,
220
+ width: int,
221
+ expand_ratio: int = 4,
222
+ output_width: int = None,
223
+ drop_path_rate: float = 0.0
224
+ ):
225
+ super().__init__()
226
+ self.width = width
227
+ self.c_fc = ops.Linear(width, width * expand_ratio)
228
+ self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
229
+ self.gelu = nn.GELU()
230
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
231
+
232
+ def forward(self, x):
233
+ return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
234
+
235
+
236
+ class QKVMultiheadCrossAttention(nn.Module):
237
+ def __init__(
238
+ self,
239
+ *,
240
+ heads: int,
241
+ width=None,
242
+ qk_norm=False,
243
+ norm_layer=ops.LayerNorm
244
+ ):
245
+ super().__init__()
246
+ self.heads = heads
247
+ self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
248
+ self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
249
+
250
+ self.attn_processor = CrossAttentionProcessor()
251
+
252
+ def forward(self, q, kv):
253
+ _, n_ctx, _ = q.shape
254
+ bs, n_data, width = kv.shape
255
+ attn_ch = width // self.heads // 2
256
+ q = q.view(bs, n_ctx, self.heads, -1)
257
+ kv = kv.view(bs, n_data, self.heads, -1)
258
+ k, v = torch.split(kv, attn_ch, dim=-1)
259
+
260
+ q = self.q_norm(q)
261
+ k = self.k_norm(k)
262
+ q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
263
+ out = self.attn_processor(self, q, k, v)
264
+ out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
265
+ return out
266
+
267
+
268
+ class MultiheadCrossAttention(nn.Module):
269
+ def __init__(
270
+ self,
271
+ *,
272
+ width: int,
273
+ heads: int,
274
+ qkv_bias: bool = True,
275
+ data_width: Optional[int] = None,
276
+ norm_layer=ops.LayerNorm,
277
+ qk_norm: bool = False,
278
+ kv_cache: bool = False,
279
+ ):
280
+ super().__init__()
281
+ self.width = width
282
+ self.heads = heads
283
+ self.data_width = width if data_width is None else data_width
284
+ self.c_q = ops.Linear(width, width, bias=qkv_bias)
285
+ self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
286
+ self.c_proj = ops.Linear(width, width)
287
+ self.attention = QKVMultiheadCrossAttention(
288
+ heads=heads,
289
+ width=width,
290
+ norm_layer=norm_layer,
291
+ qk_norm=qk_norm
292
+ )
293
+ self.kv_cache = kv_cache
294
+ self.data = None
295
+
296
+ def forward(self, x, data):
297
+ x = self.c_q(x)
298
+ if self.kv_cache:
299
+ if self.data is None:
300
+ self.data = self.c_kv(data)
301
+ logging.info('Save kv cache,this should be called only once for one mesh')
302
+ data = self.data
303
+ else:
304
+ data = self.c_kv(data)
305
+ x = self.attention(x, data)
306
+ x = self.c_proj(x)
307
+ return x
308
+
309
+
310
+ class ResidualCrossAttentionBlock(nn.Module):
311
+ def __init__(
312
+ self,
313
+ *,
314
+ width: int,
315
+ heads: int,
316
+ mlp_expand_ratio: int = 4,
317
+ data_width: Optional[int] = None,
318
+ qkv_bias: bool = True,
319
+ norm_layer=ops.LayerNorm,
320
+ qk_norm: bool = False
321
+ ):
322
+ super().__init__()
323
+
324
+ if data_width is None:
325
+ data_width = width
326
+
327
+ self.attn = MultiheadCrossAttention(
328
+ width=width,
329
+ heads=heads,
330
+ data_width=data_width,
331
+ qkv_bias=qkv_bias,
332
+ norm_layer=norm_layer,
333
+ qk_norm=qk_norm
334
+ )
335
+ self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
336
+ self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
337
+ self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
338
+ self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
339
+
340
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
341
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
342
+ x = x + self.mlp(self.ln_3(x))
343
+ return x
344
+
345
+
346
+ class QKVMultiheadAttention(nn.Module):
347
+ def __init__(
348
+ self,
349
+ *,
350
+ heads: int,
351
+ width=None,
352
+ qk_norm=False,
353
+ norm_layer=ops.LayerNorm
354
+ ):
355
+ super().__init__()
356
+ self.heads = heads
357
+ self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
358
+ self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
359
+
360
+ def forward(self, qkv):
361
+ bs, n_ctx, width = qkv.shape
362
+ attn_ch = width // self.heads // 3
363
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
364
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
365
+
366
+ q = self.q_norm(q)
367
+ k = self.k_norm(k)
368
+
369
+ q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
370
+ out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
371
+ return out
372
+
373
+
374
+ class MultiheadAttention(nn.Module):
375
+ def __init__(
376
+ self,
377
+ *,
378
+ width: int,
379
+ heads: int,
380
+ qkv_bias: bool,
381
+ norm_layer=ops.LayerNorm,
382
+ qk_norm: bool = False,
383
+ drop_path_rate: float = 0.0
384
+ ):
385
+ super().__init__()
386
+ self.width = width
387
+ self.heads = heads
388
+ self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
389
+ self.c_proj = ops.Linear(width, width)
390
+ self.attention = QKVMultiheadAttention(
391
+ heads=heads,
392
+ width=width,
393
+ norm_layer=norm_layer,
394
+ qk_norm=qk_norm
395
+ )
396
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
397
+
398
+ def forward(self, x):
399
+ x = self.c_qkv(x)
400
+ x = self.attention(x)
401
+ x = self.drop_path(self.c_proj(x))
402
+ return x
403
+
404
+
405
+ class ResidualAttentionBlock(nn.Module):
406
+ def __init__(
407
+ self,
408
+ *,
409
+ width: int,
410
+ heads: int,
411
+ qkv_bias: bool = True,
412
+ norm_layer=ops.LayerNorm,
413
+ qk_norm: bool = False,
414
+ drop_path_rate: float = 0.0,
415
+ ):
416
+ super().__init__()
417
+ self.attn = MultiheadAttention(
418
+ width=width,
419
+ heads=heads,
420
+ qkv_bias=qkv_bias,
421
+ norm_layer=norm_layer,
422
+ qk_norm=qk_norm,
423
+ drop_path_rate=drop_path_rate
424
+ )
425
+ self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
426
+ self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
427
+ self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
428
+
429
+ def forward(self, x: torch.Tensor):
430
+ x = x + self.attn(self.ln_1(x))
431
+ x = x + self.mlp(self.ln_2(x))
432
+ return x
433
+
434
+
435
+ class Transformer(nn.Module):
436
+ def __init__(
437
+ self,
438
+ *,
439
+ width: int,
440
+ layers: int,
441
+ heads: int,
442
+ qkv_bias: bool = True,
443
+ norm_layer=ops.LayerNorm,
444
+ qk_norm: bool = False,
445
+ drop_path_rate: float = 0.0
446
+ ):
447
+ super().__init__()
448
+ self.width = width
449
+ self.layers = layers
450
+ self.resblocks = nn.ModuleList(
451
+ [
452
+ ResidualAttentionBlock(
453
+ width=width,
454
+ heads=heads,
455
+ qkv_bias=qkv_bias,
456
+ norm_layer=norm_layer,
457
+ qk_norm=qk_norm,
458
+ drop_path_rate=drop_path_rate
459
+ )
460
+ for _ in range(layers)
461
+ ]
462
+ )
463
+
464
+ def forward(self, x: torch.Tensor):
465
+ for block in self.resblocks:
466
+ x = block(x)
467
+ return x
468
+
469
+
470
+ class CrossAttentionDecoder(nn.Module):
471
+
472
+ def __init__(
473
+ self,
474
+ *,
475
+ out_channels: int,
476
+ fourier_embedder: FourierEmbedder,
477
+ width: int,
478
+ heads: int,
479
+ mlp_expand_ratio: int = 4,
480
+ downsample_ratio: int = 1,
481
+ enable_ln_post: bool = True,
482
+ qkv_bias: bool = True,
483
+ qk_norm: bool = False,
484
+ label_type: str = "binary"
485
+ ):
486
+ super().__init__()
487
+
488
+ self.enable_ln_post = enable_ln_post
489
+ self.fourier_embedder = fourier_embedder
490
+ self.downsample_ratio = downsample_ratio
491
+ self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
492
+ if self.downsample_ratio != 1:
493
+ self.latents_proj = ops.Linear(width * downsample_ratio, width)
494
+ if self.enable_ln_post == False:
495
+ qk_norm = False
496
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
497
+ width=width,
498
+ mlp_expand_ratio=mlp_expand_ratio,
499
+ heads=heads,
500
+ qkv_bias=qkv_bias,
501
+ qk_norm=qk_norm
502
+ )
503
+
504
+ if self.enable_ln_post:
505
+ self.ln_post = ops.LayerNorm(width)
506
+ self.output_proj = ops.Linear(width, out_channels)
507
+ self.label_type = label_type
508
+ self.count = 0
509
+
510
+ def forward(self, queries=None, query_embeddings=None, latents=None):
511
+ if query_embeddings is None:
512
+ query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
513
+ self.count += query_embeddings.shape[1]
514
+ if self.downsample_ratio != 1:
515
+ latents = self.latents_proj(latents)
516
+ x = self.cross_attn_decoder(query_embeddings, latents)
517
+ if self.enable_ln_post:
518
+ x = self.ln_post(x)
519
+ occ = self.output_proj(x)
520
+ return occ
521
+
522
+
523
+ class ShapeVAE(nn.Module):
524
+ def __init__(
525
+ self,
526
+ *,
527
+ embed_dim: int,
528
+ width: int,
529
+ heads: int,
530
+ num_decoder_layers: int,
531
+ geo_decoder_downsample_ratio: int = 1,
532
+ geo_decoder_mlp_expand_ratio: int = 4,
533
+ geo_decoder_ln_post: bool = True,
534
+ num_freqs: int = 8,
535
+ include_pi: bool = True,
536
+ qkv_bias: bool = True,
537
+ qk_norm: bool = False,
538
+ label_type: str = "binary",
539
+ drop_path_rate: float = 0.0,
540
+ scale_factor: float = 1.0,
541
+ ):
542
+ super().__init__()
543
+ self.geo_decoder_ln_post = geo_decoder_ln_post
544
+
545
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
546
+
547
+ self.post_kl = ops.Linear(embed_dim, width)
548
+
549
+ self.transformer = Transformer(
550
+ width=width,
551
+ layers=num_decoder_layers,
552
+ heads=heads,
553
+ qkv_bias=qkv_bias,
554
+ qk_norm=qk_norm,
555
+ drop_path_rate=drop_path_rate
556
+ )
557
+
558
+ self.geo_decoder = CrossAttentionDecoder(
559
+ fourier_embedder=self.fourier_embedder,
560
+ out_channels=1,
561
+ mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
562
+ downsample_ratio=geo_decoder_downsample_ratio,
563
+ enable_ln_post=self.geo_decoder_ln_post,
564
+ width=width // geo_decoder_downsample_ratio,
565
+ heads=heads // geo_decoder_downsample_ratio,
566
+ qkv_bias=qkv_bias,
567
+ qk_norm=qk_norm,
568
+ label_type=label_type,
569
+ )
570
+
571
+ self.volume_decoder = VanillaVolumeDecoder()
572
+ self.scale_factor = scale_factor
573
+
574
+ def decode(self, latents, **kwargs):
575
+ latents = self.post_kl(latents.movedim(-2, -1))
576
+ latents = self.transformer(latents)
577
+
578
+ bounds = kwargs.get("bounds", 1.01)
579
+ num_chunks = kwargs.get("num_chunks", 8000)
580
+ octree_resolution = kwargs.get("octree_resolution", 256)
581
+ enable_pbar = kwargs.get("enable_pbar", True)
582
+
583
+ grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
584
+ return grid_logits.movedim(-2, -1)
585
+
586
+ def encode(self, x):
587
+ return None
ComfyUI/comfy/ldm/hunyuan_video/model.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Based on Flux code because of weird hunyuan video code license.
2
+
3
+ import torch
4
+ import comfy.ldm.flux.layers
5
+ import comfy.ldm.modules.diffusionmodules.mmdit
6
+ from comfy.ldm.modules.attention import optimized_attention
7
+
8
+
9
+ from dataclasses import dataclass
10
+ from einops import repeat
11
+
12
+ from torch import Tensor, nn
13
+
14
+ from comfy.ldm.flux.layers import (
15
+ DoubleStreamBlock,
16
+ EmbedND,
17
+ LastLayer,
18
+ MLPEmbedder,
19
+ SingleStreamBlock,
20
+ timestep_embedding
21
+ )
22
+
23
+ import comfy.ldm.common_dit
24
+
25
+
26
+ @dataclass
27
+ class HunyuanVideoParams:
28
+ in_channels: int
29
+ out_channels: int
30
+ vec_in_dim: int
31
+ context_in_dim: int
32
+ hidden_size: int
33
+ mlp_ratio: float
34
+ num_heads: int
35
+ depth: int
36
+ depth_single_blocks: int
37
+ axes_dim: list
38
+ theta: int
39
+ patch_size: list
40
+ qkv_bias: bool
41
+ guidance_embed: bool
42
+
43
+
44
+ class SelfAttentionRef(nn.Module):
45
+ def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
48
+ self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
49
+
50
+
51
+ class TokenRefinerBlock(nn.Module):
52
+ def __init__(
53
+ self,
54
+ hidden_size,
55
+ heads,
56
+ dtype=None,
57
+ device=None,
58
+ operations=None
59
+ ):
60
+ super().__init__()
61
+ self.heads = heads
62
+ mlp_hidden_dim = hidden_size * 4
63
+
64
+ self.adaLN_modulation = nn.Sequential(
65
+ nn.SiLU(),
66
+ operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
67
+ )
68
+
69
+ self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
70
+ self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
71
+
72
+ self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
73
+
74
+ self.mlp = nn.Sequential(
75
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
76
+ nn.SiLU(),
77
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
78
+ )
79
+
80
+ def forward(self, x, c, mask):
81
+ mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
82
+
83
+ norm_x = self.norm1(x)
84
+ qkv = self.self_attn.qkv(norm_x)
85
+ q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
86
+ attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
87
+
88
+ x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
89
+ x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
90
+ return x
91
+
92
+
93
+ class IndividualTokenRefiner(nn.Module):
94
+ def __init__(
95
+ self,
96
+ hidden_size,
97
+ heads,
98
+ num_blocks,
99
+ dtype=None,
100
+ device=None,
101
+ operations=None
102
+ ):
103
+ super().__init__()
104
+ self.blocks = nn.ModuleList(
105
+ [
106
+ TokenRefinerBlock(
107
+ hidden_size=hidden_size,
108
+ heads=heads,
109
+ dtype=dtype,
110
+ device=device,
111
+ operations=operations
112
+ )
113
+ for _ in range(num_blocks)
114
+ ]
115
+ )
116
+
117
+ def forward(self, x, c, mask):
118
+ m = None
119
+ if mask is not None:
120
+ m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
121
+ m = m + m.transpose(2, 3)
122
+
123
+ for block in self.blocks:
124
+ x = block(x, c, m)
125
+ return x
126
+
127
+
128
+
129
+ class TokenRefiner(nn.Module):
130
+ def __init__(
131
+ self,
132
+ text_dim,
133
+ hidden_size,
134
+ heads,
135
+ num_blocks,
136
+ dtype=None,
137
+ device=None,
138
+ operations=None
139
+ ):
140
+ super().__init__()
141
+
142
+ self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
143
+ self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
144
+ self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
145
+ self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
146
+
147
+ def forward(
148
+ self,
149
+ x,
150
+ timesteps,
151
+ mask,
152
+ ):
153
+ t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
154
+ # m = mask.float().unsqueeze(-1)
155
+ # c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
156
+ c = x.sum(dim=1) / x.shape[1]
157
+
158
+ c = t + self.c_embedder(c.to(x.dtype))
159
+ x = self.input_embedder(x)
160
+ x = self.individual_token_refiner(x, c, mask)
161
+ return x
162
+
163
+ class HunyuanVideo(nn.Module):
164
+ """
165
+ Transformer model for flow matching on sequences.
166
+ """
167
+
168
+ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
169
+ super().__init__()
170
+ self.dtype = dtype
171
+ params = HunyuanVideoParams(**kwargs)
172
+ self.params = params
173
+ self.patch_size = params.patch_size
174
+ self.in_channels = params.in_channels
175
+ self.out_channels = params.out_channels
176
+ if params.hidden_size % params.num_heads != 0:
177
+ raise ValueError(
178
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
179
+ )
180
+ pe_dim = params.hidden_size // params.num_heads
181
+ if sum(params.axes_dim) != pe_dim:
182
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
183
+ self.hidden_size = params.hidden_size
184
+ self.num_heads = params.num_heads
185
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
186
+
187
+ self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
188
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
189
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
190
+ self.guidance_in = (
191
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
192
+ )
193
+
194
+ self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
195
+
196
+ self.double_blocks = nn.ModuleList(
197
+ [
198
+ DoubleStreamBlock(
199
+ self.hidden_size,
200
+ self.num_heads,
201
+ mlp_ratio=params.mlp_ratio,
202
+ qkv_bias=params.qkv_bias,
203
+ flipped_img_txt=True,
204
+ dtype=dtype, device=device, operations=operations
205
+ )
206
+ for _ in range(params.depth)
207
+ ]
208
+ )
209
+
210
+ self.single_blocks = nn.ModuleList(
211
+ [
212
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
213
+ for _ in range(params.depth_single_blocks)
214
+ ]
215
+ )
216
+
217
+ if final_layer:
218
+ self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
219
+
220
+ def forward_orig(
221
+ self,
222
+ img: Tensor,
223
+ img_ids: Tensor,
224
+ txt: Tensor,
225
+ txt_ids: Tensor,
226
+ txt_mask: Tensor,
227
+ timesteps: Tensor,
228
+ y: Tensor,
229
+ guidance: Tensor = None,
230
+ guiding_frame_index=None,
231
+ ref_latent=None,
232
+ control=None,
233
+ transformer_options={},
234
+ ) -> Tensor:
235
+ patches_replace = transformer_options.get("patches_replace", {})
236
+
237
+ initial_shape = list(img.shape)
238
+ # running on sequences img
239
+ img = self.img_in(img)
240
+ vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
241
+
242
+ if ref_latent is not None:
243
+ ref_latent_ids = self.img_ids(ref_latent)
244
+ ref_latent = self.img_in(ref_latent)
245
+ img = torch.cat([ref_latent, img], dim=-2)
246
+ ref_latent_ids[..., 0] = -1
247
+ ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
248
+ img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
249
+
250
+ if guiding_frame_index is not None:
251
+ token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
252
+ vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
253
+ vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
254
+ frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
255
+ modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
256
+ modulation_dims_txt = [(0, None, 1)]
257
+ else:
258
+ vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
259
+ modulation_dims = None
260
+ modulation_dims_txt = None
261
+
262
+ if self.params.guidance_embed:
263
+ if guidance is not None:
264
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
265
+
266
+ if txt_mask is not None and not torch.is_floating_point(txt_mask):
267
+ txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
268
+
269
+ txt = self.txt_in(txt, timesteps, txt_mask)
270
+
271
+ ids = torch.cat((img_ids, txt_ids), dim=1)
272
+ pe = self.pe_embedder(ids)
273
+
274
+ img_len = img.shape[1]
275
+ if txt_mask is not None:
276
+ attn_mask_len = img_len + txt.shape[1]
277
+ attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
278
+ attn_mask[:, 0, img_len:] = txt_mask
279
+ else:
280
+ attn_mask = None
281
+
282
+ blocks_replace = patches_replace.get("dit", {})
283
+ for i, block in enumerate(self.double_blocks):
284
+ if ("double_block", i) in blocks_replace:
285
+ def block_wrap(args):
286
+ out = {}
287
+ out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
288
+ return out
289
+
290
+ out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
291
+ txt = out["txt"]
292
+ img = out["img"]
293
+ else:
294
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
295
+
296
+ if control is not None: # Controlnet
297
+ control_i = control.get("input")
298
+ if i < len(control_i):
299
+ add = control_i[i]
300
+ if add is not None:
301
+ img += add
302
+
303
+ img = torch.cat((img, txt), 1)
304
+
305
+ for i, block in enumerate(self.single_blocks):
306
+ if ("single_block", i) in blocks_replace:
307
+ def block_wrap(args):
308
+ out = {}
309
+ out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
310
+ return out
311
+
312
+ out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
313
+ img = out["img"]
314
+ else:
315
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
316
+
317
+ if control is not None: # Controlnet
318
+ control_o = control.get("output")
319
+ if i < len(control_o):
320
+ add = control_o[i]
321
+ if add is not None:
322
+ img[:, : img_len] += add
323
+
324
+ img = img[:, : img_len]
325
+ if ref_latent is not None:
326
+ img = img[:, ref_latent.shape[1]:]
327
+
328
+ img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
329
+
330
+ shape = initial_shape[-3:]
331
+ for i in range(len(shape)):
332
+ shape[i] = shape[i] // self.patch_size[i]
333
+ img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
334
+ img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
335
+ img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
336
+ return img
337
+
338
+ def img_ids(self, x):
339
+ bs, c, t, h, w = x.shape
340
+ patch_size = self.patch_size
341
+ t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
342
+ h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
343
+ w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
344
+ img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
345
+ img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
346
+ img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
347
+ img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
348
+ return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
349
+
350
+ def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
351
+ bs, c, t, h, w = x.shape
352
+ img_ids = self.img_ids(x)
353
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
354
+ out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
355
+ return out
ComfyUI/comfy/ldm/hydit/attn_layers.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Tuple, Union, Optional
4
+ from comfy.ldm.modules.attention import optimized_attention
5
+
6
+
7
+ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
8
+ """
9
+ Reshape frequency tensor for broadcasting it with another tensor.
10
+
11
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
12
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
13
+
14
+ Args:
15
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
16
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
17
+ head_first (bool): head dimension first (except batch dim) or not.
18
+
19
+ Returns:
20
+ torch.Tensor: Reshaped frequency tensor.
21
+
22
+ Raises:
23
+ AssertionError: If the frequency tensor doesn't match the expected shape.
24
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
25
+ """
26
+ ndim = x.ndim
27
+ assert 0 <= 1 < ndim
28
+
29
+ if isinstance(freqs_cis, tuple):
30
+ # freqs_cis: (cos, sin) in real space
31
+ if head_first:
32
+ assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
33
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
34
+ else:
35
+ assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
36
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
37
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
38
+ else:
39
+ # freqs_cis: values in complex space
40
+ if head_first:
41
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
42
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
43
+ else:
44
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
45
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
46
+ return freqs_cis.view(*shape)
47
+
48
+
49
+ def rotate_half(x):
50
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
51
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
52
+
53
+
54
+ def apply_rotary_emb(
55
+ xq: torch.Tensor,
56
+ xk: Optional[torch.Tensor],
57
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
58
+ head_first: bool = False,
59
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
60
+ """
61
+ Apply rotary embeddings to input tensors using the given frequency tensor.
62
+
63
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
64
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
65
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
66
+ returned as real tensors.
67
+
68
+ Args:
69
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
70
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
71
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
72
+ head_first (bool): head dimension first (except batch dim) or not.
73
+
74
+ Returns:
75
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
76
+
77
+ """
78
+ xk_out = None
79
+ if isinstance(freqs_cis, tuple):
80
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
81
+ xq_out = (xq * cos + rotate_half(xq) * sin)
82
+ if xk is not None:
83
+ xk_out = (xk * cos + rotate_half(xk) * sin)
84
+ else:
85
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
86
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
87
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
88
+ if xk is not None:
89
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
90
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
91
+
92
+ return xq_out, xk_out
93
+
94
+
95
+
96
+ class CrossAttention(nn.Module):
97
+ """
98
+ Use QK Normalization.
99
+ """
100
+ def __init__(self,
101
+ qdim,
102
+ kdim,
103
+ num_heads,
104
+ qkv_bias=True,
105
+ qk_norm=False,
106
+ attn_drop=0.0,
107
+ proj_drop=0.0,
108
+ attn_precision=None,
109
+ device=None,
110
+ dtype=None,
111
+ operations=None,
112
+ ):
113
+ factory_kwargs = {'device': device, 'dtype': dtype}
114
+ super().__init__()
115
+ self.attn_precision = attn_precision
116
+ self.qdim = qdim
117
+ self.kdim = kdim
118
+ self.num_heads = num_heads
119
+ assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
120
+ self.head_dim = self.qdim // num_heads
121
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
122
+ self.scale = self.head_dim ** -0.5
123
+
124
+ self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
125
+ self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
126
+
127
+ # TODO: eps should be 1 / 65530 if using fp16
128
+ self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
129
+ self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+
134
+ def forward(self, x, y, freqs_cis_img=None):
135
+ """
136
+ Parameters
137
+ ----------
138
+ x: torch.Tensor
139
+ (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
140
+ y: torch.Tensor
141
+ (batch, seqlen2, hidden_dim2)
142
+ freqs_cis_img: torch.Tensor
143
+ (batch, hidden_dim // 2), RoPE for image
144
+ """
145
+ b, s1, c = x.shape # [b, s1, D]
146
+ _, s2, c = y.shape # [b, s2, 1024]
147
+
148
+ q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
149
+ kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
150
+ k, v = kv.unbind(dim=2) # [b, s, h, d]
151
+ q = self.q_norm(q)
152
+ k = self.k_norm(k)
153
+
154
+ # Apply RoPE if needed
155
+ if freqs_cis_img is not None:
156
+ qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
157
+ assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
158
+ q = qq
159
+
160
+ q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
161
+ k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
162
+ v = v.transpose(-2, -3).contiguous()
163
+
164
+ context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
165
+
166
+ out = self.out_proj(context) # context.reshape - B, L1, -1
167
+ out = self.proj_drop(out)
168
+
169
+ out_tuple = (out,)
170
+
171
+ return out_tuple
172
+
173
+
174
+ class Attention(nn.Module):
175
+ """
176
+ We rename some layer names to align with flash attention
177
+ """
178
+ def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
179
+ super().__init__()
180
+ self.attn_precision = attn_precision
181
+ self.dim = dim
182
+ self.num_heads = num_heads
183
+ assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
184
+ self.head_dim = self.dim // num_heads
185
+ # This assertion is aligned with flash attention
186
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
187
+ self.scale = self.head_dim ** -0.5
188
+
189
+ # qkv --> Wqkv
190
+ self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
191
+ # TODO: eps should be 1 / 65530 if using fp16
192
+ self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
193
+ self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
194
+ self.attn_drop = nn.Dropout(attn_drop)
195
+ self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
196
+ self.proj_drop = nn.Dropout(proj_drop)
197
+
198
+ def forward(self, x, freqs_cis_img=None):
199
+ B, N, C = x.shape
200
+ qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
201
+ q, k, v = qkv.unbind(0) # [b, h, s, d]
202
+ q = self.q_norm(q) # [b, h, s, d]
203
+ k = self.k_norm(k) # [b, h, s, d]
204
+
205
+ # Apply RoPE if needed
206
+ if freqs_cis_img is not None:
207
+ qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
208
+ assert qq.shape == q.shape and kk.shape == k.shape, \
209
+ f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
210
+ q, k = qq, kk
211
+
212
+ x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
213
+ x = self.out_proj(x)
214
+ x = self.proj_drop(x)
215
+
216
+ out_tuple = (x,)
217
+
218
+ return out_tuple
ComfyUI/comfy/ldm/hydit/controlnet.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ from comfy.ldm.modules.diffusionmodules.mmdit import (
7
+ TimestepEmbedder,
8
+ PatchEmbed,
9
+ )
10
+ from .poolers import AttentionPool
11
+
12
+ import comfy.latent_formats
13
+ from .models import HunYuanDiTBlock, calc_rope
14
+
15
+
16
+
17
+ class HunYuanControlNet(nn.Module):
18
+ """
19
+ HunYuanDiT: Diffusion model with a Transformer backbone.
20
+
21
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
22
+
23
+ Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
24
+
25
+ Parameters
26
+ ----------
27
+ args: argparse.Namespace
28
+ The arguments parsed by argparse.
29
+ input_size: tuple
30
+ The size of the input image.
31
+ patch_size: int
32
+ The size of the patch.
33
+ in_channels: int
34
+ The number of input channels.
35
+ hidden_size: int
36
+ The hidden size of the transformer backbone.
37
+ depth: int
38
+ The number of transformer blocks.
39
+ num_heads: int
40
+ The number of attention heads.
41
+ mlp_ratio: float
42
+ The ratio of the hidden size of the MLP in the transformer block.
43
+ log_fn: callable
44
+ The logging function.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ input_size: tuple = 128,
50
+ patch_size: int = 2,
51
+ in_channels: int = 4,
52
+ hidden_size: int = 1408,
53
+ depth: int = 40,
54
+ num_heads: int = 16,
55
+ mlp_ratio: float = 4.3637,
56
+ text_states_dim=1024,
57
+ text_states_dim_t5=2048,
58
+ text_len=77,
59
+ text_len_t5=256,
60
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
61
+ size_cond=False,
62
+ use_style_cond=False,
63
+ learn_sigma=True,
64
+ norm="layer",
65
+ log_fn: callable = print,
66
+ attn_precision=None,
67
+ dtype=None,
68
+ device=None,
69
+ operations=None,
70
+ **kwargs,
71
+ ):
72
+ super().__init__()
73
+ self.log_fn = log_fn
74
+ self.depth = depth
75
+ self.learn_sigma = learn_sigma
76
+ self.in_channels = in_channels
77
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
78
+ self.patch_size = patch_size
79
+ self.num_heads = num_heads
80
+ self.hidden_size = hidden_size
81
+ self.text_states_dim = text_states_dim
82
+ self.text_states_dim_t5 = text_states_dim_t5
83
+ self.text_len = text_len
84
+ self.text_len_t5 = text_len_t5
85
+ self.size_cond = size_cond
86
+ self.use_style_cond = use_style_cond
87
+ self.norm = norm
88
+ self.dtype = dtype
89
+ self.latent_format = comfy.latent_formats.SDXL
90
+
91
+ self.mlp_t5 = nn.Sequential(
92
+ nn.Linear(
93
+ self.text_states_dim_t5,
94
+ self.text_states_dim_t5 * 4,
95
+ bias=True,
96
+ dtype=dtype,
97
+ device=device,
98
+ ),
99
+ nn.SiLU(),
100
+ nn.Linear(
101
+ self.text_states_dim_t5 * 4,
102
+ self.text_states_dim,
103
+ bias=True,
104
+ dtype=dtype,
105
+ device=device,
106
+ ),
107
+ )
108
+ # learnable replace
109
+ self.text_embedding_padding = nn.Parameter(
110
+ torch.randn(
111
+ self.text_len + self.text_len_t5,
112
+ self.text_states_dim,
113
+ dtype=dtype,
114
+ device=device,
115
+ )
116
+ )
117
+
118
+ # Attention pooling
119
+ pooler_out_dim = 1024
120
+ self.pooler = AttentionPool(
121
+ self.text_len_t5,
122
+ self.text_states_dim_t5,
123
+ num_heads=8,
124
+ output_dim=pooler_out_dim,
125
+ dtype=dtype,
126
+ device=device,
127
+ operations=operations,
128
+ )
129
+
130
+ # Dimension of the extra input vectors
131
+ self.extra_in_dim = pooler_out_dim
132
+
133
+ if self.size_cond:
134
+ # Image size and crop size conditions
135
+ self.extra_in_dim += 6 * 256
136
+
137
+ if self.use_style_cond:
138
+ # Here we use a default learned embedder layer for future extension.
139
+ self.style_embedder = nn.Embedding(
140
+ 1, hidden_size, dtype=dtype, device=device
141
+ )
142
+ self.extra_in_dim += hidden_size
143
+
144
+ # Text embedding for `add`
145
+ self.x_embedder = PatchEmbed(
146
+ input_size,
147
+ patch_size,
148
+ in_channels,
149
+ hidden_size,
150
+ dtype=dtype,
151
+ device=device,
152
+ operations=operations,
153
+ )
154
+ self.t_embedder = TimestepEmbedder(
155
+ hidden_size, dtype=dtype, device=device, operations=operations
156
+ )
157
+ self.extra_embedder = nn.Sequential(
158
+ operations.Linear(
159
+ self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
160
+ ),
161
+ nn.SiLU(),
162
+ operations.Linear(
163
+ hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
164
+ ),
165
+ )
166
+
167
+ # HUnYuanDiT Blocks
168
+ self.blocks = nn.ModuleList(
169
+ [
170
+ HunYuanDiTBlock(
171
+ hidden_size=hidden_size,
172
+ c_emb_size=hidden_size,
173
+ num_heads=num_heads,
174
+ mlp_ratio=mlp_ratio,
175
+ text_states_dim=self.text_states_dim,
176
+ qk_norm=qk_norm,
177
+ norm_type=self.norm,
178
+ skip=False,
179
+ attn_precision=attn_precision,
180
+ dtype=dtype,
181
+ device=device,
182
+ operations=operations,
183
+ )
184
+ for _ in range(19)
185
+ ]
186
+ )
187
+
188
+ # Input zero linear for the first block
189
+ self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
190
+
191
+
192
+ # Output zero linear for the every block
193
+ self.after_proj_list = nn.ModuleList(
194
+ [
195
+
196
+ operations.Linear(
197
+ self.hidden_size, self.hidden_size, dtype=dtype, device=device
198
+ )
199
+ for _ in range(len(self.blocks))
200
+ ]
201
+ )
202
+
203
+ def forward(
204
+ self,
205
+ x,
206
+ hint,
207
+ timesteps,
208
+ context,#encoder_hidden_states=None,
209
+ text_embedding_mask=None,
210
+ encoder_hidden_states_t5=None,
211
+ text_embedding_mask_t5=None,
212
+ image_meta_size=None,
213
+ style=None,
214
+ return_dict=False,
215
+ **kwarg,
216
+ ):
217
+ """
218
+ Forward pass of the encoder.
219
+
220
+ Parameters
221
+ ----------
222
+ x: torch.Tensor
223
+ (B, D, H, W)
224
+ t: torch.Tensor
225
+ (B)
226
+ encoder_hidden_states: torch.Tensor
227
+ CLIP text embedding, (B, L_clip, D)
228
+ text_embedding_mask: torch.Tensor
229
+ CLIP text embedding mask, (B, L_clip)
230
+ encoder_hidden_states_t5: torch.Tensor
231
+ T5 text embedding, (B, L_t5, D)
232
+ text_embedding_mask_t5: torch.Tensor
233
+ T5 text embedding mask, (B, L_t5)
234
+ image_meta_size: torch.Tensor
235
+ (B, 6)
236
+ style: torch.Tensor
237
+ (B)
238
+ cos_cis_img: torch.Tensor
239
+ sin_cis_img: torch.Tensor
240
+ return_dict: bool
241
+ Whether to return a dictionary.
242
+ """
243
+ condition = hint
244
+ if condition.shape[0] == 1:
245
+ condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
246
+
247
+ text_states = context # 2,77,1024
248
+ text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
249
+ text_states_mask = text_embedding_mask.bool() # 2,77
250
+ text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
251
+ b_t5, l_t5, c_t5 = text_states_t5.shape
252
+ text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
253
+
254
+ padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
255
+
256
+ text_states[:, -self.text_len :] = torch.where(
257
+ text_states_mask[:, -self.text_len :].unsqueeze(2),
258
+ text_states[:, -self.text_len :],
259
+ padding[: self.text_len],
260
+ )
261
+ text_states_t5[:, -self.text_len_t5 :] = torch.where(
262
+ text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
263
+ text_states_t5[:, -self.text_len_t5 :],
264
+ padding[self.text_len :],
265
+ )
266
+
267
+ text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
268
+
269
+ # _, _, oh, ow = x.shape
270
+ # th, tw = oh // self.patch_size, ow // self.patch_size
271
+
272
+ # Get image RoPE embedding according to `reso`lution.
273
+ freqs_cis_img = calc_rope(
274
+ x, self.patch_size, self.hidden_size // self.num_heads
275
+ ) # (cos_cis_img, sin_cis_img)
276
+
277
+ # ========================= Build time and image embedding =========================
278
+ t = self.t_embedder(timesteps, dtype=self.dtype)
279
+ x = self.x_embedder(x)
280
+
281
+ # ========================= Concatenate all extra vectors =========================
282
+ # Build text tokens with pooling
283
+ extra_vec = self.pooler(encoder_hidden_states_t5)
284
+
285
+ # Build image meta size tokens if applicable
286
+ # if image_meta_size is not None:
287
+ # image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
288
+ # if image_meta_size.dtype != self.dtype:
289
+ # image_meta_size = image_meta_size.half()
290
+ # image_meta_size = image_meta_size.view(-1, 6 * 256)
291
+ # extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
292
+
293
+ # Build style tokens
294
+ if style is not None:
295
+ style_embedding = self.style_embedder(style)
296
+ extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
297
+
298
+ # Concatenate all extra vectors
299
+ c = t + self.extra_embedder(extra_vec) # [B, D]
300
+
301
+ # ========================= Deal with Condition =========================
302
+ condition = self.x_embedder(condition)
303
+
304
+ # ========================= Forward pass through HunYuanDiT blocks =========================
305
+ controls = []
306
+ x = x + self.before_proj(condition) # add condition
307
+ for layer, block in enumerate(self.blocks):
308
+ x = block(x, c, text_states, freqs_cis_img)
309
+ controls.append(self.after_proj_list[layer](x)) # zero linear for output
310
+
311
+ return {"output": controls}
ComfyUI/comfy/ldm/hydit/models.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import comfy.ops
6
+ from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
7
+ from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
8
+ from torch.utils import checkpoint
9
+
10
+ from .attn_layers import Attention, CrossAttention
11
+ from .poolers import AttentionPool
12
+ from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
13
+
14
+ def calc_rope(x, patch_size, head_size):
15
+ th = (x.shape[2] + (patch_size // 2)) // patch_size
16
+ tw = (x.shape[3] + (patch_size // 2)) // patch_size
17
+ base_size = 512 // 8 // patch_size
18
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
19
+ sub_args = [start, stop, (th, tw)]
20
+ # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
21
+ rope = get_2d_rotary_pos_embed(head_size, *sub_args)
22
+ rope = (rope[0].to(x), rope[1].to(x))
23
+ return rope
24
+
25
+
26
+ def modulate(x, shift, scale):
27
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
28
+
29
+
30
+ class HunYuanDiTBlock(nn.Module):
31
+ """
32
+ A HunYuanDiT block with `add` conditioning.
33
+ """
34
+ def __init__(self,
35
+ hidden_size,
36
+ c_emb_size,
37
+ num_heads,
38
+ mlp_ratio=4.0,
39
+ text_states_dim=1024,
40
+ qk_norm=False,
41
+ norm_type="layer",
42
+ skip=False,
43
+ attn_precision=None,
44
+ dtype=None,
45
+ device=None,
46
+ operations=None,
47
+ ):
48
+ super().__init__()
49
+ use_ele_affine = True
50
+
51
+ if norm_type == "layer":
52
+ norm_layer = operations.LayerNorm
53
+ elif norm_type == "rms":
54
+ norm_layer = operations.RMSNorm
55
+ else:
56
+ raise ValueError(f"Unknown norm_type: {norm_type}")
57
+
58
+ # ========================= Self-Attention =========================
59
+ self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
60
+ self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
61
+
62
+ # ========================= FFN =========================
63
+ self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
64
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
65
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
66
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
67
+
68
+ # ========================= Add =========================
69
+ # Simply use add like SDXL.
70
+ self.default_modulation = nn.Sequential(
71
+ nn.SiLU(),
72
+ operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
73
+ )
74
+
75
+ # ========================= Cross-Attention =========================
76
+ self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
77
+ qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
78
+ self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
79
+
80
+ # ========================= Skip Connection =========================
81
+ if skip:
82
+ self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
83
+ self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
84
+ else:
85
+ self.skip_linear = None
86
+
87
+ self.gradient_checkpointing = False
88
+
89
+ def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
90
+ # Long Skip Connection
91
+ if self.skip_linear is not None:
92
+ cat = torch.cat([x, skip], dim=-1)
93
+ if cat.dtype != x.dtype:
94
+ cat = cat.to(x.dtype)
95
+ cat = self.skip_norm(cat)
96
+ x = self.skip_linear(cat)
97
+
98
+ # Self-Attention
99
+ shift_msa = self.default_modulation(c).unsqueeze(dim=1)
100
+ attn_inputs = (
101
+ self.norm1(x) + shift_msa, freq_cis_img,
102
+ )
103
+ x = x + self.attn1(*attn_inputs)[0]
104
+
105
+ # Cross-Attention
106
+ cross_inputs = (
107
+ self.norm3(x), text_states, freq_cis_img
108
+ )
109
+ x = x + self.attn2(*cross_inputs)[0]
110
+
111
+ # FFN Layer
112
+ mlp_inputs = self.norm2(x)
113
+ x = x + self.mlp(mlp_inputs)
114
+
115
+ return x
116
+
117
+ def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
118
+ if self.gradient_checkpointing and self.training:
119
+ return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
120
+ return self._forward(x, c, text_states, freq_cis_img, skip)
121
+
122
+
123
+ class FinalLayer(nn.Module):
124
+ """
125
+ The final layer of HunYuanDiT.
126
+ """
127
+ def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
128
+ super().__init__()
129
+ self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
130
+ self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
131
+ self.adaLN_modulation = nn.Sequential(
132
+ nn.SiLU(),
133
+ operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
134
+ )
135
+
136
+ def forward(self, x, c):
137
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
138
+ x = modulate(self.norm_final(x), shift, scale)
139
+ x = self.linear(x)
140
+ return x
141
+
142
+
143
+ class HunYuanDiT(nn.Module):
144
+ """
145
+ HunYuanDiT: Diffusion model with a Transformer backbone.
146
+
147
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
148
+
149
+ Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
150
+
151
+ Parameters
152
+ ----------
153
+ args: argparse.Namespace
154
+ The arguments parsed by argparse.
155
+ input_size: tuple
156
+ The size of the input image.
157
+ patch_size: int
158
+ The size of the patch.
159
+ in_channels: int
160
+ The number of input channels.
161
+ hidden_size: int
162
+ The hidden size of the transformer backbone.
163
+ depth: int
164
+ The number of transformer blocks.
165
+ num_heads: int
166
+ The number of attention heads.
167
+ mlp_ratio: float
168
+ The ratio of the hidden size of the MLP in the transformer block.
169
+ log_fn: callable
170
+ The logging function.
171
+ """
172
+ #@register_to_config
173
+ def __init__(self,
174
+ input_size: tuple = 32,
175
+ patch_size: int = 2,
176
+ in_channels: int = 4,
177
+ hidden_size: int = 1152,
178
+ depth: int = 28,
179
+ num_heads: int = 16,
180
+ mlp_ratio: float = 4.0,
181
+ text_states_dim = 1024,
182
+ text_states_dim_t5 = 2048,
183
+ text_len = 77,
184
+ text_len_t5 = 256,
185
+ qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
186
+ size_cond = False,
187
+ use_style_cond = False,
188
+ learn_sigma = True,
189
+ norm = "layer",
190
+ log_fn: callable = print,
191
+ attn_precision=None,
192
+ dtype=None,
193
+ device=None,
194
+ operations=None,
195
+ **kwargs,
196
+ ):
197
+ super().__init__()
198
+ self.log_fn = log_fn
199
+ self.depth = depth
200
+ self.learn_sigma = learn_sigma
201
+ self.in_channels = in_channels
202
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
203
+ self.patch_size = patch_size
204
+ self.num_heads = num_heads
205
+ self.hidden_size = hidden_size
206
+ self.text_states_dim = text_states_dim
207
+ self.text_states_dim_t5 = text_states_dim_t5
208
+ self.text_len = text_len
209
+ self.text_len_t5 = text_len_t5
210
+ self.size_cond = size_cond
211
+ self.use_style_cond = use_style_cond
212
+ self.norm = norm
213
+ self.dtype = dtype
214
+ #import pdb
215
+ #pdb.set_trace()
216
+
217
+ self.mlp_t5 = nn.Sequential(
218
+ operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
219
+ nn.SiLU(),
220
+ operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
221
+ )
222
+ # learnable replace
223
+ self.text_embedding_padding = nn.Parameter(
224
+ torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
225
+
226
+ # Attention pooling
227
+ pooler_out_dim = 1024
228
+ self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
229
+
230
+ # Dimension of the extra input vectors
231
+ self.extra_in_dim = pooler_out_dim
232
+
233
+ if self.size_cond:
234
+ # Image size and crop size conditions
235
+ self.extra_in_dim += 6 * 256
236
+
237
+ if self.use_style_cond:
238
+ # Here we use a default learned embedder layer for future extension.
239
+ self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
240
+ self.extra_in_dim += hidden_size
241
+
242
+ # Text embedding for `add`
243
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
244
+ self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
245
+ self.extra_embedder = nn.Sequential(
246
+ operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
247
+ nn.SiLU(),
248
+ operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
249
+ )
250
+
251
+ # HUnYuanDiT Blocks
252
+ self.blocks = nn.ModuleList([
253
+ HunYuanDiTBlock(hidden_size=hidden_size,
254
+ c_emb_size=hidden_size,
255
+ num_heads=num_heads,
256
+ mlp_ratio=mlp_ratio,
257
+ text_states_dim=self.text_states_dim,
258
+ qk_norm=qk_norm,
259
+ norm_type=self.norm,
260
+ skip=layer > depth // 2,
261
+ attn_precision=attn_precision,
262
+ dtype=dtype,
263
+ device=device,
264
+ operations=operations,
265
+ )
266
+ for layer in range(depth)
267
+ ])
268
+
269
+ self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
270
+ self.unpatchify_channels = self.out_channels
271
+
272
+
273
+
274
+ def forward(self,
275
+ x,
276
+ t,
277
+ context,#encoder_hidden_states=None,
278
+ text_embedding_mask=None,
279
+ encoder_hidden_states_t5=None,
280
+ text_embedding_mask_t5=None,
281
+ image_meta_size=None,
282
+ style=None,
283
+ return_dict=False,
284
+ control=None,
285
+ transformer_options={},
286
+ ):
287
+ """
288
+ Forward pass of the encoder.
289
+
290
+ Parameters
291
+ ----------
292
+ x: torch.Tensor
293
+ (B, D, H, W)
294
+ t: torch.Tensor
295
+ (B)
296
+ encoder_hidden_states: torch.Tensor
297
+ CLIP text embedding, (B, L_clip, D)
298
+ text_embedding_mask: torch.Tensor
299
+ CLIP text embedding mask, (B, L_clip)
300
+ encoder_hidden_states_t5: torch.Tensor
301
+ T5 text embedding, (B, L_t5, D)
302
+ text_embedding_mask_t5: torch.Tensor
303
+ T5 text embedding mask, (B, L_t5)
304
+ image_meta_size: torch.Tensor
305
+ (B, 6)
306
+ style: torch.Tensor
307
+ (B)
308
+ cos_cis_img: torch.Tensor
309
+ sin_cis_img: torch.Tensor
310
+ return_dict: bool
311
+ Whether to return a dictionary.
312
+ """
313
+ patches_replace = transformer_options.get("patches_replace", {})
314
+ encoder_hidden_states = context
315
+ text_states = encoder_hidden_states # 2,77,1024
316
+ text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
317
+ text_states_mask = text_embedding_mask.bool() # 2,77
318
+ text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
319
+ b_t5, l_t5, c_t5 = text_states_t5.shape
320
+ text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
321
+
322
+ padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
323
+
324
+ text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
325
+ text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
326
+
327
+ text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
328
+ # clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
329
+
330
+ _, _, oh, ow = x.shape
331
+ th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
332
+
333
+
334
+ # Get image RoPE embedding according to `reso`lution.
335
+ freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
336
+
337
+ # ========================= Build time and image embedding =========================
338
+ t = self.t_embedder(t, dtype=x.dtype)
339
+ x = self.x_embedder(x)
340
+
341
+ # ========================= Concatenate all extra vectors =========================
342
+ # Build text tokens with pooling
343
+ extra_vec = self.pooler(encoder_hidden_states_t5)
344
+
345
+ # Build image meta size tokens if applicable
346
+ if self.size_cond:
347
+ image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
348
+ image_meta_size = image_meta_size.view(-1, 6 * 256)
349
+ extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
350
+
351
+ # Build style tokens
352
+ if self.use_style_cond:
353
+ if style is None:
354
+ style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
355
+ style_embedding = self.style_embedder(style, out_dtype=x.dtype)
356
+ extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
357
+
358
+ # Concatenate all extra vectors
359
+ c = t + self.extra_embedder(extra_vec) # [B, D]
360
+
361
+ blocks_replace = patches_replace.get("dit", {})
362
+
363
+ controls = None
364
+ if control:
365
+ controls = control.get("output", None)
366
+ # ========================= Forward pass through HunYuanDiT blocks =========================
367
+ skips = []
368
+ for layer, block in enumerate(self.blocks):
369
+ if layer > self.depth // 2:
370
+ if controls is not None:
371
+ skip = skips.pop() + controls.pop().to(dtype=x.dtype)
372
+ else:
373
+ skip = skips.pop()
374
+ else:
375
+ skip = None
376
+
377
+ if ("double_block", layer) in blocks_replace:
378
+ def block_wrap(args):
379
+ out = {}
380
+ out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
381
+ return out
382
+
383
+ out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
384
+ x = out["img"]
385
+ else:
386
+ x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
387
+
388
+
389
+ if layer < (self.depth // 2 - 1):
390
+ skips.append(x)
391
+ if controls is not None and len(controls) != 0:
392
+ raise ValueError("The number of controls is not equal to the number of skip connections.")
393
+
394
+ # ========================= Final layer =========================
395
+ x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
396
+ x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
397
+
398
+ if return_dict:
399
+ return {'x': x}
400
+ if self.learn_sigma:
401
+ return x[:,:self.out_channels // 2,:oh,:ow]
402
+ return x[:,:,:oh,:ow]
403
+
404
+ def unpatchify(self, x, h, w):
405
+ """
406
+ x: (N, T, patch_size**2 * C)
407
+ imgs: (N, H, W, C)
408
+ """
409
+ c = self.unpatchify_channels
410
+ p = self.x_embedder.patch_size[0]
411
+ # h = w = int(x.shape[1] ** 0.5)
412
+ assert h * w == x.shape[1]
413
+
414
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
415
+ x = torch.einsum('nhwpqc->nchpwq', x)
416
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
417
+ return imgs
ComfyUI/comfy/ldm/hydit/poolers.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from comfy.ldm.modules.attention import optimized_attention
4
+ import comfy.ops
5
+
6
+ class AttentionPool(nn.Module):
7
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
8
+ super().__init__()
9
+ self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
10
+ self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
11
+ self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
12
+ self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
13
+ self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
14
+ self.num_heads = num_heads
15
+ self.embed_dim = embed_dim
16
+
17
+ def forward(self, x):
18
+ x = x[:,:self.positional_embedding.shape[0] - 1]
19
+ x = x.permute(1, 0, 2) # NLC -> LNC
20
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
21
+ x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
22
+
23
+ q = self.q_proj(x[:1])
24
+ k = self.k_proj(x)
25
+ v = self.v_proj(x)
26
+
27
+ batch_size = q.shape[1]
28
+ head_dim = self.embed_dim // self.num_heads
29
+ q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
30
+ k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
31
+ v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
32
+
33
+ attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
34
+
35
+ attn_output = self.c_proj(attn_output)
36
+ return attn_output.squeeze(0)
ComfyUI/comfy/ldm/hydit/posemb_layers.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Union
4
+
5
+
6
+ def _to_tuple(x):
7
+ if isinstance(x, int):
8
+ return x, x
9
+ else:
10
+ return x
11
+
12
+
13
+ def get_fill_resize_and_crop(src, tgt):
14
+ th, tw = _to_tuple(tgt)
15
+ h, w = _to_tuple(src)
16
+
17
+ tr = th / tw # base resolution
18
+ r = h / w # target resolution
19
+
20
+ # resize
21
+ if r > tr:
22
+ resize_height = th
23
+ resize_width = int(round(th / h * w))
24
+ else:
25
+ resize_width = tw
26
+ resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
27
+
28
+ crop_top = int(round((th - resize_height) / 2.0))
29
+ crop_left = int(round((tw - resize_width) / 2.0))
30
+
31
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
32
+
33
+
34
+ def get_meshgrid(start, *args):
35
+ if len(args) == 0:
36
+ # start is grid_size
37
+ num = _to_tuple(start)
38
+ start = (0, 0)
39
+ stop = num
40
+ elif len(args) == 1:
41
+ # start is start, args[0] is stop, step is 1
42
+ start = _to_tuple(start)
43
+ stop = _to_tuple(args[0])
44
+ num = (stop[0] - start[0], stop[1] - start[1])
45
+ elif len(args) == 2:
46
+ # start is start, args[0] is stop, args[1] is num
47
+ start = _to_tuple(start)
48
+ stop = _to_tuple(args[0])
49
+ num = _to_tuple(args[1])
50
+ else:
51
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
52
+
53
+ grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
54
+ grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
55
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
56
+ grid = np.stack(grid, axis=0) # [2, W, H]
57
+ return grid
58
+
59
+ #################################################################################
60
+ # Sine/Cosine Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
63
+
64
+ def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
65
+ """
66
+ grid_size: int of the grid height and width
67
+ return:
68
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
69
+ """
70
+ grid = get_meshgrid(start, *args) # [2, H, w]
71
+ # grid_h = np.arange(grid_size, dtype=np.float32)
72
+ # grid_w = np.arange(grid_size, dtype=np.float32)
73
+ # grid = np.meshgrid(grid_w, grid_h) # here w goes first
74
+ # grid = np.stack(grid, axis=0) # [2, W, H]
75
+
76
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
77
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
78
+ if cls_token and extra_tokens > 0:
79
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
80
+ return pos_embed
81
+
82
+
83
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
84
+ assert embed_dim % 2 == 0
85
+
86
+ # use half of dimensions to encode grid_h
87
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
88
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
89
+
90
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
91
+ return emb
92
+
93
+
94
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
95
+ """
96
+ embed_dim: output dimension for each position
97
+ pos: a list of positions to be encoded: size (W,H)
98
+ out: (M, D)
99
+ """
100
+ assert embed_dim % 2 == 0
101
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
102
+ omega /= embed_dim / 2.
103
+ omega = 1. / 10000**omega # (D/2,)
104
+
105
+ pos = pos.reshape(-1) # (M,)
106
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
107
+
108
+ emb_sin = np.sin(out) # (M, D/2)
109
+ emb_cos = np.cos(out) # (M, D/2)
110
+
111
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
112
+ return emb
113
+
114
+
115
+ #################################################################################
116
+ # Rotary Positional Embedding Functions #
117
+ #################################################################################
118
+ # https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
119
+
120
+ def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
121
+ """
122
+ This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
123
+
124
+ Parameters
125
+ ----------
126
+ embed_dim: int
127
+ embedding dimension size
128
+ start: int or tuple of int
129
+ If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
130
+ If len(args) == 2, start is start, args[0] is stop, args[1] is num.
131
+ use_real: bool
132
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
133
+
134
+ Returns
135
+ -------
136
+ pos_embed: torch.Tensor
137
+ [HW, D/2]
138
+ """
139
+ grid = get_meshgrid(start, *args) # [2, H, w]
140
+ grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
141
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
142
+ return pos_embed
143
+
144
+
145
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
146
+ assert embed_dim % 4 == 0
147
+
148
+ # use half of dimensions to encode grid_h
149
+ emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
150
+ emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
151
+
152
+ if use_real:
153
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
154
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
155
+ return cos, sin
156
+ else:
157
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
158
+ return emb
159
+
160
+
161
+ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
162
+ """
163
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
164
+
165
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
166
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
167
+ The returned tensor contains complex values in complex64 data type.
168
+
169
+ Args:
170
+ dim (int): Dimension of the frequency tensor.
171
+ pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
172
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
173
+ use_real (bool, optional): If True, return real part and imaginary part separately.
174
+ Otherwise, return complex numbers.
175
+
176
+ Returns:
177
+ torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
178
+
179
+ """
180
+ if isinstance(pos, int):
181
+ pos = np.arange(pos)
182
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
183
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
184
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
185
+ if use_real:
186
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
187
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
188
+ return freqs_cos, freqs_sin
189
+ else:
190
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
191
+ return freqs_cis
192
+
193
+
194
+
195
+ def calc_sizes(rope_img, patch_size, th, tw):
196
+ if rope_img == 'extend':
197
+ # Expansion mode
198
+ sub_args = [(th, tw)]
199
+ elif rope_img.startswith('base'):
200
+ # Based on the specified dimensions, other dimensions are obtained through interpolation.
201
+ base_size = int(rope_img[4:]) // 8 // patch_size
202
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
203
+ sub_args = [start, stop, (th, tw)]
204
+ else:
205
+ raise ValueError(f"Unknown rope_img: {rope_img}")
206
+ return sub_args
207
+
208
+
209
+ def init_image_posemb(rope_img,
210
+ resolutions,
211
+ patch_size,
212
+ hidden_size,
213
+ num_heads,
214
+ log_fn,
215
+ rope_real=True,
216
+ ):
217
+ freqs_cis_img = {}
218
+ for reso in resolutions:
219
+ th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
220
+ sub_args = calc_sizes(rope_img, patch_size, th, tw)
221
+ freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
222
+ log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
223
+ f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
224
+ return freqs_cis_img
ComfyUI/comfy/ldm/lightricks/model.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import comfy.ldm.modules.attention
4
+ import comfy.ldm.common_dit
5
+ from einops import rearrange
6
+ import math
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
10
+
11
+
12
+ def get_timestep_embedding(
13
+ timesteps: torch.Tensor,
14
+ embedding_dim: int,
15
+ flip_sin_to_cos: bool = False,
16
+ downscale_freq_shift: float = 1,
17
+ scale: float = 1,
18
+ max_period: int = 10000,
19
+ ):
20
+ """
21
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
22
+
23
+ Args
24
+ timesteps (torch.Tensor):
25
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
26
+ embedding_dim (int):
27
+ the dimension of the output.
28
+ flip_sin_to_cos (bool):
29
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
30
+ downscale_freq_shift (float):
31
+ Controls the delta between frequencies between dimensions
32
+ scale (float):
33
+ Scaling factor applied to the embeddings.
34
+ max_period (int):
35
+ Controls the maximum frequency of the embeddings
36
+ Returns
37
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
38
+ """
39
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
40
+
41
+ half_dim = embedding_dim // 2
42
+ exponent = -math.log(max_period) * torch.arange(
43
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
44
+ )
45
+ exponent = exponent / (half_dim - downscale_freq_shift)
46
+
47
+ emb = torch.exp(exponent)
48
+ emb = timesteps[:, None].float() * emb[None, :]
49
+
50
+ # scale embeddings
51
+ emb = scale * emb
52
+
53
+ # concat sine and cosine embeddings
54
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
55
+
56
+ # flip sine and cosine embeddings
57
+ if flip_sin_to_cos:
58
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
59
+
60
+ # zero pad
61
+ if embedding_dim % 2 == 1:
62
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
63
+ return emb
64
+
65
+
66
+ class TimestepEmbedding(nn.Module):
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ time_embed_dim: int,
71
+ act_fn: str = "silu",
72
+ out_dim: int = None,
73
+ post_act_fn: Optional[str] = None,
74
+ cond_proj_dim=None,
75
+ sample_proj_bias=True,
76
+ dtype=None, device=None, operations=None,
77
+ ):
78
+ super().__init__()
79
+
80
+ self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
81
+
82
+ if cond_proj_dim is not None:
83
+ self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
84
+ else:
85
+ self.cond_proj = None
86
+
87
+ self.act = nn.SiLU()
88
+
89
+ if out_dim is not None:
90
+ time_embed_dim_out = out_dim
91
+ else:
92
+ time_embed_dim_out = time_embed_dim
93
+ self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
94
+
95
+ if post_act_fn is None:
96
+ self.post_act = None
97
+ # else:
98
+ # self.post_act = get_activation(post_act_fn)
99
+
100
+ def forward(self, sample, condition=None):
101
+ if condition is not None:
102
+ sample = sample + self.cond_proj(condition)
103
+ sample = self.linear_1(sample)
104
+
105
+ if self.act is not None:
106
+ sample = self.act(sample)
107
+
108
+ sample = self.linear_2(sample)
109
+
110
+ if self.post_act is not None:
111
+ sample = self.post_act(sample)
112
+ return sample
113
+
114
+
115
+ class Timesteps(nn.Module):
116
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
117
+ super().__init__()
118
+ self.num_channels = num_channels
119
+ self.flip_sin_to_cos = flip_sin_to_cos
120
+ self.downscale_freq_shift = downscale_freq_shift
121
+ self.scale = scale
122
+
123
+ def forward(self, timesteps):
124
+ t_emb = get_timestep_embedding(
125
+ timesteps,
126
+ self.num_channels,
127
+ flip_sin_to_cos=self.flip_sin_to_cos,
128
+ downscale_freq_shift=self.downscale_freq_shift,
129
+ scale=self.scale,
130
+ )
131
+ return t_emb
132
+
133
+
134
+ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
135
+ """
136
+ For PixArt-Alpha.
137
+
138
+ Reference:
139
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
140
+ """
141
+
142
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
143
+ super().__init__()
144
+
145
+ self.outdim = size_emb_dim
146
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
147
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
148
+
149
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
150
+ timesteps_proj = self.time_proj(timestep)
151
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
152
+ return timesteps_emb
153
+
154
+
155
+ class AdaLayerNormSingle(nn.Module):
156
+ r"""
157
+ Norm layer adaptive layer norm single (adaLN-single).
158
+
159
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
160
+
161
+ Parameters:
162
+ embedding_dim (`int`): The size of each embedding vector.
163
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
164
+ """
165
+
166
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
167
+ super().__init__()
168
+
169
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
170
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
171
+ )
172
+
173
+ self.silu = nn.SiLU()
174
+ self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
175
+
176
+ def forward(
177
+ self,
178
+ timestep: torch.Tensor,
179
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
180
+ batch_size: Optional[int] = None,
181
+ hidden_dtype: Optional[torch.dtype] = None,
182
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
183
+ # No modulation happening here.
184
+ added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
185
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
186
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
187
+
188
+ class PixArtAlphaTextProjection(nn.Module):
189
+ """
190
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
191
+
192
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
193
+ """
194
+
195
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
196
+ super().__init__()
197
+ if out_features is None:
198
+ out_features = hidden_size
199
+ self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
200
+ if act_fn == "gelu_tanh":
201
+ self.act_1 = nn.GELU(approximate="tanh")
202
+ elif act_fn == "silu":
203
+ self.act_1 = nn.SiLU()
204
+ else:
205
+ raise ValueError(f"Unknown activation function: {act_fn}")
206
+ self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
207
+
208
+ def forward(self, caption):
209
+ hidden_states = self.linear_1(caption)
210
+ hidden_states = self.act_1(hidden_states)
211
+ hidden_states = self.linear_2(hidden_states)
212
+ return hidden_states
213
+
214
+
215
+ class GELU_approx(nn.Module):
216
+ def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
217
+ super().__init__()
218
+ self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
219
+
220
+ def forward(self, x):
221
+ return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
222
+
223
+
224
+ class FeedForward(nn.Module):
225
+ def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
226
+ super().__init__()
227
+ inner_dim = int(dim * mult)
228
+ project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
229
+
230
+ self.net = nn.Sequential(
231
+ project_in,
232
+ nn.Dropout(dropout),
233
+ operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
234
+ )
235
+
236
+ def forward(self, x):
237
+ return self.net(x)
238
+
239
+
240
+ def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
241
+ cos_freqs = freqs_cis[0]
242
+ sin_freqs = freqs_cis[1]
243
+
244
+ t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
245
+ t1, t2 = t_dup.unbind(dim=-1)
246
+ t_dup = torch.stack((-t2, t1), dim=-1)
247
+ input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
248
+
249
+ out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
250
+
251
+ return out
252
+
253
+
254
+ class CrossAttention(nn.Module):
255
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
256
+ super().__init__()
257
+ inner_dim = dim_head * heads
258
+ context_dim = query_dim if context_dim is None else context_dim
259
+ self.attn_precision = attn_precision
260
+
261
+ self.heads = heads
262
+ self.dim_head = dim_head
263
+
264
+ self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
265
+ self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
266
+
267
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
268
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
269
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
270
+
271
+ self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
272
+
273
+ def forward(self, x, context=None, mask=None, pe=None):
274
+ q = self.to_q(x)
275
+ context = x if context is None else context
276
+ k = self.to_k(context)
277
+ v = self.to_v(context)
278
+
279
+ q = self.q_norm(q)
280
+ k = self.k_norm(k)
281
+
282
+ if pe is not None:
283
+ q = apply_rotary_emb(q, pe)
284
+ k = apply_rotary_emb(k, pe)
285
+
286
+ if mask is None:
287
+ out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
288
+ else:
289
+ out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
290
+ return self.to_out(out)
291
+
292
+
293
+ class BasicTransformerBlock(nn.Module):
294
+ def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
295
+ super().__init__()
296
+
297
+ self.attn_precision = attn_precision
298
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
299
+ self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
300
+
301
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
302
+
303
+ self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
304
+
305
+ def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
306
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
307
+
308
+ x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
309
+
310
+ x += self.attn2(x, context=context, mask=attention_mask)
311
+
312
+ y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
313
+ x += self.ff(y) * gate_mlp
314
+
315
+ return x
316
+
317
+ def get_fractional_positions(indices_grid, max_pos):
318
+ fractional_positions = torch.stack(
319
+ [
320
+ indices_grid[:, i] / max_pos[i]
321
+ for i in range(3)
322
+ ],
323
+ dim=-1,
324
+ )
325
+ return fractional_positions
326
+
327
+
328
+ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
329
+ dtype = torch.float32 #self.dtype
330
+
331
+ fractional_positions = get_fractional_positions(indices_grid, max_pos)
332
+
333
+ start = 1
334
+ end = theta
335
+ device = fractional_positions.device
336
+
337
+ indices = theta ** (
338
+ torch.linspace(
339
+ math.log(start, theta),
340
+ math.log(end, theta),
341
+ dim // 6,
342
+ device=device,
343
+ dtype=dtype,
344
+ )
345
+ )
346
+ indices = indices.to(dtype=dtype)
347
+
348
+ indices = indices * math.pi / 2
349
+
350
+ freqs = (
351
+ (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
352
+ .transpose(-1, -2)
353
+ .flatten(2)
354
+ )
355
+
356
+ cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
357
+ sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
358
+ if dim % 6 != 0:
359
+ cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
360
+ sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
361
+ cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
362
+ sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
363
+ return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
364
+
365
+
366
+ class LTXVModel(torch.nn.Module):
367
+ def __init__(self,
368
+ in_channels=128,
369
+ cross_attention_dim=2048,
370
+ attention_head_dim=64,
371
+ num_attention_heads=32,
372
+
373
+ caption_channels=4096,
374
+ num_layers=28,
375
+
376
+
377
+ positional_embedding_theta=10000.0,
378
+ positional_embedding_max_pos=[20, 2048, 2048],
379
+ causal_temporal_positioning=False,
380
+ vae_scale_factors=(8, 32, 32),
381
+ dtype=None, device=None, operations=None, **kwargs):
382
+ super().__init__()
383
+ self.generator = None
384
+ self.vae_scale_factors = vae_scale_factors
385
+ self.dtype = dtype
386
+ self.out_channels = in_channels
387
+ self.inner_dim = num_attention_heads * attention_head_dim
388
+ self.causal_temporal_positioning = causal_temporal_positioning
389
+
390
+ self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
391
+
392
+ self.adaln_single = AdaLayerNormSingle(
393
+ self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
394
+ )
395
+
396
+ # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
397
+
398
+ self.caption_projection = PixArtAlphaTextProjection(
399
+ in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
400
+ )
401
+
402
+ self.transformer_blocks = nn.ModuleList(
403
+ [
404
+ BasicTransformerBlock(
405
+ self.inner_dim,
406
+ num_attention_heads,
407
+ attention_head_dim,
408
+ context_dim=cross_attention_dim,
409
+ # attn_precision=attn_precision,
410
+ dtype=dtype, device=device, operations=operations
411
+ )
412
+ for d in range(num_layers)
413
+ ]
414
+ )
415
+
416
+ self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
417
+ self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
418
+ self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
419
+
420
+ self.patchifier = SymmetricPatchifier(1)
421
+
422
+ def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
423
+ patches_replace = transformer_options.get("patches_replace", {})
424
+
425
+ orig_shape = list(x.shape)
426
+
427
+ x, latent_coords = self.patchifier.patchify(x)
428
+ pixel_coords = latent_to_pixel_coords(
429
+ latent_coords=latent_coords,
430
+ scale_factors=self.vae_scale_factors,
431
+ causal_fix=self.causal_temporal_positioning,
432
+ )
433
+
434
+ if keyframe_idxs is not None:
435
+ pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
436
+
437
+ fractional_coords = pixel_coords.to(torch.float32)
438
+ fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
439
+
440
+ x = self.patchify_proj(x)
441
+ timestep = timestep * 1000.0
442
+
443
+ if attention_mask is not None and not torch.is_floating_point(attention_mask):
444
+ attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
445
+
446
+ pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
447
+
448
+ batch_size = x.shape[0]
449
+ timestep, embedded_timestep = self.adaln_single(
450
+ timestep.flatten(),
451
+ {"resolution": None, "aspect_ratio": None},
452
+ batch_size=batch_size,
453
+ hidden_dtype=x.dtype,
454
+ )
455
+ # Second dimension is 1 or number of tokens (if timestep_per_token)
456
+ timestep = timestep.view(batch_size, -1, timestep.shape[-1])
457
+ embedded_timestep = embedded_timestep.view(
458
+ batch_size, -1, embedded_timestep.shape[-1]
459
+ )
460
+
461
+ # 2. Blocks
462
+ if self.caption_projection is not None:
463
+ batch_size = x.shape[0]
464
+ context = self.caption_projection(context)
465
+ context = context.view(
466
+ batch_size, -1, x.shape[-1]
467
+ )
468
+
469
+ blocks_replace = patches_replace.get("dit", {})
470
+ for i, block in enumerate(self.transformer_blocks):
471
+ if ("double_block", i) in blocks_replace:
472
+ def block_wrap(args):
473
+ out = {}
474
+ out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
475
+ return out
476
+
477
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
478
+ x = out["img"]
479
+ else:
480
+ x = block(
481
+ x,
482
+ context=context,
483
+ attention_mask=attention_mask,
484
+ timestep=timestep,
485
+ pe=pe
486
+ )
487
+
488
+ # 3. Output
489
+ scale_shift_values = (
490
+ self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
491
+ )
492
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
493
+ x = self.norm_out(x)
494
+ # Modulation
495
+ x = x * (1 + scale) + shift
496
+ x = self.proj_out(x)
497
+
498
+ x = self.patchifier.unpatchify(
499
+ latents=x,
500
+ output_height=orig_shape[3],
501
+ output_width=orig_shape[4],
502
+ output_num_frames=orig_shape[2],
503
+ out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
504
+ )
505
+
506
+ return x
ComfyUI/comfy/ldm/lightricks/symmetric_patchifier.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor
7
+
8
+
9
+ def latent_to_pixel_coords(
10
+ latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
11
+ ) -> Tensor:
12
+ """
13
+ Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
14
+ configuration.
15
+ Args:
16
+ latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
17
+ containing the latent corner coordinates of each token.
18
+ scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
19
+ causal_fix (bool): Whether to take into account the different temporal scale
20
+ of the first frame. Default = False for backwards compatibility.
21
+ Returns:
22
+ Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
23
+ """
24
+ pixel_coords = (
25
+ latent_coords
26
+ * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
27
+ )
28
+ if causal_fix:
29
+ # Fix temporal scale for first frame to 1 due to causality
30
+ pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
31
+ return pixel_coords
32
+
33
+
34
+ class Patchifier(ABC):
35
+ def __init__(self, patch_size: int):
36
+ super().__init__()
37
+ self._patch_size = (1, patch_size, patch_size)
38
+
39
+ @abstractmethod
40
+ def patchify(
41
+ self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
42
+ ) -> Tuple[Tensor, Tensor]:
43
+ pass
44
+
45
+ @abstractmethod
46
+ def unpatchify(
47
+ self,
48
+ latents: Tensor,
49
+ output_height: int,
50
+ output_width: int,
51
+ output_num_frames: int,
52
+ out_channels: int,
53
+ ) -> Tuple[Tensor, Tensor]:
54
+ pass
55
+
56
+ @property
57
+ def patch_size(self):
58
+ return self._patch_size
59
+
60
+ def get_latent_coords(
61
+ self, latent_num_frames, latent_height, latent_width, batch_size, device
62
+ ):
63
+ """
64
+ Return a tensor of shape [batch_size, 3, num_patches] containing the
65
+ top-left corner latent coordinates of each latent patch.
66
+ The tensor is repeated for each batch element.
67
+ """
68
+ latent_sample_coords = torch.meshgrid(
69
+ torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
70
+ torch.arange(0, latent_height, self._patch_size[1], device=device),
71
+ torch.arange(0, latent_width, self._patch_size[2], device=device),
72
+ indexing="ij",
73
+ )
74
+ latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
75
+ latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
76
+ latent_coords = rearrange(
77
+ latent_coords, "b c f h w -> b c (f h w)", b=batch_size
78
+ )
79
+ return latent_coords
80
+
81
+
82
+ class SymmetricPatchifier(Patchifier):
83
+ def patchify(
84
+ self,
85
+ latents: Tensor,
86
+ ) -> Tuple[Tensor, Tensor]:
87
+ b, _, f, h, w = latents.shape
88
+ latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
89
+ latents = rearrange(
90
+ latents,
91
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
92
+ p1=self._patch_size[0],
93
+ p2=self._patch_size[1],
94
+ p3=self._patch_size[2],
95
+ )
96
+ return latents, latent_coords
97
+
98
+ def unpatchify(
99
+ self,
100
+ latents: Tensor,
101
+ output_height: int,
102
+ output_width: int,
103
+ output_num_frames: int,
104
+ out_channels: int,
105
+ ) -> Tuple[Tensor, Tensor]:
106
+ output_height = output_height // self._patch_size[1]
107
+ output_width = output_width // self._patch_size[2]
108
+ latents = rearrange(
109
+ latents,
110
+ "b (f h w) (c p q) -> b c f (h p) (w q) ",
111
+ f=output_num_frames,
112
+ h=output_height,
113
+ w=output_width,
114
+ p=self._patch_size[1],
115
+ q=self._patch_size[2],
116
+ )
117
+ return latents
ComfyUI/comfy/ldm/lightricks/vae/causal_conv3d.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import comfy.ops
6
+ ops = comfy.ops.disable_weight_init
7
+
8
+
9
+ class CausalConv3d(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channels,
13
+ out_channels,
14
+ kernel_size: int = 3,
15
+ stride: Union[int, Tuple[int]] = 1,
16
+ dilation: int = 1,
17
+ groups: int = 1,
18
+ spatial_padding_mode: str = "zeros",
19
+ **kwargs,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.in_channels = in_channels
24
+ self.out_channels = out_channels
25
+
26
+ kernel_size = (kernel_size, kernel_size, kernel_size)
27
+ self.time_kernel_size = kernel_size[0]
28
+
29
+ dilation = (dilation, 1, 1)
30
+
31
+ height_pad = kernel_size[1] // 2
32
+ width_pad = kernel_size[2] // 2
33
+ padding = (0, height_pad, width_pad)
34
+
35
+ self.conv = ops.Conv3d(
36
+ in_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ stride=stride,
40
+ dilation=dilation,
41
+ padding=padding,
42
+ padding_mode=spatial_padding_mode,
43
+ groups=groups,
44
+ )
45
+
46
+ def forward(self, x, causal: bool = True):
47
+ if causal:
48
+ first_frame_pad = x[:, :, :1, :, :].repeat(
49
+ (1, 1, self.time_kernel_size - 1, 1, 1)
50
+ )
51
+ x = torch.concatenate((first_frame_pad, x), dim=2)
52
+ else:
53
+ first_frame_pad = x[:, :, :1, :, :].repeat(
54
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
55
+ )
56
+ last_frame_pad = x[:, :, -1:, :, :].repeat(
57
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
58
+ )
59
+ x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
60
+ x = self.conv(x)
61
+ return x
62
+
63
+ @property
64
+ def weight(self):
65
+ return self.conv.weight
ComfyUI/comfy/ldm/lumina/model.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
2
+ from __future__ import annotations
3
+
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import comfy.ldm.common_dit
10
+
11
+ from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
12
+ from comfy.ldm.modules.attention import optimized_attention_masked
13
+ from comfy.ldm.flux.layers import EmbedND
14
+
15
+
16
+ def modulate(x, scale):
17
+ return x * (1 + scale.unsqueeze(1))
18
+
19
+ #############################################################################
20
+ # Core NextDiT Model #
21
+ #############################################################################
22
+
23
+
24
+ class JointAttention(nn.Module):
25
+ """Multi-head attention module."""
26
+
27
+ def __init__(
28
+ self,
29
+ dim: int,
30
+ n_heads: int,
31
+ n_kv_heads: Optional[int],
32
+ qk_norm: bool,
33
+ operation_settings={},
34
+ ):
35
+ """
36
+ Initialize the Attention module.
37
+
38
+ Args:
39
+ dim (int): Number of input dimensions.
40
+ n_heads (int): Number of heads.
41
+ n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
42
+
43
+ """
44
+ super().__init__()
45
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
46
+ self.n_local_heads = n_heads
47
+ self.n_local_kv_heads = self.n_kv_heads
48
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
49
+ self.head_dim = dim // n_heads
50
+
51
+ self.qkv = operation_settings.get("operations").Linear(
52
+ dim,
53
+ (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
54
+ bias=False,
55
+ device=operation_settings.get("device"),
56
+ dtype=operation_settings.get("dtype"),
57
+ )
58
+ self.out = operation_settings.get("operations").Linear(
59
+ n_heads * self.head_dim,
60
+ dim,
61
+ bias=False,
62
+ device=operation_settings.get("device"),
63
+ dtype=operation_settings.get("dtype"),
64
+ )
65
+
66
+ if qk_norm:
67
+ self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
68
+ self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
69
+ else:
70
+ self.q_norm = self.k_norm = nn.Identity()
71
+
72
+ @staticmethod
73
+ def apply_rotary_emb(
74
+ x_in: torch.Tensor,
75
+ freqs_cis: torch.Tensor,
76
+ ) -> torch.Tensor:
77
+ """
78
+ Apply rotary embeddings to input tensors using the given frequency
79
+ tensor.
80
+
81
+ This function applies rotary embeddings to the given query 'xq' and
82
+ key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
83
+ input tensors are reshaped as complex numbers, and the frequency tensor
84
+ is reshaped for broadcasting compatibility. The resulting tensors
85
+ contain rotary embeddings and are returned as real tensors.
86
+
87
+ Args:
88
+ x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
89
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
90
+ exponentials.
91
+
92
+ Returns:
93
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
94
+ and key tensor with rotary embeddings.
95
+ """
96
+
97
+ t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
98
+ t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
99
+ return t_out.reshape(*x_in.shape)
100
+
101
+ def forward(
102
+ self,
103
+ x: torch.Tensor,
104
+ x_mask: torch.Tensor,
105
+ freqs_cis: torch.Tensor,
106
+ ) -> torch.Tensor:
107
+ """
108
+
109
+ Args:
110
+ x:
111
+ x_mask:
112
+ freqs_cis:
113
+
114
+ Returns:
115
+
116
+ """
117
+ bsz, seqlen, _ = x.shape
118
+
119
+ xq, xk, xv = torch.split(
120
+ self.qkv(x),
121
+ [
122
+ self.n_local_heads * self.head_dim,
123
+ self.n_local_kv_heads * self.head_dim,
124
+ self.n_local_kv_heads * self.head_dim,
125
+ ],
126
+ dim=-1,
127
+ )
128
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
129
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
130
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
131
+
132
+ xq = self.q_norm(xq)
133
+ xk = self.k_norm(xk)
134
+
135
+ xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
136
+ xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
137
+
138
+ n_rep = self.n_local_heads // self.n_local_kv_heads
139
+ if n_rep >= 1:
140
+ xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
141
+ xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
142
+ output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
143
+
144
+ return self.out(output)
145
+
146
+
147
+ class FeedForward(nn.Module):
148
+ def __init__(
149
+ self,
150
+ dim: int,
151
+ hidden_dim: int,
152
+ multiple_of: int,
153
+ ffn_dim_multiplier: Optional[float],
154
+ operation_settings={},
155
+ ):
156
+ """
157
+ Initialize the FeedForward module.
158
+
159
+ Args:
160
+ dim (int): Input dimension.
161
+ hidden_dim (int): Hidden dimension of the feedforward layer.
162
+ multiple_of (int): Value to ensure hidden dimension is a multiple
163
+ of this value.
164
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden
165
+ dimension. Defaults to None.
166
+
167
+ """
168
+ super().__init__()
169
+ # custom dim factor multiplier
170
+ if ffn_dim_multiplier is not None:
171
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
172
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
173
+
174
+ self.w1 = operation_settings.get("operations").Linear(
175
+ dim,
176
+ hidden_dim,
177
+ bias=False,
178
+ device=operation_settings.get("device"),
179
+ dtype=operation_settings.get("dtype"),
180
+ )
181
+ self.w2 = operation_settings.get("operations").Linear(
182
+ hidden_dim,
183
+ dim,
184
+ bias=False,
185
+ device=operation_settings.get("device"),
186
+ dtype=operation_settings.get("dtype"),
187
+ )
188
+ self.w3 = operation_settings.get("operations").Linear(
189
+ dim,
190
+ hidden_dim,
191
+ bias=False,
192
+ device=operation_settings.get("device"),
193
+ dtype=operation_settings.get("dtype"),
194
+ )
195
+
196
+ # @torch.compile
197
+ def _forward_silu_gating(self, x1, x3):
198
+ return F.silu(x1) * x3
199
+
200
+ def forward(self, x):
201
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
202
+
203
+
204
+ class JointTransformerBlock(nn.Module):
205
+ def __init__(
206
+ self,
207
+ layer_id: int,
208
+ dim: int,
209
+ n_heads: int,
210
+ n_kv_heads: int,
211
+ multiple_of: int,
212
+ ffn_dim_multiplier: float,
213
+ norm_eps: float,
214
+ qk_norm: bool,
215
+ modulation=True,
216
+ operation_settings={},
217
+ ) -> None:
218
+ """
219
+ Initialize a TransformerBlock.
220
+
221
+ Args:
222
+ layer_id (int): Identifier for the layer.
223
+ dim (int): Embedding dimension of the input features.
224
+ n_heads (int): Number of attention heads.
225
+ n_kv_heads (Optional[int]): Number of attention heads in key and
226
+ value features (if using GQA), or set to None for the same as
227
+ query.
228
+ multiple_of (int):
229
+ ffn_dim_multiplier (float):
230
+ norm_eps (float):
231
+
232
+ """
233
+ super().__init__()
234
+ self.dim = dim
235
+ self.head_dim = dim // n_heads
236
+ self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
237
+ self.feed_forward = FeedForward(
238
+ dim=dim,
239
+ hidden_dim=4 * dim,
240
+ multiple_of=multiple_of,
241
+ ffn_dim_multiplier=ffn_dim_multiplier,
242
+ operation_settings=operation_settings,
243
+ )
244
+ self.layer_id = layer_id
245
+ self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
246
+ self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
247
+
248
+ self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
249
+ self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
250
+
251
+ self.modulation = modulation
252
+ if modulation:
253
+ self.adaLN_modulation = nn.Sequential(
254
+ nn.SiLU(),
255
+ operation_settings.get("operations").Linear(
256
+ min(dim, 1024),
257
+ 4 * dim,
258
+ bias=True,
259
+ device=operation_settings.get("device"),
260
+ dtype=operation_settings.get("dtype"),
261
+ ),
262
+ )
263
+
264
+ def forward(
265
+ self,
266
+ x: torch.Tensor,
267
+ x_mask: torch.Tensor,
268
+ freqs_cis: torch.Tensor,
269
+ adaln_input: Optional[torch.Tensor]=None,
270
+ ):
271
+ """
272
+ Perform a forward pass through the TransformerBlock.
273
+
274
+ Args:
275
+ x (torch.Tensor): Input tensor.
276
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
277
+
278
+ Returns:
279
+ torch.Tensor: Output tensor after applying attention and
280
+ feedforward layers.
281
+
282
+ """
283
+ if self.modulation:
284
+ assert adaln_input is not None
285
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
286
+
287
+ x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
288
+ self.attention(
289
+ modulate(self.attention_norm1(x), scale_msa),
290
+ x_mask,
291
+ freqs_cis,
292
+ )
293
+ )
294
+ x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
295
+ self.feed_forward(
296
+ modulate(self.ffn_norm1(x), scale_mlp),
297
+ )
298
+ )
299
+ else:
300
+ assert adaln_input is None
301
+ x = x + self.attention_norm2(
302
+ self.attention(
303
+ self.attention_norm1(x),
304
+ x_mask,
305
+ freqs_cis,
306
+ )
307
+ )
308
+ x = x + self.ffn_norm2(
309
+ self.feed_forward(
310
+ self.ffn_norm1(x),
311
+ )
312
+ )
313
+ return x
314
+
315
+
316
+ class FinalLayer(nn.Module):
317
+ """
318
+ The final layer of NextDiT.
319
+ """
320
+
321
+ def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
322
+ super().__init__()
323
+ self.norm_final = operation_settings.get("operations").LayerNorm(
324
+ hidden_size,
325
+ elementwise_affine=False,
326
+ eps=1e-6,
327
+ device=operation_settings.get("device"),
328
+ dtype=operation_settings.get("dtype"),
329
+ )
330
+ self.linear = operation_settings.get("operations").Linear(
331
+ hidden_size,
332
+ patch_size * patch_size * out_channels,
333
+ bias=True,
334
+ device=operation_settings.get("device"),
335
+ dtype=operation_settings.get("dtype"),
336
+ )
337
+
338
+ self.adaLN_modulation = nn.Sequential(
339
+ nn.SiLU(),
340
+ operation_settings.get("operations").Linear(
341
+ min(hidden_size, 1024),
342
+ hidden_size,
343
+ bias=True,
344
+ device=operation_settings.get("device"),
345
+ dtype=operation_settings.get("dtype"),
346
+ ),
347
+ )
348
+
349
+ def forward(self, x, c):
350
+ scale = self.adaLN_modulation(c)
351
+ x = modulate(self.norm_final(x), scale)
352
+ x = self.linear(x)
353
+ return x
354
+
355
+
356
+ class NextDiT(nn.Module):
357
+ """
358
+ Diffusion model with a Transformer backbone.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ patch_size: int = 2,
364
+ in_channels: int = 4,
365
+ dim: int = 4096,
366
+ n_layers: int = 32,
367
+ n_refiner_layers: int = 2,
368
+ n_heads: int = 32,
369
+ n_kv_heads: Optional[int] = None,
370
+ multiple_of: int = 256,
371
+ ffn_dim_multiplier: Optional[float] = None,
372
+ norm_eps: float = 1e-5,
373
+ qk_norm: bool = False,
374
+ cap_feat_dim: int = 5120,
375
+ axes_dims: List[int] = (16, 56, 56),
376
+ axes_lens: List[int] = (1, 512, 512),
377
+ image_model=None,
378
+ device=None,
379
+ dtype=None,
380
+ operations=None,
381
+ ) -> None:
382
+ super().__init__()
383
+ self.dtype = dtype
384
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
385
+ self.in_channels = in_channels
386
+ self.out_channels = in_channels
387
+ self.patch_size = patch_size
388
+
389
+ self.x_embedder = operation_settings.get("operations").Linear(
390
+ in_features=patch_size * patch_size * in_channels,
391
+ out_features=dim,
392
+ bias=True,
393
+ device=operation_settings.get("device"),
394
+ dtype=operation_settings.get("dtype"),
395
+ )
396
+
397
+ self.noise_refiner = nn.ModuleList(
398
+ [
399
+ JointTransformerBlock(
400
+ layer_id,
401
+ dim,
402
+ n_heads,
403
+ n_kv_heads,
404
+ multiple_of,
405
+ ffn_dim_multiplier,
406
+ norm_eps,
407
+ qk_norm,
408
+ modulation=True,
409
+ operation_settings=operation_settings,
410
+ )
411
+ for layer_id in range(n_refiner_layers)
412
+ ]
413
+ )
414
+ self.context_refiner = nn.ModuleList(
415
+ [
416
+ JointTransformerBlock(
417
+ layer_id,
418
+ dim,
419
+ n_heads,
420
+ n_kv_heads,
421
+ multiple_of,
422
+ ffn_dim_multiplier,
423
+ norm_eps,
424
+ qk_norm,
425
+ modulation=False,
426
+ operation_settings=operation_settings,
427
+ )
428
+ for layer_id in range(n_refiner_layers)
429
+ ]
430
+ )
431
+
432
+ self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
433
+ self.cap_embedder = nn.Sequential(
434
+ operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
435
+ operation_settings.get("operations").Linear(
436
+ cap_feat_dim,
437
+ dim,
438
+ bias=True,
439
+ device=operation_settings.get("device"),
440
+ dtype=operation_settings.get("dtype"),
441
+ ),
442
+ )
443
+
444
+ self.layers = nn.ModuleList(
445
+ [
446
+ JointTransformerBlock(
447
+ layer_id,
448
+ dim,
449
+ n_heads,
450
+ n_kv_heads,
451
+ multiple_of,
452
+ ffn_dim_multiplier,
453
+ norm_eps,
454
+ qk_norm,
455
+ operation_settings=operation_settings,
456
+ )
457
+ for layer_id in range(n_layers)
458
+ ]
459
+ )
460
+ self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
461
+ self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
462
+
463
+ assert (dim // n_heads) == sum(axes_dims)
464
+ self.axes_dims = axes_dims
465
+ self.axes_lens = axes_lens
466
+ self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
467
+ self.dim = dim
468
+ self.n_heads = n_heads
469
+
470
+ def unpatchify(
471
+ self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
472
+ ) -> List[torch.Tensor]:
473
+ """
474
+ x: (N, T, patch_size**2 * C)
475
+ imgs: (N, H, W, C)
476
+ """
477
+ pH = pW = self.patch_size
478
+ imgs = []
479
+ for i in range(x.size(0)):
480
+ H, W = img_size[i]
481
+ begin = cap_size[i]
482
+ end = begin + (H // pH) * (W // pW)
483
+ imgs.append(
484
+ x[i][begin:end]
485
+ .view(H // pH, W // pW, pH, pW, self.out_channels)
486
+ .permute(4, 0, 2, 1, 3)
487
+ .flatten(3, 4)
488
+ .flatten(1, 2)
489
+ )
490
+
491
+ if return_tensor:
492
+ imgs = torch.stack(imgs, dim=0)
493
+ return imgs
494
+
495
+ def patchify_and_embed(
496
+ self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
497
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
498
+ bsz = len(x)
499
+ pH = pW = self.patch_size
500
+ device = x[0].device
501
+ dtype = x[0].dtype
502
+
503
+ if cap_mask is not None:
504
+ l_effective_cap_len = cap_mask.sum(dim=1).tolist()
505
+ else:
506
+ l_effective_cap_len = [num_tokens] * bsz
507
+
508
+ if cap_mask is not None and not torch.is_floating_point(cap_mask):
509
+ cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
510
+
511
+ img_sizes = [(img.size(1), img.size(2)) for img in x]
512
+ l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
513
+
514
+ max_seq_len = max(
515
+ (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
516
+ )
517
+ max_cap_len = max(l_effective_cap_len)
518
+ max_img_len = max(l_effective_img_len)
519
+
520
+ position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
521
+
522
+ for i in range(bsz):
523
+ cap_len = l_effective_cap_len[i]
524
+ img_len = l_effective_img_len[i]
525
+ H, W = img_sizes[i]
526
+ H_tokens, W_tokens = H // pH, W // pW
527
+ assert H_tokens * W_tokens == img_len
528
+
529
+ position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
530
+ position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
531
+ row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
532
+ col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
533
+ position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
534
+ position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
535
+
536
+ freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
537
+
538
+ # build freqs_cis for cap and image individually
539
+ cap_freqs_cis_shape = list(freqs_cis.shape)
540
+ # cap_freqs_cis_shape[1] = max_cap_len
541
+ cap_freqs_cis_shape[1] = cap_feats.shape[1]
542
+ cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
543
+
544
+ img_freqs_cis_shape = list(freqs_cis.shape)
545
+ img_freqs_cis_shape[1] = max_img_len
546
+ img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
547
+
548
+ for i in range(bsz):
549
+ cap_len = l_effective_cap_len[i]
550
+ img_len = l_effective_img_len[i]
551
+ cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
552
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
553
+
554
+ # refine context
555
+ for layer in self.context_refiner:
556
+ cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
557
+
558
+ # refine image
559
+ flat_x = []
560
+ for i in range(bsz):
561
+ img = x[i]
562
+ C, H, W = img.size()
563
+ img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
564
+ flat_x.append(img)
565
+ x = flat_x
566
+ padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
567
+ padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
568
+ for i in range(bsz):
569
+ padded_img_embed[i, :l_effective_img_len[i]] = x[i]
570
+ padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
571
+
572
+ padded_img_embed = self.x_embedder(padded_img_embed)
573
+ padded_img_mask = padded_img_mask.unsqueeze(1)
574
+ for layer in self.noise_refiner:
575
+ padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
576
+
577
+ if cap_mask is not None:
578
+ mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
579
+ mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
580
+ else:
581
+ mask = None
582
+
583
+ padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
584
+ for i in range(bsz):
585
+ cap_len = l_effective_cap_len[i]
586
+ img_len = l_effective_img_len[i]
587
+
588
+ padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
589
+ padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
590
+
591
+ return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
592
+
593
+ # def forward(self, x, t, cap_feats, cap_mask):
594
+ def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
595
+ t = 1.0 - timesteps
596
+ cap_feats = context
597
+ cap_mask = attention_mask
598
+ bs, c, h, w = x.shape
599
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
600
+ """
601
+ Forward pass of NextDiT.
602
+ t: (N,) tensor of diffusion timesteps
603
+ y: (N,) tensor of text tokens/features
604
+ """
605
+
606
+ t = self.t_embedder(t, dtype=x.dtype) # (N, D)
607
+ adaln_input = t
608
+
609
+ cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
610
+
611
+ x_is_tensor = isinstance(x, torch.Tensor)
612
+ x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
613
+ freqs_cis = freqs_cis.to(x.device)
614
+
615
+ for layer in self.layers:
616
+ x = layer(x, mask, freqs_cis, adaln_input)
617
+
618
+ x = self.final_layer(x, adaln_input)
619
+ x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
620
+
621
+ return -x
622
+
ComfyUI/comfy/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import torch
4
+ from contextlib import contextmanager
5
+ from typing import Any, Dict, Tuple, Union
6
+
7
+ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from comfy.ldm.util import get_obj_from_str, instantiate_from_config
10
+ from comfy.ldm.modules.ema import LitEma
11
+ import comfy.ops
12
+
13
+ class DiagonalGaussianRegularizer(torch.nn.Module):
14
+ def __init__(self, sample: bool = False):
15
+ super().__init__()
16
+ self.sample = sample
17
+
18
+ def get_trainable_parameters(self) -> Any:
19
+ yield from ()
20
+
21
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
22
+ posterior = DiagonalGaussianDistribution(z)
23
+ if self.sample:
24
+ z = posterior.sample()
25
+ else:
26
+ z = posterior.mode()
27
+ return z, None
28
+
29
+
30
+ class AbstractAutoencoder(torch.nn.Module):
31
+ """
32
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
33
+ unCLIP models, etc. Hence, it is fairly general, and specific features
34
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ ema_decay: Union[None, float] = None,
40
+ monitor: Union[None, str] = None,
41
+ input_key: str = "jpg",
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.input_key = input_key
47
+ self.use_ema = ema_decay is not None
48
+ if monitor is not None:
49
+ self.monitor = monitor
50
+
51
+ if self.use_ema:
52
+ self.model_ema = LitEma(self, decay=ema_decay)
53
+ logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
54
+
55
+ def get_input(self, batch) -> Any:
56
+ raise NotImplementedError()
57
+
58
+ def on_train_batch_end(self, *args, **kwargs):
59
+ # for EMA computation
60
+ if self.use_ema:
61
+ self.model_ema(self)
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ logging.info(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ logging.info(f"{context}: Restored training weights")
77
+
78
+ def encode(self, *args, **kwargs) -> torch.Tensor:
79
+ raise NotImplementedError("encode()-method of abstract base class called")
80
+
81
+ def decode(self, *args, **kwargs) -> torch.Tensor:
82
+ raise NotImplementedError("decode()-method of abstract base class called")
83
+
84
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
85
+ logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
86
+ return get_obj_from_str(cfg["target"])(
87
+ params, lr=lr, **cfg.get("params", dict())
88
+ )
89
+
90
+ def configure_optimizers(self) -> Any:
91
+ raise NotImplementedError()
92
+
93
+
94
+ class AutoencodingEngine(AbstractAutoencoder):
95
+ """
96
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
97
+ (we also restore them explicitly as special cases for legacy reasons).
98
+ Regularizations such as KL or VQ are moved to the regularizer class.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ *args,
104
+ encoder_config: Dict,
105
+ decoder_config: Dict,
106
+ regularizer_config: Dict,
107
+ **kwargs,
108
+ ):
109
+ super().__init__(*args, **kwargs)
110
+
111
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
112
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
113
+ self.regularization = instantiate_from_config(
114
+ regularizer_config
115
+ )
116
+
117
+ def get_last_layer(self):
118
+ return self.decoder.get_last_layer()
119
+
120
+ def encode(
121
+ self,
122
+ x: torch.Tensor,
123
+ return_reg_log: bool = False,
124
+ unregularized: bool = False,
125
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
126
+ z = self.encoder(x)
127
+ if unregularized:
128
+ return z, dict()
129
+ z, reg_log = self.regularization(z)
130
+ if return_reg_log:
131
+ return z, reg_log
132
+ return z
133
+
134
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
135
+ x = self.decoder(z, **kwargs)
136
+ return x
137
+
138
+ def forward(
139
+ self, x: torch.Tensor, **additional_decode_kwargs
140
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
141
+ z, reg_log = self.encode(x, return_reg_log=True)
142
+ dec = self.decode(z, **additional_decode_kwargs)
143
+ return z, dec, reg_log
144
+
145
+
146
+ class AutoencodingEngineLegacy(AutoencodingEngine):
147
+ def __init__(self, embed_dim: int, **kwargs):
148
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
149
+ ddconfig = kwargs.pop("ddconfig")
150
+ super().__init__(
151
+ encoder_config={
152
+ "target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
153
+ "params": ddconfig,
154
+ },
155
+ decoder_config={
156
+ "target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
157
+ "params": ddconfig,
158
+ },
159
+ **kwargs,
160
+ )
161
+
162
+ if ddconfig.get("conv3d", False):
163
+ conv_op = comfy.ops.disable_weight_init.Conv3d
164
+ else:
165
+ conv_op = comfy.ops.disable_weight_init.Conv2d
166
+
167
+ self.quant_conv = conv_op(
168
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
169
+ (1 + ddconfig["double_z"]) * embed_dim,
170
+ 1,
171
+ )
172
+
173
+ self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
174
+ self.embed_dim = embed_dim
175
+
176
+ def get_autoencoder_params(self) -> list:
177
+ params = super().get_autoencoder_params()
178
+ return params
179
+
180
+ def encode(
181
+ self, x: torch.Tensor, return_reg_log: bool = False
182
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
183
+ if self.max_batch_size is None:
184
+ z = self.encoder(x)
185
+ z = self.quant_conv(z)
186
+ else:
187
+ N = x.shape[0]
188
+ bs = self.max_batch_size
189
+ n_batches = int(math.ceil(N / bs))
190
+ z = list()
191
+ for i_batch in range(n_batches):
192
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
193
+ z_batch = self.quant_conv(z_batch)
194
+ z.append(z_batch)
195
+ z = torch.cat(z, 0)
196
+
197
+ z, reg_log = self.regularization(z)
198
+ if return_reg_log:
199
+ return z, reg_log
200
+ return z
201
+
202
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
203
+ if self.max_batch_size is None:
204
+ dec = self.post_quant_conv(z)
205
+ dec = self.decoder(dec, **decoder_kwargs)
206
+ else:
207
+ N = z.shape[0]
208
+ bs = self.max_batch_size
209
+ n_batches = int(math.ceil(N / bs))
210
+ dec = list()
211
+ for i_batch in range(n_batches):
212
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
213
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
214
+ dec.append(dec_batch)
215
+ dec = torch.cat(dec, 0)
216
+
217
+ return dec
218
+
219
+
220
+ class AutoencoderKL(AutoencodingEngineLegacy):
221
+ def __init__(self, **kwargs):
222
+ if "lossconfig" in kwargs:
223
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
224
+ super().__init__(
225
+ regularizer_config={
226
+ "target": (
227
+ "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
228
+ )
229
+ },
230
+ **kwargs,
231
+ )
ComfyUI/comfy/ldm/modules/attention.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, einsum
7
+ from einops import rearrange, repeat
8
+ from typing import Optional
9
+ import logging
10
+
11
+ from .diffusionmodules.util import AlphaBlender, timestep_embedding
12
+ from .sub_quadratic_attention import efficient_dot_product_attention
13
+
14
+ from comfy import model_management
15
+
16
+ if model_management.xformers_enabled():
17
+ import xformers
18
+ import xformers.ops
19
+
20
+ if model_management.sage_attention_enabled():
21
+ try:
22
+ from sageattention import sageattn
23
+ except ModuleNotFoundError as e:
24
+ if e.name == "sageattention":
25
+ logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
26
+ else:
27
+ raise e
28
+ exit(-1)
29
+
30
+ if model_management.flash_attention_enabled():
31
+ try:
32
+ from flash_attn import flash_attn_func
33
+ except ModuleNotFoundError:
34
+ logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
35
+ exit(-1)
36
+
37
+ from comfy.cli_args import args
38
+ import comfy.ops
39
+ ops = comfy.ops.disable_weight_init
40
+
41
+ FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
42
+
43
+ def get_attn_precision(attn_precision, current_dtype):
44
+ if args.dont_upcast_attention:
45
+ return None
46
+
47
+ if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
48
+ return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
49
+ return attn_precision
50
+
51
+ def exists(val):
52
+ return val is not None
53
+
54
+
55
+ def default(val, d):
56
+ if exists(val):
57
+ return val
58
+ return d
59
+
60
+
61
+ # feedforward
62
+ class GEGLU(nn.Module):
63
+ def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
64
+ super().__init__()
65
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
66
+
67
+ def forward(self, x):
68
+ x, gate = self.proj(x).chunk(2, dim=-1)
69
+ return x * F.gelu(gate)
70
+
71
+
72
+ class FeedForward(nn.Module):
73
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
74
+ super().__init__()
75
+ inner_dim = int(dim * mult)
76
+ dim_out = default(dim_out, dim)
77
+ project_in = nn.Sequential(
78
+ operations.Linear(dim, inner_dim, dtype=dtype, device=device),
79
+ nn.GELU()
80
+ ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
81
+
82
+ self.net = nn.Sequential(
83
+ project_in,
84
+ nn.Dropout(dropout),
85
+ operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
86
+ )
87
+
88
+ def forward(self, x):
89
+ return self.net(x)
90
+
91
+ def Normalize(in_channels, dtype=None, device=None):
92
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
93
+
94
+ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
95
+ attn_precision = get_attn_precision(attn_precision, q.dtype)
96
+
97
+ if skip_reshape:
98
+ b, _, _, dim_head = q.shape
99
+ else:
100
+ b, _, dim_head = q.shape
101
+ dim_head //= heads
102
+
103
+ scale = dim_head ** -0.5
104
+
105
+ h = heads
106
+ if skip_reshape:
107
+ q, k, v = map(
108
+ lambda t: t.reshape(b * heads, -1, dim_head),
109
+ (q, k, v),
110
+ )
111
+ else:
112
+ q, k, v = map(
113
+ lambda t: t.unsqueeze(3)
114
+ .reshape(b, -1, heads, dim_head)
115
+ .permute(0, 2, 1, 3)
116
+ .reshape(b * heads, -1, dim_head)
117
+ .contiguous(),
118
+ (q, k, v),
119
+ )
120
+
121
+ # force cast to fp32 to avoid overflowing
122
+ if attn_precision == torch.float32:
123
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
124
+ else:
125
+ sim = einsum('b i d, b j d -> b i j', q, k) * scale
126
+
127
+ del q, k
128
+
129
+ if exists(mask):
130
+ if mask.dtype == torch.bool:
131
+ mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
132
+ max_neg_value = -torch.finfo(sim.dtype).max
133
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
134
+ sim.masked_fill_(~mask, max_neg_value)
135
+ else:
136
+ if len(mask.shape) == 2:
137
+ bs = 1
138
+ else:
139
+ bs = mask.shape[0]
140
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
141
+ sim.add_(mask)
142
+
143
+ # attention, what we cannot get enough of
144
+ sim = sim.softmax(dim=-1)
145
+
146
+ out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
147
+
148
+ if skip_output_reshape:
149
+ out = (
150
+ out.unsqueeze(0)
151
+ .reshape(b, heads, -1, dim_head)
152
+ )
153
+ else:
154
+ out = (
155
+ out.unsqueeze(0)
156
+ .reshape(b, heads, -1, dim_head)
157
+ .permute(0, 2, 1, 3)
158
+ .reshape(b, -1, heads * dim_head)
159
+ )
160
+ return out
161
+
162
+
163
+ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
164
+ attn_precision = get_attn_precision(attn_precision, query.dtype)
165
+
166
+ if skip_reshape:
167
+ b, _, _, dim_head = query.shape
168
+ else:
169
+ b, _, dim_head = query.shape
170
+ dim_head //= heads
171
+
172
+ if skip_reshape:
173
+ query = query.reshape(b * heads, -1, dim_head)
174
+ value = value.reshape(b * heads, -1, dim_head)
175
+ key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
176
+ else:
177
+ query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
178
+ value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
179
+ key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
180
+
181
+
182
+ dtype = query.dtype
183
+ upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
184
+ if upcast_attention:
185
+ bytes_per_token = torch.finfo(torch.float32).bits//8
186
+ else:
187
+ bytes_per_token = torch.finfo(query.dtype).bits//8
188
+ batch_x_heads, q_tokens, _ = query.shape
189
+ _, _, k_tokens = key.shape
190
+
191
+ mem_free_total, _ = model_management.get_free_memory(query.device, True)
192
+
193
+ kv_chunk_size_min = None
194
+ kv_chunk_size = None
195
+ query_chunk_size = None
196
+
197
+ for x in [4096, 2048, 1024, 512, 256]:
198
+ count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
199
+ if count >= k_tokens:
200
+ kv_chunk_size = k_tokens
201
+ query_chunk_size = x
202
+ break
203
+
204
+ if query_chunk_size is None:
205
+ query_chunk_size = 512
206
+
207
+ if mask is not None:
208
+ if len(mask.shape) == 2:
209
+ bs = 1
210
+ else:
211
+ bs = mask.shape[0]
212
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
213
+
214
+ hidden_states = efficient_dot_product_attention(
215
+ query,
216
+ key,
217
+ value,
218
+ query_chunk_size=query_chunk_size,
219
+ kv_chunk_size=kv_chunk_size,
220
+ kv_chunk_size_min=kv_chunk_size_min,
221
+ use_checkpoint=False,
222
+ upcast_attention=upcast_attention,
223
+ mask=mask,
224
+ )
225
+
226
+ hidden_states = hidden_states.to(dtype)
227
+ if skip_output_reshape:
228
+ hidden_states = hidden_states.unflatten(0, (-1, heads))
229
+ else:
230
+ hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
231
+ return hidden_states
232
+
233
+ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
234
+ attn_precision = get_attn_precision(attn_precision, q.dtype)
235
+
236
+ if skip_reshape:
237
+ b, _, _, dim_head = q.shape
238
+ else:
239
+ b, _, dim_head = q.shape
240
+ dim_head //= heads
241
+
242
+ scale = dim_head ** -0.5
243
+
244
+ if skip_reshape:
245
+ q, k, v = map(
246
+ lambda t: t.reshape(b * heads, -1, dim_head),
247
+ (q, k, v),
248
+ )
249
+ else:
250
+ q, k, v = map(
251
+ lambda t: t.unsqueeze(3)
252
+ .reshape(b, -1, heads, dim_head)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(b * heads, -1, dim_head)
255
+ .contiguous(),
256
+ (q, k, v),
257
+ )
258
+
259
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
260
+
261
+ mem_free_total = model_management.get_free_memory(q.device)
262
+
263
+ if attn_precision == torch.float32:
264
+ element_size = 4
265
+ upcast = True
266
+ else:
267
+ element_size = q.element_size()
268
+ upcast = False
269
+
270
+ gb = 1024 ** 3
271
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
272
+ modifier = 3
273
+ mem_required = tensor_size * modifier
274
+ steps = 1
275
+
276
+
277
+ if mem_required > mem_free_total:
278
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
279
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
280
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
281
+
282
+ if steps > 64:
283
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
284
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
285
+ f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
286
+
287
+ if mask is not None:
288
+ if len(mask.shape) == 2:
289
+ bs = 1
290
+ else:
291
+ bs = mask.shape[0]
292
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
293
+
294
+ # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
295
+ first_op_done = False
296
+ cleared_cache = False
297
+ while True:
298
+ try:
299
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
300
+ for i in range(0, q.shape[1], slice_size):
301
+ end = i + slice_size
302
+ if upcast:
303
+ with torch.autocast(enabled=False, device_type = 'cuda'):
304
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
305
+ else:
306
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
307
+
308
+ if mask is not None:
309
+ if len(mask.shape) == 2:
310
+ s1 += mask[i:end]
311
+ else:
312
+ if mask.shape[1] == 1:
313
+ s1 += mask
314
+ else:
315
+ s1 += mask[:, i:end]
316
+
317
+ s2 = s1.softmax(dim=-1).to(v.dtype)
318
+ del s1
319
+ first_op_done = True
320
+
321
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
322
+ del s2
323
+ break
324
+ except model_management.OOM_EXCEPTION as e:
325
+ if first_op_done == False:
326
+ model_management.soft_empty_cache(True)
327
+ if cleared_cache == False:
328
+ cleared_cache = True
329
+ logging.warning("out of memory error, emptying cache and trying again")
330
+ continue
331
+ steps *= 2
332
+ if steps > 64:
333
+ raise e
334
+ logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
335
+ else:
336
+ raise e
337
+
338
+ del q, k, v
339
+
340
+ if skip_output_reshape:
341
+ r1 = (
342
+ r1.unsqueeze(0)
343
+ .reshape(b, heads, -1, dim_head)
344
+ )
345
+ else:
346
+ r1 = (
347
+ r1.unsqueeze(0)
348
+ .reshape(b, heads, -1, dim_head)
349
+ .permute(0, 2, 1, 3)
350
+ .reshape(b, -1, heads * dim_head)
351
+ )
352
+ return r1
353
+
354
+ BROKEN_XFORMERS = False
355
+ try:
356
+ x_vers = xformers.__version__
357
+ # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
358
+ BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
359
+ except:
360
+ pass
361
+
362
+ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
363
+ b = q.shape[0]
364
+ dim_head = q.shape[-1]
365
+ # check to make sure xformers isn't broken
366
+ disabled_xformers = False
367
+
368
+ if BROKEN_XFORMERS:
369
+ if b * heads > 65535:
370
+ disabled_xformers = True
371
+
372
+ if not disabled_xformers:
373
+ if torch.jit.is_tracing() or torch.jit.is_scripting():
374
+ disabled_xformers = True
375
+
376
+ if disabled_xformers:
377
+ return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
378
+
379
+ if skip_reshape:
380
+ # b h k d -> b k h d
381
+ q, k, v = map(
382
+ lambda t: t.permute(0, 2, 1, 3),
383
+ (q, k, v),
384
+ )
385
+ # actually do the reshaping
386
+ else:
387
+ dim_head //= heads
388
+ q, k, v = map(
389
+ lambda t: t.reshape(b, -1, heads, dim_head),
390
+ (q, k, v),
391
+ )
392
+
393
+ if mask is not None:
394
+ # add a singleton batch dimension
395
+ if mask.ndim == 2:
396
+ mask = mask.unsqueeze(0)
397
+ # add a singleton heads dimension
398
+ if mask.ndim == 3:
399
+ mask = mask.unsqueeze(1)
400
+ # pad to a multiple of 8
401
+ pad = 8 - mask.shape[-1] % 8
402
+ # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
403
+ # but when using separated heads, the shape has to be (B, H, Nq, Nk)
404
+ # in flux, this matrix ends up being over 1GB
405
+ # here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
406
+ mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
407
+
408
+ mask_out[..., :mask.shape[-1]] = mask
409
+ # doesn't this remove the padding again??
410
+ mask = mask_out[..., :mask.shape[-1]]
411
+ mask = mask.expand(b, heads, -1, -1)
412
+
413
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
414
+
415
+ if skip_output_reshape:
416
+ out = out.permute(0, 2, 1, 3)
417
+ else:
418
+ out = (
419
+ out.reshape(b, -1, heads * dim_head)
420
+ )
421
+
422
+ return out
423
+
424
+ if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
425
+ SDP_BATCH_LIMIT = 2**15
426
+ else:
427
+ #TODO: other GPUs ?
428
+ SDP_BATCH_LIMIT = 2**31
429
+
430
+
431
+ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
432
+ if skip_reshape:
433
+ b, _, _, dim_head = q.shape
434
+ else:
435
+ b, _, dim_head = q.shape
436
+ dim_head //= heads
437
+ q, k, v = map(
438
+ lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
439
+ (q, k, v),
440
+ )
441
+
442
+ if mask is not None:
443
+ # add a batch dimension if there isn't already one
444
+ if mask.ndim == 2:
445
+ mask = mask.unsqueeze(0)
446
+ # add a heads dimension if there isn't already one
447
+ if mask.ndim == 3:
448
+ mask = mask.unsqueeze(1)
449
+
450
+ if SDP_BATCH_LIMIT >= b:
451
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
452
+ if not skip_output_reshape:
453
+ out = (
454
+ out.transpose(1, 2).reshape(b, -1, heads * dim_head)
455
+ )
456
+ else:
457
+ out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
458
+ for i in range(0, b, SDP_BATCH_LIMIT):
459
+ m = mask
460
+ if mask is not None:
461
+ if mask.shape[0] > 1:
462
+ m = mask[i : i + SDP_BATCH_LIMIT]
463
+
464
+ out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
465
+ q[i : i + SDP_BATCH_LIMIT],
466
+ k[i : i + SDP_BATCH_LIMIT],
467
+ v[i : i + SDP_BATCH_LIMIT],
468
+ attn_mask=m,
469
+ dropout_p=0.0, is_causal=False
470
+ ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
471
+ return out
472
+
473
+
474
+ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
475
+ if skip_reshape:
476
+ b, _, _, dim_head = q.shape
477
+ tensor_layout = "HND"
478
+ else:
479
+ b, _, dim_head = q.shape
480
+ dim_head //= heads
481
+ q, k, v = map(
482
+ lambda t: t.view(b, -1, heads, dim_head),
483
+ (q, k, v),
484
+ )
485
+ tensor_layout = "NHD"
486
+
487
+ if mask is not None:
488
+ # add a batch dimension if there isn't already one
489
+ if mask.ndim == 2:
490
+ mask = mask.unsqueeze(0)
491
+ # add a heads dimension if there isn't already one
492
+ if mask.ndim == 3:
493
+ mask = mask.unsqueeze(1)
494
+
495
+ try:
496
+ out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
497
+ except Exception as e:
498
+ logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
499
+ if tensor_layout == "NHD":
500
+ q, k, v = map(
501
+ lambda t: t.transpose(1, 2),
502
+ (q, k, v),
503
+ )
504
+ return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
505
+
506
+ if tensor_layout == "HND":
507
+ if not skip_output_reshape:
508
+ out = (
509
+ out.transpose(1, 2).reshape(b, -1, heads * dim_head)
510
+ )
511
+ else:
512
+ if skip_output_reshape:
513
+ out = out.transpose(1, 2)
514
+ else:
515
+ out = out.reshape(b, -1, heads * dim_head)
516
+ return out
517
+
518
+
519
+ try:
520
+ @torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
521
+ def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
522
+ dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
523
+ return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
524
+
525
+
526
+ @flash_attn_wrapper.register_fake
527
+ def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
528
+ # Output shape is the same as q
529
+ return q.new_empty(q.shape)
530
+ except AttributeError as error:
531
+ FLASH_ATTN_ERROR = error
532
+
533
+ def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
534
+ dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
535
+ assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
536
+
537
+
538
+ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
539
+ if skip_reshape:
540
+ b, _, _, dim_head = q.shape
541
+ else:
542
+ b, _, dim_head = q.shape
543
+ dim_head //= heads
544
+ q, k, v = map(
545
+ lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
546
+ (q, k, v),
547
+ )
548
+
549
+ if mask is not None:
550
+ # add a batch dimension if there isn't already one
551
+ if mask.ndim == 2:
552
+ mask = mask.unsqueeze(0)
553
+ # add a heads dimension if there isn't already one
554
+ if mask.ndim == 3:
555
+ mask = mask.unsqueeze(1)
556
+
557
+ try:
558
+ assert mask is None
559
+ out = flash_attn_wrapper(
560
+ q.transpose(1, 2),
561
+ k.transpose(1, 2),
562
+ v.transpose(1, 2),
563
+ dropout_p=0.0,
564
+ causal=False,
565
+ ).transpose(1, 2)
566
+ except Exception as e:
567
+ logging.warning(f"Flash Attention failed, using default SDPA: {e}")
568
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
569
+ if not skip_output_reshape:
570
+ out = (
571
+ out.transpose(1, 2).reshape(b, -1, heads * dim_head)
572
+ )
573
+ return out
574
+
575
+
576
+ optimized_attention = attention_basic
577
+
578
+ if model_management.sage_attention_enabled():
579
+ logging.info("Using sage attention")
580
+ optimized_attention = attention_sage
581
+ elif model_management.xformers_enabled():
582
+ logging.info("Using xformers attention")
583
+ optimized_attention = attention_xformers
584
+ elif model_management.flash_attention_enabled():
585
+ logging.info("Using Flash Attention")
586
+ optimized_attention = attention_flash
587
+ elif model_management.pytorch_attention_enabled():
588
+ logging.info("Using pytorch attention")
589
+ optimized_attention = attention_pytorch
590
+ else:
591
+ if args.use_split_cross_attention:
592
+ logging.info("Using split optimization for attention")
593
+ optimized_attention = attention_split
594
+ else:
595
+ logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
596
+ optimized_attention = attention_sub_quad
597
+
598
+ optimized_attention_masked = optimized_attention
599
+
600
+ def optimized_attention_for_device(device, mask=False, small_input=False):
601
+ if small_input:
602
+ if model_management.pytorch_attention_enabled():
603
+ return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
604
+ else:
605
+ return attention_basic
606
+
607
+ if device == torch.device("cpu"):
608
+ return attention_sub_quad
609
+
610
+ if mask:
611
+ return optimized_attention_masked
612
+
613
+ return optimized_attention
614
+
615
+
616
+ class CrossAttention(nn.Module):
617
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
618
+ super().__init__()
619
+ inner_dim = dim_head * heads
620
+ context_dim = default(context_dim, query_dim)
621
+ self.attn_precision = attn_precision
622
+
623
+ self.heads = heads
624
+ self.dim_head = dim_head
625
+
626
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
627
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
628
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
629
+
630
+ self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
631
+
632
+ def forward(self, x, context=None, value=None, mask=None):
633
+ q = self.to_q(x)
634
+ context = default(context, x)
635
+ k = self.to_k(context)
636
+ if value is not None:
637
+ v = self.to_v(value)
638
+ del value
639
+ else:
640
+ v = self.to_v(context)
641
+
642
+ if mask is None:
643
+ out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
644
+ else:
645
+ out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
646
+ return self.to_out(out)
647
+
648
+
649
+ class BasicTransformerBlock(nn.Module):
650
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
651
+ disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
652
+ super().__init__()
653
+
654
+ self.ff_in = ff_in or inner_dim is not None
655
+ if inner_dim is None:
656
+ inner_dim = dim
657
+
658
+ self.is_res = inner_dim == dim
659
+ self.attn_precision = attn_precision
660
+
661
+ if self.ff_in:
662
+ self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
663
+ self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
664
+
665
+ self.disable_self_attn = disable_self_attn
666
+ self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
667
+ context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
668
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
669
+
670
+ if disable_temporal_crossattention:
671
+ if switch_temporal_ca_to_sa:
672
+ raise ValueError
673
+ else:
674
+ self.attn2 = None
675
+ else:
676
+ context_dim_attn2 = None
677
+ if not switch_temporal_ca_to_sa:
678
+ context_dim_attn2 = context_dim
679
+
680
+ self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
681
+ heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
682
+ self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
683
+
684
+ self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
685
+ self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
686
+ self.n_heads = n_heads
687
+ self.d_head = d_head
688
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
689
+
690
+ def forward(self, x, context=None, transformer_options={}):
691
+ extra_options = {}
692
+ block = transformer_options.get("block", None)
693
+ block_index = transformer_options.get("block_index", 0)
694
+ transformer_patches = {}
695
+ transformer_patches_replace = {}
696
+
697
+ for k in transformer_options:
698
+ if k == "patches":
699
+ transformer_patches = transformer_options[k]
700
+ elif k == "patches_replace":
701
+ transformer_patches_replace = transformer_options[k]
702
+ else:
703
+ extra_options[k] = transformer_options[k]
704
+
705
+ extra_options["n_heads"] = self.n_heads
706
+ extra_options["dim_head"] = self.d_head
707
+ extra_options["attn_precision"] = self.attn_precision
708
+
709
+ if self.ff_in:
710
+ x_skip = x
711
+ x = self.ff_in(self.norm_in(x))
712
+ if self.is_res:
713
+ x += x_skip
714
+
715
+ n = self.norm1(x)
716
+ if self.disable_self_attn:
717
+ context_attn1 = context
718
+ else:
719
+ context_attn1 = None
720
+ value_attn1 = None
721
+
722
+ if "attn1_patch" in transformer_patches:
723
+ patch = transformer_patches["attn1_patch"]
724
+ if context_attn1 is None:
725
+ context_attn1 = n
726
+ value_attn1 = context_attn1
727
+ for p in patch:
728
+ n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
729
+
730
+ if block is not None:
731
+ transformer_block = (block[0], block[1], block_index)
732
+ else:
733
+ transformer_block = None
734
+ attn1_replace_patch = transformer_patches_replace.get("attn1", {})
735
+ block_attn1 = transformer_block
736
+ if block_attn1 not in attn1_replace_patch:
737
+ block_attn1 = block
738
+
739
+ if block_attn1 in attn1_replace_patch:
740
+ if context_attn1 is None:
741
+ context_attn1 = n
742
+ value_attn1 = n
743
+ n = self.attn1.to_q(n)
744
+ context_attn1 = self.attn1.to_k(context_attn1)
745
+ value_attn1 = self.attn1.to_v(value_attn1)
746
+ n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
747
+ n = self.attn1.to_out(n)
748
+ else:
749
+ n = self.attn1(n, context=context_attn1, value=value_attn1)
750
+
751
+ if "attn1_output_patch" in transformer_patches:
752
+ patch = transformer_patches["attn1_output_patch"]
753
+ for p in patch:
754
+ n = p(n, extra_options)
755
+
756
+ x = n + x
757
+ if "middle_patch" in transformer_patches:
758
+ patch = transformer_patches["middle_patch"]
759
+ for p in patch:
760
+ x = p(x, extra_options)
761
+
762
+ if self.attn2 is not None:
763
+ n = self.norm2(x)
764
+ if self.switch_temporal_ca_to_sa:
765
+ context_attn2 = n
766
+ else:
767
+ context_attn2 = context
768
+ value_attn2 = None
769
+ if "attn2_patch" in transformer_patches:
770
+ patch = transformer_patches["attn2_patch"]
771
+ value_attn2 = context_attn2
772
+ for p in patch:
773
+ n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
774
+
775
+ attn2_replace_patch = transformer_patches_replace.get("attn2", {})
776
+ block_attn2 = transformer_block
777
+ if block_attn2 not in attn2_replace_patch:
778
+ block_attn2 = block
779
+
780
+ if block_attn2 in attn2_replace_patch:
781
+ if value_attn2 is None:
782
+ value_attn2 = context_attn2
783
+ n = self.attn2.to_q(n)
784
+ context_attn2 = self.attn2.to_k(context_attn2)
785
+ value_attn2 = self.attn2.to_v(value_attn2)
786
+ n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
787
+ n = self.attn2.to_out(n)
788
+ else:
789
+ n = self.attn2(n, context=context_attn2, value=value_attn2)
790
+
791
+ if "attn2_output_patch" in transformer_patches:
792
+ patch = transformer_patches["attn2_output_patch"]
793
+ for p in patch:
794
+ n = p(n, extra_options)
795
+
796
+ x = n + x
797
+ if self.is_res:
798
+ x_skip = x
799
+ x = self.ff(self.norm3(x))
800
+ if self.is_res:
801
+ x = x_skip + x
802
+
803
+ return x
804
+
805
+
806
+ class SpatialTransformer(nn.Module):
807
+ """
808
+ Transformer block for image-like data.
809
+ First, project the input (aka embedding)
810
+ and reshape to b, t, d.
811
+ Then apply standard transformer action.
812
+ Finally, reshape to image
813
+ NEW: use_linear for more efficiency instead of the 1x1 convs
814
+ """
815
+ def __init__(self, in_channels, n_heads, d_head,
816
+ depth=1, dropout=0., context_dim=None,
817
+ disable_self_attn=False, use_linear=False,
818
+ use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
819
+ super().__init__()
820
+ if exists(context_dim) and not isinstance(context_dim, list):
821
+ context_dim = [context_dim] * depth
822
+ self.in_channels = in_channels
823
+ inner_dim = n_heads * d_head
824
+ self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
825
+ if not use_linear:
826
+ self.proj_in = operations.Conv2d(in_channels,
827
+ inner_dim,
828
+ kernel_size=1,
829
+ stride=1,
830
+ padding=0, dtype=dtype, device=device)
831
+ else:
832
+ self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
833
+
834
+ self.transformer_blocks = nn.ModuleList(
835
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
836
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
837
+ for d in range(depth)]
838
+ )
839
+ if not use_linear:
840
+ self.proj_out = operations.Conv2d(inner_dim,in_channels,
841
+ kernel_size=1,
842
+ stride=1,
843
+ padding=0, dtype=dtype, device=device)
844
+ else:
845
+ self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
846
+ self.use_linear = use_linear
847
+
848
+ def forward(self, x, context=None, transformer_options={}):
849
+ # note: if no context is given, cross-attention defaults to self-attention
850
+ if not isinstance(context, list):
851
+ context = [context] * len(self.transformer_blocks)
852
+ b, c, h, w = x.shape
853
+ transformer_options["activations_shape"] = list(x.shape)
854
+ x_in = x
855
+ x = self.norm(x)
856
+ if not self.use_linear:
857
+ x = self.proj_in(x)
858
+ x = x.movedim(1, 3).flatten(1, 2).contiguous()
859
+ if self.use_linear:
860
+ x = self.proj_in(x)
861
+ for i, block in enumerate(self.transformer_blocks):
862
+ transformer_options["block_index"] = i
863
+ x = block(x, context=context[i], transformer_options=transformer_options)
864
+ if self.use_linear:
865
+ x = self.proj_out(x)
866
+ x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
867
+ if not self.use_linear:
868
+ x = self.proj_out(x)
869
+ return x + x_in
870
+
871
+
872
+ class SpatialVideoTransformer(SpatialTransformer):
873
+ def __init__(
874
+ self,
875
+ in_channels,
876
+ n_heads,
877
+ d_head,
878
+ depth=1,
879
+ dropout=0.0,
880
+ use_linear=False,
881
+ context_dim=None,
882
+ use_spatial_context=False,
883
+ timesteps=None,
884
+ merge_strategy: str = "fixed",
885
+ merge_factor: float = 0.5,
886
+ time_context_dim=None,
887
+ ff_in=False,
888
+ checkpoint=False,
889
+ time_depth=1,
890
+ disable_self_attn=False,
891
+ disable_temporal_crossattention=False,
892
+ max_time_embed_period: int = 10000,
893
+ attn_precision=None,
894
+ dtype=None, device=None, operations=ops
895
+ ):
896
+ super().__init__(
897
+ in_channels,
898
+ n_heads,
899
+ d_head,
900
+ depth=depth,
901
+ dropout=dropout,
902
+ use_checkpoint=checkpoint,
903
+ context_dim=context_dim,
904
+ use_linear=use_linear,
905
+ disable_self_attn=disable_self_attn,
906
+ attn_precision=attn_precision,
907
+ dtype=dtype, device=device, operations=operations
908
+ )
909
+ self.time_depth = time_depth
910
+ self.depth = depth
911
+ self.max_time_embed_period = max_time_embed_period
912
+
913
+ time_mix_d_head = d_head
914
+ n_time_mix_heads = n_heads
915
+
916
+ time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
917
+
918
+ inner_dim = n_heads * d_head
919
+ if use_spatial_context:
920
+ time_context_dim = context_dim
921
+
922
+ self.time_stack = nn.ModuleList(
923
+ [
924
+ BasicTransformerBlock(
925
+ inner_dim,
926
+ n_time_mix_heads,
927
+ time_mix_d_head,
928
+ dropout=dropout,
929
+ context_dim=time_context_dim,
930
+ # timesteps=timesteps,
931
+ checkpoint=checkpoint,
932
+ ff_in=ff_in,
933
+ inner_dim=time_mix_inner_dim,
934
+ disable_self_attn=disable_self_attn,
935
+ disable_temporal_crossattention=disable_temporal_crossattention,
936
+ attn_precision=attn_precision,
937
+ dtype=dtype, device=device, operations=operations
938
+ )
939
+ for _ in range(self.depth)
940
+ ]
941
+ )
942
+
943
+ assert len(self.time_stack) == len(self.transformer_blocks)
944
+
945
+ self.use_spatial_context = use_spatial_context
946
+ self.in_channels = in_channels
947
+
948
+ time_embed_dim = self.in_channels * 4
949
+ self.time_pos_embed = nn.Sequential(
950
+ operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
951
+ nn.SiLU(),
952
+ operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
953
+ )
954
+
955
+ self.time_mixer = AlphaBlender(
956
+ alpha=merge_factor, merge_strategy=merge_strategy
957
+ )
958
+
959
+ def forward(
960
+ self,
961
+ x: torch.Tensor,
962
+ context: Optional[torch.Tensor] = None,
963
+ time_context: Optional[torch.Tensor] = None,
964
+ timesteps: Optional[int] = None,
965
+ image_only_indicator: Optional[torch.Tensor] = None,
966
+ transformer_options={}
967
+ ) -> torch.Tensor:
968
+ _, _, h, w = x.shape
969
+ transformer_options["activations_shape"] = list(x.shape)
970
+ x_in = x
971
+ spatial_context = None
972
+ if exists(context):
973
+ spatial_context = context
974
+
975
+ if self.use_spatial_context:
976
+ assert (
977
+ context.ndim == 3
978
+ ), f"n dims of spatial context should be 3 but are {context.ndim}"
979
+
980
+ if time_context is None:
981
+ time_context = context
982
+ time_context_first_timestep = time_context[::timesteps]
983
+ time_context = repeat(
984
+ time_context_first_timestep, "b ... -> (b n) ...", n=h * w
985
+ )
986
+ elif time_context is not None and not self.use_spatial_context:
987
+ time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
988
+ if time_context.ndim == 2:
989
+ time_context = rearrange(time_context, "b c -> b 1 c")
990
+
991
+ x = self.norm(x)
992
+ if not self.use_linear:
993
+ x = self.proj_in(x)
994
+ x = rearrange(x, "b c h w -> b (h w) c")
995
+ if self.use_linear:
996
+ x = self.proj_in(x)
997
+
998
+ num_frames = torch.arange(timesteps, device=x.device)
999
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
1000
+ num_frames = rearrange(num_frames, "b t -> (b t)")
1001
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
1002
+ emb = self.time_pos_embed(t_emb)
1003
+ emb = emb[:, None, :]
1004
+
1005
+ for it_, (block, mix_block) in enumerate(
1006
+ zip(self.transformer_blocks, self.time_stack)
1007
+ ):
1008
+ transformer_options["block_index"] = it_
1009
+ x = block(
1010
+ x,
1011
+ context=spatial_context,
1012
+ transformer_options=transformer_options,
1013
+ )
1014
+
1015
+ x_mix = x
1016
+ x_mix = x_mix + emb
1017
+
1018
+ B, S, C = x_mix.shape
1019
+ x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
1020
+ x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
1021
+ x_mix = rearrange(
1022
+ x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
1023
+ )
1024
+
1025
+ x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
1026
+
1027
+ if self.use_linear:
1028
+ x = self.proj_out(x)
1029
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
1030
+ if not self.use_linear:
1031
+ x = self.proj_out(x)
1032
+ out = x + x_in
1033
+ return out
1034
+
1035
+
ComfyUI/comfy/ldm/modules/ema.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1, dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ # remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.', '')
20
+ self.m_name2s_name.update({name: s_name})
21
+ self.register_buffer(s_name, p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def reset_num_updates(self):
26
+ del self.num_updates
27
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28
+
29
+ def forward(self, model):
30
+ decay = self.decay
31
+
32
+ if self.num_updates >= 0:
33
+ self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
+
36
+ one_minus_decay = 1.0 - decay
37
+
38
+ with torch.no_grad():
39
+ m_param = dict(model.named_parameters())
40
+ shadow_params = dict(self.named_buffers())
41
+
42
+ for key in m_param:
43
+ if m_param[key].requires_grad:
44
+ sname = self.m_name2s_name[key]
45
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47
+ else:
48
+ assert not key in self.m_name2s_name
49
+
50
+ def copy_to(self, model):
51
+ m_param = dict(model.named_parameters())
52
+ shadow_params = dict(self.named_buffers())
53
+ for key in m_param:
54
+ if m_param[key].requires_grad:
55
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def store(self, parameters):
60
+ """
61
+ Save the current parameters for restoring later.
62
+ Args:
63
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64
+ temporarily stored.
65
+ """
66
+ self.collected_params = [param.clone() for param in parameters]
67
+
68
+ def restore(self, parameters):
69
+ """
70
+ Restore the parameters stored with the `store` method.
71
+ Useful to validate the model with EMA parameters without affecting the
72
+ original optimization process. Store the parameters before the
73
+ `copy_to` method. After validation (or model saving), use this to
74
+ restore the former parameters.
75
+ Args:
76
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77
+ updated with the stored parameters.
78
+ """
79
+ for c_param, param in zip(self.collected_params, parameters):
80
+ param.data.copy_(c_param.data)
ComfyUI/comfy/ldm/modules/sub_quadratic_attention.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original source:
2
+ # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
3
+ # license:
4
+ # MIT
5
+ # credit:
6
+ # Amin Rezaei (original author)
7
+ # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
8
+ # implementation of:
9
+ # Self-attention Does Not Need O(n2) Memory":
10
+ # https://arxiv.org/abs/2112.05682v2
11
+
12
+ from functools import partial
13
+ import torch
14
+ from torch import Tensor
15
+ from torch.utils.checkpoint import checkpoint
16
+ import math
17
+ import logging
18
+
19
+ try:
20
+ from typing import Optional, NamedTuple, List, Protocol
21
+ except ImportError:
22
+ from typing import Optional, NamedTuple, List
23
+ from typing_extensions import Protocol
24
+
25
+ from typing import List
26
+
27
+ from comfy import model_management
28
+
29
+ def dynamic_slice(
30
+ x: Tensor,
31
+ starts: List[int],
32
+ sizes: List[int],
33
+ ) -> Tensor:
34
+ slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
35
+ return x[slicing]
36
+
37
+ class AttnChunk(NamedTuple):
38
+ exp_values: Tensor
39
+ exp_weights_sum: Tensor
40
+ max_score: Tensor
41
+
42
+ class SummarizeChunk(Protocol):
43
+ @staticmethod
44
+ def __call__(
45
+ query: Tensor,
46
+ key_t: Tensor,
47
+ value: Tensor,
48
+ ) -> AttnChunk: ...
49
+
50
+ class ComputeQueryChunkAttn(Protocol):
51
+ @staticmethod
52
+ def __call__(
53
+ query: Tensor,
54
+ key_t: Tensor,
55
+ value: Tensor,
56
+ ) -> Tensor: ...
57
+
58
+ def _summarize_chunk(
59
+ query: Tensor,
60
+ key_t: Tensor,
61
+ value: Tensor,
62
+ scale: float,
63
+ upcast_attention: bool,
64
+ mask,
65
+ ) -> AttnChunk:
66
+ if upcast_attention:
67
+ with torch.autocast(enabled=False, device_type = 'cuda'):
68
+ query = query.float()
69
+ key_t = key_t.float()
70
+ attn_weights = torch.baddbmm(
71
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
72
+ query,
73
+ key_t,
74
+ alpha=scale,
75
+ beta=0,
76
+ )
77
+ else:
78
+ attn_weights = torch.baddbmm(
79
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
80
+ query,
81
+ key_t,
82
+ alpha=scale,
83
+ beta=0,
84
+ )
85
+ max_score, _ = torch.max(attn_weights, -1, keepdim=True)
86
+ max_score = max_score.detach()
87
+ attn_weights -= max_score
88
+ if mask is not None:
89
+ attn_weights += mask
90
+ torch.exp(attn_weights, out=attn_weights)
91
+ exp_weights = attn_weights.to(value.dtype)
92
+ exp_values = torch.bmm(exp_weights, value)
93
+ max_score = max_score.squeeze(-1)
94
+ return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
95
+
96
+ def _query_chunk_attention(
97
+ query: Tensor,
98
+ key_t: Tensor,
99
+ value: Tensor,
100
+ summarize_chunk: SummarizeChunk,
101
+ kv_chunk_size: int,
102
+ mask,
103
+ ) -> Tensor:
104
+ batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
105
+ _, _, v_channels_per_head = value.shape
106
+
107
+ def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
108
+ key_chunk = dynamic_slice(
109
+ key_t,
110
+ (0, 0, chunk_idx),
111
+ (batch_x_heads, k_channels_per_head, kv_chunk_size)
112
+ )
113
+ value_chunk = dynamic_slice(
114
+ value,
115
+ (0, chunk_idx, 0),
116
+ (batch_x_heads, kv_chunk_size, v_channels_per_head)
117
+ )
118
+ if mask is not None:
119
+ mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
120
+
121
+ return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
122
+
123
+ chunks: List[AttnChunk] = [
124
+ chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
125
+ ]
126
+ acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
127
+ chunk_values, chunk_weights, chunk_max = acc_chunk
128
+
129
+ global_max, _ = torch.max(chunk_max, 0, keepdim=True)
130
+ max_diffs = torch.exp(chunk_max - global_max)
131
+ chunk_values *= torch.unsqueeze(max_diffs, -1)
132
+ chunk_weights *= max_diffs
133
+
134
+ all_values = chunk_values.sum(dim=0)
135
+ all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
136
+ return all_values / all_weights
137
+
138
+ # TODO: refactor CrossAttention#get_attention_scores to share code with this
139
+ def _get_attention_scores_no_kv_chunking(
140
+ query: Tensor,
141
+ key_t: Tensor,
142
+ value: Tensor,
143
+ scale: float,
144
+ upcast_attention: bool,
145
+ mask,
146
+ ) -> Tensor:
147
+ if upcast_attention:
148
+ with torch.autocast(enabled=False, device_type = 'cuda'):
149
+ query = query.float()
150
+ key_t = key_t.float()
151
+ attn_scores = torch.baddbmm(
152
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
153
+ query,
154
+ key_t,
155
+ alpha=scale,
156
+ beta=0,
157
+ )
158
+ else:
159
+ attn_scores = torch.baddbmm(
160
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
161
+ query,
162
+ key_t,
163
+ alpha=scale,
164
+ beta=0,
165
+ )
166
+
167
+ if mask is not None:
168
+ attn_scores += mask
169
+ try:
170
+ attn_probs = attn_scores.softmax(dim=-1)
171
+ del attn_scores
172
+ except model_management.OOM_EXCEPTION:
173
+ logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
174
+ attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
175
+ torch.exp(attn_scores, out=attn_scores)
176
+ summed = torch.sum(attn_scores, dim=-1, keepdim=True)
177
+ attn_scores /= summed
178
+ attn_probs = attn_scores
179
+
180
+ hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
181
+ return hidden_states_slice
182
+
183
+ class ScannedChunk(NamedTuple):
184
+ chunk_idx: int
185
+ attn_chunk: AttnChunk
186
+
187
+ def efficient_dot_product_attention(
188
+ query: Tensor,
189
+ key_t: Tensor,
190
+ value: Tensor,
191
+ query_chunk_size=1024,
192
+ kv_chunk_size: Optional[int] = None,
193
+ kv_chunk_size_min: Optional[int] = None,
194
+ use_checkpoint=True,
195
+ upcast_attention=False,
196
+ mask = None,
197
+ ):
198
+ """Computes efficient dot-product attention given query, transposed key, and value.
199
+ This is efficient version of attention presented in
200
+ https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
201
+ Args:
202
+ query: queries for calculating attention with shape of
203
+ `[batch * num_heads, tokens, channels_per_head]`.
204
+ key_t: keys for calculating attention with shape of
205
+ `[batch * num_heads, channels_per_head, tokens]`.
206
+ value: values to be used in attention with shape of
207
+ `[batch * num_heads, tokens, channels_per_head]`.
208
+ query_chunk_size: int: query chunks size
209
+ kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
210
+ kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
211
+ use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
212
+ Returns:
213
+ Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
214
+ """
215
+ batch_x_heads, q_tokens, q_channels_per_head = query.shape
216
+ _, _, k_tokens = key_t.shape
217
+ scale = q_channels_per_head ** -0.5
218
+
219
+ kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
220
+ if kv_chunk_size_min is not None:
221
+ kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
222
+
223
+ if mask is not None and len(mask.shape) == 2:
224
+ mask = mask.unsqueeze(0)
225
+
226
+ def get_query_chunk(chunk_idx: int) -> Tensor:
227
+ return dynamic_slice(
228
+ query,
229
+ (0, chunk_idx, 0),
230
+ (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
231
+ )
232
+
233
+ def get_mask_chunk(chunk_idx: int) -> Tensor:
234
+ if mask is None:
235
+ return None
236
+ if mask.shape[1] == 1:
237
+ return mask
238
+ chunk = min(query_chunk_size, q_tokens)
239
+ return mask[:,chunk_idx:chunk_idx + chunk]
240
+
241
+ summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
242
+ summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
243
+ compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
244
+ _get_attention_scores_no_kv_chunking,
245
+ scale=scale,
246
+ upcast_attention=upcast_attention
247
+ ) if k_tokens <= kv_chunk_size else (
248
+ # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
249
+ partial(
250
+ _query_chunk_attention,
251
+ kv_chunk_size=kv_chunk_size,
252
+ summarize_chunk=summarize_chunk,
253
+ )
254
+ )
255
+
256
+ if q_tokens <= query_chunk_size:
257
+ # fast-path for when there's just 1 query chunk
258
+ return compute_query_chunk_attn(
259
+ query=query,
260
+ key_t=key_t,
261
+ value=value,
262
+ mask=mask,
263
+ )
264
+
265
+ # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
266
+ # and pass slices to be mutated, instead of torch.cat()ing the returned slices
267
+ res = torch.cat([
268
+ compute_query_chunk_attn(
269
+ query=get_query_chunk(i * query_chunk_size),
270
+ key_t=key_t,
271
+ value=value,
272
+ mask=get_mask_chunk(i * query_chunk_size)
273
+ ) for i in range(math.ceil(q_tokens / query_chunk_size))
274
+ ], dim=1)
275
+ return res
ComfyUI/comfy/ldm/modules/temporal_ae.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Iterable, Union
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+
7
+ import comfy.ops
8
+ ops = comfy.ops.disable_weight_init
9
+
10
+ from .diffusionmodules.model import (
11
+ AttnBlock,
12
+ Decoder,
13
+ ResnetBlock,
14
+ )
15
+ from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
16
+ from .attention import BasicTransformerBlock
17
+
18
+ def partialclass(cls, *args, **kwargs):
19
+ class NewCls(cls):
20
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
21
+
22
+ return NewCls
23
+
24
+
25
+ class VideoResBlock(ResnetBlock):
26
+ def __init__(
27
+ self,
28
+ out_channels,
29
+ *args,
30
+ dropout=0.0,
31
+ video_kernel_size=3,
32
+ alpha=0.0,
33
+ merge_strategy="learned",
34
+ **kwargs,
35
+ ):
36
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
37
+ if video_kernel_size is None:
38
+ video_kernel_size = [3, 1, 1]
39
+ self.time_stack = ResBlock(
40
+ channels=out_channels,
41
+ emb_channels=0,
42
+ dropout=dropout,
43
+ dims=3,
44
+ use_scale_shift_norm=False,
45
+ use_conv=False,
46
+ up=False,
47
+ down=False,
48
+ kernel_size=video_kernel_size,
49
+ use_checkpoint=False,
50
+ skip_t_emb=True,
51
+ )
52
+
53
+ self.merge_strategy = merge_strategy
54
+ if self.merge_strategy == "fixed":
55
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
56
+ elif self.merge_strategy == "learned":
57
+ self.register_parameter(
58
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
59
+ )
60
+ else:
61
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
62
+
63
+ def get_alpha(self, bs):
64
+ if self.merge_strategy == "fixed":
65
+ return self.mix_factor
66
+ elif self.merge_strategy == "learned":
67
+ return torch.sigmoid(self.mix_factor)
68
+ else:
69
+ raise NotImplementedError()
70
+
71
+ def forward(self, x, temb, skip_video=False, timesteps=None):
72
+ b, c, h, w = x.shape
73
+ if timesteps is None:
74
+ timesteps = b
75
+
76
+ x = super().forward(x, temb)
77
+
78
+ if not skip_video:
79
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
80
+
81
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
82
+
83
+ x = self.time_stack(x, temb)
84
+
85
+ alpha = self.get_alpha(bs=b // timesteps).to(x.device)
86
+ x = alpha * x + (1.0 - alpha) * x_mix
87
+
88
+ x = rearrange(x, "b c t h w -> (b t) c h w")
89
+ return x
90
+
91
+
92
+ class AE3DConv(ops.Conv2d):
93
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
94
+ super().__init__(in_channels, out_channels, *args, **kwargs)
95
+ if isinstance(video_kernel_size, Iterable):
96
+ padding = [int(k // 2) for k in video_kernel_size]
97
+ else:
98
+ padding = int(video_kernel_size // 2)
99
+
100
+ self.time_mix_conv = ops.Conv3d(
101
+ in_channels=out_channels,
102
+ out_channels=out_channels,
103
+ kernel_size=video_kernel_size,
104
+ padding=padding,
105
+ )
106
+
107
+ def forward(self, input, timesteps=None, skip_video=False):
108
+ if timesteps is None:
109
+ timesteps = input.shape[0]
110
+ x = super().forward(input)
111
+ if skip_video:
112
+ return x
113
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
114
+ x = self.time_mix_conv(x)
115
+ return rearrange(x, "b c t h w -> (b t) c h w")
116
+
117
+
118
+ class AttnVideoBlock(AttnBlock):
119
+ def __init__(
120
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
121
+ ):
122
+ super().__init__(in_channels)
123
+ # no context, single headed, as in base class
124
+ self.time_mix_block = BasicTransformerBlock(
125
+ dim=in_channels,
126
+ n_heads=1,
127
+ d_head=in_channels,
128
+ checkpoint=False,
129
+ ff_in=True,
130
+ )
131
+
132
+ time_embed_dim = self.in_channels * 4
133
+ self.video_time_embed = torch.nn.Sequential(
134
+ ops.Linear(self.in_channels, time_embed_dim),
135
+ torch.nn.SiLU(),
136
+ ops.Linear(time_embed_dim, self.in_channels),
137
+ )
138
+
139
+ self.merge_strategy = merge_strategy
140
+ if self.merge_strategy == "fixed":
141
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
142
+ elif self.merge_strategy == "learned":
143
+ self.register_parameter(
144
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
145
+ )
146
+ else:
147
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
148
+
149
+ def forward(self, x, timesteps=None, skip_time_block=False):
150
+ if skip_time_block:
151
+ return super().forward(x)
152
+
153
+ if timesteps is None:
154
+ timesteps = x.shape[0]
155
+
156
+ x_in = x
157
+ x = self.attention(x)
158
+ h, w = x.shape[2:]
159
+ x = rearrange(x, "b c h w -> b (h w) c")
160
+
161
+ x_mix = x
162
+ num_frames = torch.arange(timesteps, device=x.device)
163
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
164
+ num_frames = rearrange(num_frames, "b t -> (b t)")
165
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
166
+ emb = self.video_time_embed(t_emb) # b, n_channels
167
+ emb = emb[:, None, :]
168
+ x_mix = x_mix + emb
169
+
170
+ alpha = self.get_alpha().to(x.device)
171
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
172
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
173
+
174
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
175
+ x = self.proj_out(x)
176
+
177
+ return x_in + x
178
+
179
+ def get_alpha(
180
+ self,
181
+ ):
182
+ if self.merge_strategy == "fixed":
183
+ return self.mix_factor
184
+ elif self.merge_strategy == "learned":
185
+ return torch.sigmoid(self.mix_factor)
186
+ else:
187
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
188
+
189
+
190
+
191
+ def make_time_attn(
192
+ in_channels,
193
+ attn_type="vanilla",
194
+ attn_kwargs=None,
195
+ alpha: float = 0,
196
+ merge_strategy: str = "learned",
197
+ conv_op=ops.Conv2d,
198
+ ):
199
+ return partialclass(
200
+ AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
201
+ )
202
+
203
+
204
+ class Conv2DWrapper(torch.nn.Conv2d):
205
+ def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
206
+ return super().forward(input)
207
+
208
+
209
+ class VideoDecoder(Decoder):
210
+ available_time_modes = ["all", "conv-only", "attn-only"]
211
+
212
+ def __init__(
213
+ self,
214
+ *args,
215
+ video_kernel_size: Union[int, list] = 3,
216
+ alpha: float = 0.0,
217
+ merge_strategy: str = "learned",
218
+ time_mode: str = "conv-only",
219
+ **kwargs,
220
+ ):
221
+ self.video_kernel_size = video_kernel_size
222
+ self.alpha = alpha
223
+ self.merge_strategy = merge_strategy
224
+ self.time_mode = time_mode
225
+ assert (
226
+ self.time_mode in self.available_time_modes
227
+ ), f"time_mode parameter has to be in {self.available_time_modes}"
228
+
229
+ if self.time_mode != "attn-only":
230
+ kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
231
+ if self.time_mode not in ["conv-only", "only-last-conv"]:
232
+ kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
233
+ if self.time_mode not in ["attn-only", "only-last-conv"]:
234
+ kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
235
+
236
+ super().__init__(*args, **kwargs)
237
+
238
+ def get_last_layer(self, skip_time_mix=False, **kwargs):
239
+ if self.time_mode == "attn-only":
240
+ raise NotImplementedError("TODO")
241
+ else:
242
+ return (
243
+ self.conv_out.time_mix_conv.weight
244
+ if not skip_time_mix
245
+ else self.conv_out.weight
246
+ )
ComfyUI/comfy/ldm/omnigen/omnigen2.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code: https://github.com/VectorSpaceLab/OmniGen2
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from comfy.ldm.lightricks.model import Timesteps
10
+ from comfy.ldm.flux.layers import EmbedND
11
+ from comfy.ldm.modules.attention import optimized_attention_masked
12
+ import comfy.model_management
13
+ import comfy.ldm.common_dit
14
+
15
+
16
+ def apply_rotary_emb(x, freqs_cis):
17
+ if x.shape[1] == 0:
18
+ return x
19
+
20
+ t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
21
+ t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
22
+ return t_out.reshape(*x.shape).to(dtype=x.dtype)
23
+
24
+
25
+ def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
26
+ return F.silu(x) * y
27
+
28
+
29
+ class TimestepEmbedding(nn.Module):
30
+ def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
31
+ super().__init__()
32
+ self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
33
+ self.act = nn.SiLU()
34
+ self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
35
+
36
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
37
+ sample = self.linear_1(sample)
38
+ sample = self.act(sample)
39
+ sample = self.linear_2(sample)
40
+ return sample
41
+
42
+
43
+ class LuminaRMSNormZero(nn.Module):
44
+ def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
45
+ super().__init__()
46
+ self.silu = nn.SiLU()
47
+ self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
48
+ self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
49
+
50
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ emb = self.linear(self.silu(emb))
52
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
53
+ x = self.norm(x) * (1 + scale_msa[:, None])
54
+ return x, gate_msa, scale_mlp, gate_mlp
55
+
56
+
57
+ class LuminaLayerNormContinuous(nn.Module):
58
+ def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
59
+ super().__init__()
60
+ self.silu = nn.SiLU()
61
+ self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
62
+ self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
63
+ self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
64
+
65
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
66
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
67
+ x = self.norm(x) * (1 + emb)[:, None, :]
68
+ if self.linear_2 is not None:
69
+ x = self.linear_2(x)
70
+ return x
71
+
72
+
73
+ class LuminaFeedForward(nn.Module):
74
+ def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
75
+ super().__init__()
76
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
77
+ self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
78
+ self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
79
+ self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ h1, h2 = self.linear_1(x), self.linear_3(x)
83
+ return self.linear_2(swiglu(h1, h2))
84
+
85
+
86
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
87
+ def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
88
+ super().__init__()
89
+ self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
90
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
91
+ self.caption_embedder = nn.Sequential(
92
+ operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
93
+ operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
94
+ )
95
+
96
+ def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
97
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
98
+ time_embed = self.timestep_embedder(timestep_proj)
99
+ caption_embed = self.caption_embedder(text_hidden_states)
100
+ return time_embed, caption_embed
101
+
102
+
103
+ class Attention(nn.Module):
104
+ def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
105
+ super().__init__()
106
+ self.heads = heads
107
+ self.kv_heads = kv_heads
108
+ self.dim_head = dim_head
109
+ self.scale = dim_head ** -0.5
110
+
111
+ self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
112
+ self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
113
+ self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
114
+
115
+ self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
116
+ self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
117
+
118
+ self.to_out = nn.Sequential(
119
+ operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
120
+ nn.Dropout(0.0)
121
+ )
122
+
123
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
124
+ batch_size, sequence_length, _ = hidden_states.shape
125
+
126
+ query = self.to_q(hidden_states)
127
+ key = self.to_k(encoder_hidden_states)
128
+ value = self.to_v(encoder_hidden_states)
129
+
130
+ query = query.view(batch_size, -1, self.heads, self.dim_head)
131
+ key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
132
+ value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
133
+
134
+ query = self.norm_q(query)
135
+ key = self.norm_k(key)
136
+
137
+ if image_rotary_emb is not None:
138
+ query = apply_rotary_emb(query, image_rotary_emb)
139
+ key = apply_rotary_emb(key, image_rotary_emb)
140
+
141
+ query = query.transpose(1, 2)
142
+ key = key.transpose(1, 2)
143
+ value = value.transpose(1, 2)
144
+
145
+ if self.kv_heads < self.heads:
146
+ key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
147
+ value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
148
+
149
+ hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
150
+ hidden_states = self.to_out[0](hidden_states)
151
+ return hidden_states
152
+
153
+
154
+ class OmniGen2TransformerBlock(nn.Module):
155
+ def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
156
+ super().__init__()
157
+ self.modulation = modulation
158
+
159
+ self.attn = Attention(
160
+ query_dim=dim,
161
+ dim_head=dim // num_attention_heads,
162
+ heads=num_attention_heads,
163
+ kv_heads=num_kv_heads,
164
+ eps=1e-5,
165
+ bias=False,
166
+ dtype=dtype, device=device, operations=operations,
167
+ )
168
+
169
+ self.feed_forward = LuminaFeedForward(
170
+ dim=dim,
171
+ inner_dim=4 * dim,
172
+ multiple_of=multiple_of,
173
+ dtype=dtype, device=device, operations=operations
174
+ )
175
+
176
+ if modulation:
177
+ self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
178
+ else:
179
+ self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
180
+
181
+ self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
182
+ self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
183
+ self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
184
+
185
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
186
+ if self.modulation:
187
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
188
+ attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
189
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
190
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
191
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
192
+ else:
193
+ norm_hidden_states = self.norm1(hidden_states)
194
+ attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
195
+ hidden_states = hidden_states + self.norm2(attn_output)
196
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
197
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
198
+ return hidden_states
199
+
200
+
201
+ class OmniGen2RotaryPosEmbed(nn.Module):
202
+ def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
203
+ super().__init__()
204
+ self.theta = theta
205
+ self.axes_dim = axes_dim
206
+ self.axes_lens = axes_lens
207
+ self.patch_size = patch_size
208
+ self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
209
+
210
+ def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
211
+ p = self.patch_size
212
+
213
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
214
+
215
+ max_seq_len = max(seq_lengths)
216
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
217
+ max_img_len = max(l_effective_img_len)
218
+
219
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
220
+
221
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
222
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
223
+
224
+ pe_shift = cap_seq_len
225
+ pe_shift_len = cap_seq_len
226
+
227
+ if ref_img_sizes[i] is not None:
228
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
229
+ H, W = ref_img_size
230
+ ref_H_tokens, ref_W_tokens = H // p, W // p
231
+
232
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
233
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
234
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
235
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
236
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
237
+
238
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
239
+ pe_shift_len += ref_img_len
240
+
241
+ H, W = img_sizes[i]
242
+ H_tokens, W_tokens = H // p, W // p
243
+
244
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
245
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
246
+
247
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
248
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
249
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
250
+
251
+ freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
252
+
253
+ cap_freqs_cis_shape = list(freqs_cis.shape)
254
+ cap_freqs_cis_shape[1] = encoder_seq_len
255
+ cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
256
+
257
+ ref_img_freqs_cis_shape = list(freqs_cis.shape)
258
+ ref_img_freqs_cis_shape[1] = max_ref_img_len
259
+ ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
260
+
261
+ img_freqs_cis_shape = list(freqs_cis.shape)
262
+ img_freqs_cis_shape[1] = max_img_len
263
+ img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
264
+
265
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
266
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
267
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
268
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
269
+
270
+ return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
271
+
272
+
273
+ class OmniGen2Transformer2DModel(nn.Module):
274
+ def __init__(
275
+ self,
276
+ patch_size: int = 2,
277
+ in_channels: int = 16,
278
+ out_channels: Optional[int] = None,
279
+ hidden_size: int = 2304,
280
+ num_layers: int = 26,
281
+ num_refiner_layers: int = 2,
282
+ num_attention_heads: int = 24,
283
+ num_kv_heads: int = 8,
284
+ multiple_of: int = 256,
285
+ ffn_dim_multiplier: Optional[float] = None,
286
+ norm_eps: float = 1e-5,
287
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
288
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
289
+ text_feat_dim: int = 1024,
290
+ timestep_scale: float = 1.0,
291
+ image_model=None,
292
+ device=None,
293
+ dtype=None,
294
+ operations=None,
295
+ ):
296
+ super().__init__()
297
+
298
+ self.patch_size = patch_size
299
+ self.out_channels = out_channels or in_channels
300
+ self.hidden_size = hidden_size
301
+ self.dtype = dtype
302
+
303
+ self.rope_embedder = OmniGen2RotaryPosEmbed(
304
+ theta=10000,
305
+ axes_dim=axes_dim_rope,
306
+ axes_lens=axes_lens,
307
+ patch_size=patch_size,
308
+ )
309
+
310
+ self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
311
+ self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
312
+
313
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
314
+ hidden_size=hidden_size,
315
+ text_feat_dim=text_feat_dim,
316
+ norm_eps=norm_eps,
317
+ timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
318
+ )
319
+
320
+ self.noise_refiner = nn.ModuleList([
321
+ OmniGen2TransformerBlock(
322
+ hidden_size, num_attention_heads, num_kv_heads,
323
+ multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
324
+ ) for _ in range(num_refiner_layers)
325
+ ])
326
+
327
+ self.ref_image_refiner = nn.ModuleList([
328
+ OmniGen2TransformerBlock(
329
+ hidden_size, num_attention_heads, num_kv_heads,
330
+ multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
331
+ ) for _ in range(num_refiner_layers)
332
+ ])
333
+
334
+ self.context_refiner = nn.ModuleList([
335
+ OmniGen2TransformerBlock(
336
+ hidden_size, num_attention_heads, num_kv_heads,
337
+ multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
338
+ ) for _ in range(num_refiner_layers)
339
+ ])
340
+
341
+ self.layers = nn.ModuleList([
342
+ OmniGen2TransformerBlock(
343
+ hidden_size, num_attention_heads, num_kv_heads,
344
+ multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
345
+ ) for _ in range(num_layers)
346
+ ])
347
+
348
+ self.norm_out = LuminaLayerNormContinuous(
349
+ embedding_dim=hidden_size,
350
+ conditioning_embedding_dim=min(hidden_size, 1024),
351
+ elementwise_affine=False,
352
+ eps=1e-6,
353
+ out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
354
+ )
355
+
356
+ self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
357
+
358
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
359
+ batch_size = len(hidden_states)
360
+ p = self.patch_size
361
+
362
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
363
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
364
+
365
+ if ref_image_hidden_states is not None:
366
+ ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
367
+ ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
368
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
369
+ else:
370
+ ref_img_sizes = [None for _ in range(batch_size)]
371
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
372
+
373
+ flat_ref_img_hidden_states = None
374
+ if ref_image_hidden_states is not None:
375
+ imgs = []
376
+ for ref_img in ref_image_hidden_states:
377
+ B, C, H, W = ref_img.size()
378
+ ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
379
+ imgs.append(ref_img)
380
+ flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
381
+
382
+ img = hidden_states
383
+ B, C, H, W = img.size()
384
+ flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
385
+
386
+ return (
387
+ flat_hidden_states, flat_ref_img_hidden_states,
388
+ None, None,
389
+ l_effective_ref_img_len, l_effective_img_len,
390
+ ref_img_sizes, img_sizes,
391
+ )
392
+
393
+ def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
394
+ batch_size = len(hidden_states)
395
+
396
+ hidden_states = self.x_embedder(hidden_states)
397
+ if ref_image_hidden_states is not None:
398
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
399
+ image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
400
+
401
+ for i in range(batch_size):
402
+ shift = 0
403
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
404
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
405
+ shift += ref_img_len
406
+
407
+ for layer in self.noise_refiner:
408
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
409
+
410
+ if ref_image_hidden_states is not None:
411
+ for layer in self.ref_image_refiner:
412
+ ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
413
+
414
+ hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
415
+
416
+ return hidden_states
417
+
418
+ def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
419
+ B, C, H, W = x.shape
420
+ hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
421
+ _, _, H_padded, W_padded = hidden_states.shape
422
+ timestep = 1.0 - timesteps
423
+ text_hidden_states = context
424
+ text_attention_mask = attention_mask
425
+ ref_image_hidden_states = ref_latents
426
+ device = hidden_states.device
427
+
428
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
429
+
430
+ (
431
+ hidden_states, ref_image_hidden_states,
432
+ img_mask, ref_img_mask,
433
+ l_effective_ref_img_len, l_effective_img_len,
434
+ ref_img_sizes, img_sizes,
435
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
436
+
437
+ (
438
+ context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
439
+ rotary_emb, encoder_seq_lengths, seq_lengths,
440
+ ) = self.rope_embedder(
441
+ hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
442
+ l_effective_ref_img_len, l_effective_img_len,
443
+ ref_img_sizes, img_sizes, device,
444
+ )
445
+
446
+ for layer in self.context_refiner:
447
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
448
+
449
+ img_len = hidden_states.shape[1]
450
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
451
+ hidden_states, ref_image_hidden_states,
452
+ img_mask, ref_img_mask,
453
+ noise_rotary_emb, ref_img_rotary_emb,
454
+ l_effective_ref_img_len, l_effective_img_len,
455
+ temb,
456
+ )
457
+
458
+ hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
459
+ attention_mask = None
460
+
461
+ for layer in self.layers:
462
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
463
+
464
+ hidden_states = self.norm_out(hidden_states, temb)
465
+
466
+ p = self.patch_size
467
+ output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
468
+
469
+ return -output
ComfyUI/comfy/ldm/pixart/pixartms.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on:
2
+ # https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
3
+ # https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .blocks import (
8
+ t2i_modulate,
9
+ CaptionEmbedder,
10
+ AttentionKVCompress,
11
+ MultiHeadCrossAttention,
12
+ T2IFinalLayer,
13
+ SizeEmbedder,
14
+ )
15
+ from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
16
+
17
+
18
+ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
19
+ grid_h, grid_w = torch.meshgrid(
20
+ torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
21
+ torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
22
+ indexing='ij'
23
+ )
24
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
25
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
26
+ emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
27
+ return emb
28
+
29
+ class PixArtMSBlock(nn.Module):
30
+ """
31
+ A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
32
+ """
33
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
34
+ sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
35
+ super().__init__()
36
+ self.hidden_size = hidden_size
37
+ self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
38
+ self.attn = AttentionKVCompress(
39
+ hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
40
+ qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
41
+ )
42
+ self.cross_attn = MultiHeadCrossAttention(
43
+ hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
44
+ )
45
+ self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
46
+ # to be compatible with lower version pytorch
47
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
48
+ self.mlp = Mlp(
49
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
50
+ dtype=dtype, device=device, operations=operations
51
+ )
52
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
53
+
54
+ def forward(self, x, y, t, mask=None, HW=None, **kwargs):
55
+ B, N, C = x.shape
56
+
57
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
58
+ x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
59
+ x = x + self.cross_attn(x, y, mask)
60
+ x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
61
+
62
+ return x
63
+
64
+
65
+ ### Core PixArt Model ###
66
+ class PixArtMS(nn.Module):
67
+ """
68
+ Diffusion model with a Transformer backbone.
69
+ """
70
+ def __init__(
71
+ self,
72
+ input_size=32,
73
+ patch_size=2,
74
+ in_channels=4,
75
+ hidden_size=1152,
76
+ depth=28,
77
+ num_heads=16,
78
+ mlp_ratio=4.0,
79
+ class_dropout_prob=0.1,
80
+ learn_sigma=True,
81
+ pred_sigma=True,
82
+ drop_path: float = 0.,
83
+ caption_channels=4096,
84
+ pe_interpolation=None,
85
+ pe_precision=None,
86
+ config=None,
87
+ model_max_length=120,
88
+ micro_condition=True,
89
+ qk_norm=False,
90
+ kv_compress_config=None,
91
+ dtype=None,
92
+ device=None,
93
+ operations=None,
94
+ **kwargs,
95
+ ):
96
+ nn.Module.__init__(self)
97
+ self.dtype = dtype
98
+ self.pred_sigma = pred_sigma
99
+ self.in_channels = in_channels
100
+ self.out_channels = in_channels * 2 if pred_sigma else in_channels
101
+ self.patch_size = patch_size
102
+ self.num_heads = num_heads
103
+ self.pe_interpolation = pe_interpolation
104
+ self.pe_precision = pe_precision
105
+ self.hidden_size = hidden_size
106
+ self.depth = depth
107
+
108
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
109
+ self.t_block = nn.Sequential(
110
+ nn.SiLU(),
111
+ operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
112
+ )
113
+ self.x_embedder = PatchEmbed(
114
+ patch_size=patch_size,
115
+ in_chans=in_channels,
116
+ embed_dim=hidden_size,
117
+ bias=True,
118
+ dtype=dtype,
119
+ device=device,
120
+ operations=operations
121
+ )
122
+ self.t_embedder = TimestepEmbedder(
123
+ hidden_size, dtype=dtype, device=device, operations=operations,
124
+ )
125
+ self.y_embedder = CaptionEmbedder(
126
+ in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
127
+ act_layer=approx_gelu, token_num=model_max_length,
128
+ dtype=dtype, device=device, operations=operations,
129
+ )
130
+
131
+ self.micro_conditioning = micro_condition
132
+ if self.micro_conditioning:
133
+ self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
134
+ self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
135
+
136
+ # For fixed sin-cos embedding:
137
+ # num_patches = (input_size // patch_size) * (input_size // patch_size)
138
+ # self.base_size = input_size // self.patch_size
139
+ # self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
140
+
141
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
142
+ if kv_compress_config is None:
143
+ kv_compress_config = {
144
+ 'sampling': None,
145
+ 'scale_factor': 1,
146
+ 'kv_compress_layer': [],
147
+ }
148
+ self.blocks = nn.ModuleList([
149
+ PixArtMSBlock(
150
+ hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
151
+ sampling=kv_compress_config['sampling'],
152
+ sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
153
+ qk_norm=qk_norm,
154
+ dtype=dtype,
155
+ device=device,
156
+ operations=operations,
157
+ )
158
+ for i in range(depth)
159
+ ])
160
+ self.final_layer = T2IFinalLayer(
161
+ hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
162
+ )
163
+
164
+ def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
165
+ """
166
+ Original forward pass of PixArt.
167
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
168
+ t: (N,) tensor of diffusion timesteps
169
+ y: (N, 1, 120, C) conditioning
170
+ ar: (N, 1): aspect ratio
171
+ cs: (N ,2) size conditioning for height/width
172
+ """
173
+ B, C, H, W = x.shape
174
+ c_res = (H + W) // 2
175
+ pe_interpolation = self.pe_interpolation
176
+ if pe_interpolation is None or self.pe_precision is not None:
177
+ # calculate pe_interpolation on-the-fly
178
+ pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
179
+
180
+ pos_embed = get_2d_sincos_pos_embed_torch(
181
+ self.hidden_size,
182
+ h=(H // self.patch_size),
183
+ w=(W // self.patch_size),
184
+ pe_interpolation=pe_interpolation,
185
+ base_size=((round(c_res / 64) * 64) // self.patch_size),
186
+ device=x.device,
187
+ dtype=x.dtype,
188
+ ).unsqueeze(0)
189
+
190
+ x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
191
+ t = self.t_embedder(timestep, x.dtype) # (N, D)
192
+
193
+ if self.micro_conditioning and (c_size is not None and c_ar is not None):
194
+ bs = x.shape[0]
195
+ c_size = self.csize_embedder(c_size, bs) # (N, D)
196
+ c_ar = self.ar_embedder(c_ar, bs) # (N, D)
197
+ t = t + torch.cat([c_size, c_ar], dim=1)
198
+
199
+ t0 = self.t_block(t)
200
+ y = self.y_embedder(y, self.training) # (N, D)
201
+
202
+ if mask is not None:
203
+ if mask.shape[0] != y.shape[0]:
204
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
205
+ mask = mask.squeeze(1).squeeze(1)
206
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
207
+ y_lens = mask.sum(dim=1).tolist()
208
+ else:
209
+ y_lens = None
210
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
211
+ for block in self.blocks:
212
+ x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
213
+
214
+ x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
215
+ x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
216
+
217
+ return x
218
+
219
+ def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
220
+ B, C, H, W = x.shape
221
+
222
+ # Fallback for missing microconds
223
+ if self.micro_conditioning:
224
+ if c_size is None:
225
+ c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
226
+
227
+ if c_ar is None:
228
+ c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
229
+
230
+ ## Still accepts the input w/o that dim but returns garbage
231
+ if len(context.shape) == 3:
232
+ context = context.unsqueeze(1)
233
+
234
+ ## run original forward pass
235
+ out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
236
+
237
+ ## only return EPS
238
+ if self.pred_sigma:
239
+ return out[:, :self.in_channels]
240
+ return out
241
+
242
+ def unpatchify(self, x, h, w):
243
+ """
244
+ x: (N, T, patch_size**2 * C)
245
+ imgs: (N, H, W, C)
246
+ """
247
+ c = self.out_channels
248
+ p = self.x_embedder.patch_size[0]
249
+ h = h // self.patch_size
250
+ w = w // self.patch_size
251
+ assert h * w == x.shape[1]
252
+
253
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
254
+ x = torch.einsum('nhwpqc->nchpwq', x)
255
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
256
+ return imgs