Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- ComfyUI/comfy/comfy_types/examples/example_nodes.py +28 -0
- ComfyUI/comfy/comfy_types/examples/input_options.png +0 -0
- ComfyUI/comfy/comfy_types/examples/input_types.png +0 -0
- ComfyUI/comfy/comfy_types/examples/required_hint.png +0 -0
- ComfyUI/comfy/ldm/ace/attention.py +761 -0
- ComfyUI/comfy/ldm/ace/lyric_encoder.py +1067 -0
- ComfyUI/comfy/ldm/ace/model.py +385 -0
- ComfyUI/comfy/ldm/ace/vae/music_log_mel.py +113 -0
- ComfyUI/comfy/ldm/audio/autoencoder.py +276 -0
- ComfyUI/comfy/ldm/audio/dit.py +896 -0
- ComfyUI/comfy/ldm/audio/embedders.py +108 -0
- ComfyUI/comfy/ldm/aura/mmdit.py +498 -0
- ComfyUI/comfy/ldm/cascade/common.py +154 -0
- ComfyUI/comfy/ldm/cascade/controlnet.py +92 -0
- ComfyUI/comfy/ldm/cascade/stage_a.py +259 -0
- ComfyUI/comfy/ldm/cascade/stage_b.py +256 -0
- ComfyUI/comfy/ldm/cascade/stage_c.py +273 -0
- ComfyUI/comfy/ldm/cascade/stage_c_coder.py +98 -0
- ComfyUI/comfy/ldm/chroma/layers.py +181 -0
- ComfyUI/comfy/ldm/chroma/model.py +270 -0
- ComfyUI/comfy/ldm/cosmos/blocks.py +797 -0
- ComfyUI/comfy/ldm/cosmos/model.py +512 -0
- ComfyUI/comfy/ldm/cosmos/position_embedding.py +207 -0
- ComfyUI/comfy/ldm/cosmos/predict2.py +864 -0
- ComfyUI/comfy/ldm/cosmos/vae.py +131 -0
- ComfyUI/comfy/ldm/flux/controlnet.py +208 -0
- ComfyUI/comfy/ldm/flux/layers.py +278 -0
- ComfyUI/comfy/ldm/flux/math.py +45 -0
- ComfyUI/comfy/ldm/flux/model.py +244 -0
- ComfyUI/comfy/ldm/flux/redux.py +25 -0
- ComfyUI/comfy/ldm/hidream/model.py +802 -0
- ComfyUI/comfy/ldm/hunyuan3d/model.py +135 -0
- ComfyUI/comfy/ldm/hunyuan3d/vae.py +587 -0
- ComfyUI/comfy/ldm/hunyuan_video/model.py +355 -0
- ComfyUI/comfy/ldm/hydit/attn_layers.py +218 -0
- ComfyUI/comfy/ldm/hydit/controlnet.py +311 -0
- ComfyUI/comfy/ldm/hydit/models.py +417 -0
- ComfyUI/comfy/ldm/hydit/poolers.py +36 -0
- ComfyUI/comfy/ldm/hydit/posemb_layers.py +224 -0
- ComfyUI/comfy/ldm/lightricks/model.py +506 -0
- ComfyUI/comfy/ldm/lightricks/symmetric_patchifier.py +117 -0
- ComfyUI/comfy/ldm/lightricks/vae/causal_conv3d.py +65 -0
- ComfyUI/comfy/ldm/lumina/model.py +622 -0
- ComfyUI/comfy/ldm/models/autoencoder.py +231 -0
- ComfyUI/comfy/ldm/modules/attention.py +1035 -0
- ComfyUI/comfy/ldm/modules/ema.py +80 -0
- ComfyUI/comfy/ldm/modules/sub_quadratic_attention.py +275 -0
- ComfyUI/comfy/ldm/modules/temporal_ae.py +246 -0
- ComfyUI/comfy/ldm/omnigen/omnigen2.py +469 -0
- ComfyUI/comfy/ldm/pixart/pixartms.py +256 -0
ComfyUI/comfy/comfy_types/examples/example_nodes.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
| 2 |
+
from inspect import cleandoc
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ExampleNode(ComfyNodeABC):
|
| 6 |
+
"""An example node that just adds 1 to an input integer.
|
| 7 |
+
|
| 8 |
+
* Requires a modern IDE to provide any benefit (detail: an IDE configured with analysis paths etc).
|
| 9 |
+
* This node is intended as an example for developers only.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
DESCRIPTION = cleandoc(__doc__)
|
| 13 |
+
CATEGORY = "examples"
|
| 14 |
+
|
| 15 |
+
@classmethod
|
| 16 |
+
def INPUT_TYPES(s) -> InputTypeDict:
|
| 17 |
+
return {
|
| 18 |
+
"required": {
|
| 19 |
+
"input_int": (IO.INT, {"defaultInput": True}),
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
RETURN_TYPES = (IO.INT,)
|
| 24 |
+
RETURN_NAMES = ("input_plus_one",)
|
| 25 |
+
FUNCTION = "execute"
|
| 26 |
+
|
| 27 |
+
def execute(self, input_int: int):
|
| 28 |
+
return (input_int + 1,)
|
ComfyUI/comfy/comfy_types/examples/input_options.png
ADDED
|
ComfyUI/comfy/comfy_types/examples/input_types.png
ADDED
|
ComfyUI/comfy/comfy_types/examples/required_hint.png
ADDED
|
ComfyUI/comfy/ldm/ace/attention.py
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
|
| 2 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from typing import Tuple, Union, Optional
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
import comfy.model_management
|
| 22 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 23 |
+
|
| 24 |
+
class Attention(nn.Module):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
query_dim: int,
|
| 28 |
+
cross_attention_dim: Optional[int] = None,
|
| 29 |
+
heads: int = 8,
|
| 30 |
+
kv_heads: Optional[int] = None,
|
| 31 |
+
dim_head: int = 64,
|
| 32 |
+
dropout: float = 0.0,
|
| 33 |
+
bias: bool = False,
|
| 34 |
+
qk_norm: Optional[str] = None,
|
| 35 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 36 |
+
added_proj_bias: Optional[bool] = True,
|
| 37 |
+
out_bias: bool = True,
|
| 38 |
+
scale_qk: bool = True,
|
| 39 |
+
only_cross_attention: bool = False,
|
| 40 |
+
eps: float = 1e-5,
|
| 41 |
+
rescale_output_factor: float = 1.0,
|
| 42 |
+
residual_connection: bool = False,
|
| 43 |
+
processor=None,
|
| 44 |
+
out_dim: int = None,
|
| 45 |
+
out_context_dim: int = None,
|
| 46 |
+
context_pre_only=None,
|
| 47 |
+
pre_only=False,
|
| 48 |
+
elementwise_affine: bool = True,
|
| 49 |
+
is_causal: bool = False,
|
| 50 |
+
dtype=None, device=None, operations=None
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 55 |
+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
| 56 |
+
self.query_dim = query_dim
|
| 57 |
+
self.use_bias = bias
|
| 58 |
+
self.is_cross_attention = cross_attention_dim is not None
|
| 59 |
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 60 |
+
self.rescale_output_factor = rescale_output_factor
|
| 61 |
+
self.residual_connection = residual_connection
|
| 62 |
+
self.dropout = dropout
|
| 63 |
+
self.fused_projections = False
|
| 64 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 65 |
+
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
| 66 |
+
self.context_pre_only = context_pre_only
|
| 67 |
+
self.pre_only = pre_only
|
| 68 |
+
self.is_causal = is_causal
|
| 69 |
+
|
| 70 |
+
self.scale_qk = scale_qk
|
| 71 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
| 72 |
+
|
| 73 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 74 |
+
# for slice_size > 0 the attention score computation
|
| 75 |
+
# is split across the batch axis to save memory
|
| 76 |
+
# You can set slice_size with `set_attention_slice`
|
| 77 |
+
self.sliceable_head_dim = heads
|
| 78 |
+
|
| 79 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 80 |
+
self.only_cross_attention = only_cross_attention
|
| 81 |
+
|
| 82 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.group_norm = None
|
| 88 |
+
self.spatial_norm = None
|
| 89 |
+
|
| 90 |
+
self.norm_q = None
|
| 91 |
+
self.norm_k = None
|
| 92 |
+
|
| 93 |
+
self.norm_cross = None
|
| 94 |
+
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
| 95 |
+
|
| 96 |
+
if not self.only_cross_attention:
|
| 97 |
+
# only relevant for the `AddedKVProcessor` classes
|
| 98 |
+
self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
| 99 |
+
self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
| 100 |
+
else:
|
| 101 |
+
self.to_k = None
|
| 102 |
+
self.to_v = None
|
| 103 |
+
|
| 104 |
+
self.added_proj_bias = added_proj_bias
|
| 105 |
+
if self.added_kv_proj_dim is not None:
|
| 106 |
+
self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
| 107 |
+
self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
| 108 |
+
if self.context_pre_only is not None:
|
| 109 |
+
self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
|
| 110 |
+
else:
|
| 111 |
+
self.add_q_proj = None
|
| 112 |
+
self.add_k_proj = None
|
| 113 |
+
self.add_v_proj = None
|
| 114 |
+
|
| 115 |
+
if not self.pre_only:
|
| 116 |
+
self.to_out = nn.ModuleList([])
|
| 117 |
+
self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
|
| 118 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 119 |
+
else:
|
| 120 |
+
self.to_out = None
|
| 121 |
+
|
| 122 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
| 123 |
+
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
| 124 |
+
else:
|
| 125 |
+
self.to_add_out = None
|
| 126 |
+
|
| 127 |
+
self.norm_added_q = None
|
| 128 |
+
self.norm_added_k = None
|
| 129 |
+
self.processor = processor
|
| 130 |
+
|
| 131 |
+
def forward(
|
| 132 |
+
self,
|
| 133 |
+
hidden_states: torch.Tensor,
|
| 134 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 135 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 136 |
+
**cross_attention_kwargs,
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
return self.processor(
|
| 139 |
+
self,
|
| 140 |
+
hidden_states,
|
| 141 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 142 |
+
attention_mask=attention_mask,
|
| 143 |
+
**cross_attention_kwargs,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CustomLiteLAProcessor2_0:
|
| 148 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
|
| 149 |
+
|
| 150 |
+
def __init__(self):
|
| 151 |
+
self.kernel_func = nn.ReLU(inplace=False)
|
| 152 |
+
self.eps = 1e-15
|
| 153 |
+
self.pad_val = 1.0
|
| 154 |
+
|
| 155 |
+
def apply_rotary_emb(
|
| 156 |
+
self,
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 159 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 160 |
+
"""
|
| 161 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 162 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 163 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 164 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
x (`torch.Tensor`):
|
| 168 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 169 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 173 |
+
"""
|
| 174 |
+
cos, sin = freqs_cis # [S, D]
|
| 175 |
+
cos = cos[None, None]
|
| 176 |
+
sin = sin[None, None]
|
| 177 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 178 |
+
|
| 179 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 180 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 181 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 182 |
+
|
| 183 |
+
return out
|
| 184 |
+
|
| 185 |
+
def __call__(
|
| 186 |
+
self,
|
| 187 |
+
attn: Attention,
|
| 188 |
+
hidden_states: torch.FloatTensor,
|
| 189 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 190 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 191 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 192 |
+
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| 193 |
+
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| 194 |
+
*args,
|
| 195 |
+
**kwargs,
|
| 196 |
+
) -> torch.FloatTensor:
|
| 197 |
+
hidden_states_len = hidden_states.shape[1]
|
| 198 |
+
|
| 199 |
+
input_ndim = hidden_states.ndim
|
| 200 |
+
if input_ndim == 4:
|
| 201 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 202 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 203 |
+
if encoder_hidden_states is not None:
|
| 204 |
+
context_input_ndim = encoder_hidden_states.ndim
|
| 205 |
+
if context_input_ndim == 4:
|
| 206 |
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
| 207 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 208 |
+
|
| 209 |
+
batch_size = hidden_states.shape[0]
|
| 210 |
+
|
| 211 |
+
# `sample` projections.
|
| 212 |
+
dtype = hidden_states.dtype
|
| 213 |
+
query = attn.to_q(hidden_states)
|
| 214 |
+
key = attn.to_k(hidden_states)
|
| 215 |
+
value = attn.to_v(hidden_states)
|
| 216 |
+
|
| 217 |
+
# `context` projections.
|
| 218 |
+
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
| 219 |
+
if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
|
| 220 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
| 221 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 222 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 223 |
+
|
| 224 |
+
# attention
|
| 225 |
+
if not attn.is_cross_attention:
|
| 226 |
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
| 227 |
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
| 228 |
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
| 229 |
+
else:
|
| 230 |
+
query = hidden_states
|
| 231 |
+
key = encoder_hidden_states
|
| 232 |
+
value = encoder_hidden_states
|
| 233 |
+
|
| 234 |
+
inner_dim = key.shape[-1]
|
| 235 |
+
head_dim = inner_dim // attn.heads
|
| 236 |
+
|
| 237 |
+
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
| 238 |
+
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
|
| 239 |
+
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
|
| 240 |
+
|
| 241 |
+
# RoPE需要 [B, H, S, D] 输入
|
| 242 |
+
# 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
|
| 243 |
+
query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
|
| 244 |
+
|
| 245 |
+
# Apply query and key normalization if needed
|
| 246 |
+
if attn.norm_q is not None:
|
| 247 |
+
query = attn.norm_q(query)
|
| 248 |
+
if attn.norm_k is not None:
|
| 249 |
+
key = attn.norm_k(key)
|
| 250 |
+
|
| 251 |
+
# Apply RoPE if needed
|
| 252 |
+
if rotary_freqs_cis is not None:
|
| 253 |
+
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
| 254 |
+
if not attn.is_cross_attention:
|
| 255 |
+
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
| 256 |
+
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
| 257 |
+
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
| 258 |
+
|
| 259 |
+
# 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
|
| 260 |
+
query = query.permute(0, 1, 3, 2) # [B, H, D, S]
|
| 261 |
+
|
| 262 |
+
if attention_mask is not None:
|
| 263 |
+
# attention_mask: [B, S] -> [B, 1, S, 1]
|
| 264 |
+
attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
|
| 265 |
+
query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
|
| 266 |
+
if not attn.is_cross_attention:
|
| 267 |
+
key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
|
| 268 |
+
value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
|
| 269 |
+
|
| 270 |
+
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
| 271 |
+
encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
|
| 272 |
+
# 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
|
| 273 |
+
key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
|
| 274 |
+
value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
|
| 275 |
+
|
| 276 |
+
query = self.kernel_func(query)
|
| 277 |
+
key = self.kernel_func(key)
|
| 278 |
+
|
| 279 |
+
query, key, value = query.float(), key.float(), value.float()
|
| 280 |
+
|
| 281 |
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
|
| 282 |
+
|
| 283 |
+
vk = torch.matmul(value, key)
|
| 284 |
+
|
| 285 |
+
hidden_states = torch.matmul(vk, query)
|
| 286 |
+
|
| 287 |
+
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
|
| 288 |
+
hidden_states = hidden_states.float()
|
| 289 |
+
|
| 290 |
+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
| 291 |
+
|
| 292 |
+
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
|
| 293 |
+
|
| 294 |
+
hidden_states = hidden_states.to(dtype)
|
| 295 |
+
if encoder_hidden_states is not None:
|
| 296 |
+
encoder_hidden_states = encoder_hidden_states.to(dtype)
|
| 297 |
+
|
| 298 |
+
# Split the attention outputs.
|
| 299 |
+
if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
|
| 300 |
+
hidden_states, encoder_hidden_states = (
|
| 301 |
+
hidden_states[:, : hidden_states_len],
|
| 302 |
+
hidden_states[:, hidden_states_len:],
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# linear proj
|
| 306 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 307 |
+
# dropout
|
| 308 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 309 |
+
if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
|
| 310 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 311 |
+
|
| 312 |
+
if input_ndim == 4:
|
| 313 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 314 |
+
if encoder_hidden_states is not None and context_input_ndim == 4:
|
| 315 |
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 316 |
+
|
| 317 |
+
if torch.get_autocast_gpu_dtype() == torch.float16:
|
| 318 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 319 |
+
if encoder_hidden_states is not None:
|
| 320 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 321 |
+
|
| 322 |
+
return hidden_states, encoder_hidden_states
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class CustomerAttnProcessor2_0:
|
| 326 |
+
r"""
|
| 327 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def apply_rotary_emb(
|
| 331 |
+
self,
|
| 332 |
+
x: torch.Tensor,
|
| 333 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 334 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 335 |
+
"""
|
| 336 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 337 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 338 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 339 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
x (`torch.Tensor`):
|
| 343 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 344 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 348 |
+
"""
|
| 349 |
+
cos, sin = freqs_cis # [S, D]
|
| 350 |
+
cos = cos[None, None]
|
| 351 |
+
sin = sin[None, None]
|
| 352 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 353 |
+
|
| 354 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 355 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 356 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 357 |
+
|
| 358 |
+
return out
|
| 359 |
+
|
| 360 |
+
def __call__(
|
| 361 |
+
self,
|
| 362 |
+
attn: Attention,
|
| 363 |
+
hidden_states: torch.FloatTensor,
|
| 364 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 365 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 366 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 367 |
+
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| 368 |
+
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| 369 |
+
*args,
|
| 370 |
+
**kwargs,
|
| 371 |
+
) -> torch.Tensor:
|
| 372 |
+
|
| 373 |
+
residual = hidden_states
|
| 374 |
+
input_ndim = hidden_states.ndim
|
| 375 |
+
|
| 376 |
+
if input_ndim == 4:
|
| 377 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 378 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 379 |
+
|
| 380 |
+
batch_size, sequence_length, _ = (
|
| 381 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
|
| 385 |
+
|
| 386 |
+
if attn.group_norm is not None:
|
| 387 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 388 |
+
|
| 389 |
+
query = attn.to_q(hidden_states)
|
| 390 |
+
|
| 391 |
+
if encoder_hidden_states is None:
|
| 392 |
+
encoder_hidden_states = hidden_states
|
| 393 |
+
elif attn.norm_cross:
|
| 394 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 395 |
+
|
| 396 |
+
key = attn.to_k(encoder_hidden_states)
|
| 397 |
+
value = attn.to_v(encoder_hidden_states)
|
| 398 |
+
|
| 399 |
+
inner_dim = key.shape[-1]
|
| 400 |
+
head_dim = inner_dim // attn.heads
|
| 401 |
+
|
| 402 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 403 |
+
|
| 404 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 405 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 406 |
+
|
| 407 |
+
if attn.norm_q is not None:
|
| 408 |
+
query = attn.norm_q(query)
|
| 409 |
+
if attn.norm_k is not None:
|
| 410 |
+
key = attn.norm_k(key)
|
| 411 |
+
|
| 412 |
+
# Apply RoPE if needed
|
| 413 |
+
if rotary_freqs_cis is not None:
|
| 414 |
+
query = self.apply_rotary_emb(query, rotary_freqs_cis)
|
| 415 |
+
if not attn.is_cross_attention:
|
| 416 |
+
key = self.apply_rotary_emb(key, rotary_freqs_cis)
|
| 417 |
+
elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
|
| 418 |
+
key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
|
| 419 |
+
|
| 420 |
+
if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
|
| 421 |
+
# attention_mask: N x S1
|
| 422 |
+
# encoder_attention_mask: N x S2
|
| 423 |
+
# cross attention 整合attention_mask和encoder_attention_mask
|
| 424 |
+
combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
|
| 425 |
+
attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
|
| 426 |
+
attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
|
| 427 |
+
|
| 428 |
+
elif not attn.is_cross_attention and attention_mask is not None:
|
| 429 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 430 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 431 |
+
# (batch, heads, source_length, target_length)
|
| 432 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 433 |
+
|
| 434 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 435 |
+
hidden_states = optimized_attention(
|
| 436 |
+
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
|
| 437 |
+
).to(query.dtype)
|
| 438 |
+
|
| 439 |
+
# linear proj
|
| 440 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 441 |
+
# dropout
|
| 442 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 443 |
+
|
| 444 |
+
if input_ndim == 4:
|
| 445 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 446 |
+
|
| 447 |
+
if attn.residual_connection:
|
| 448 |
+
hidden_states = hidden_states + residual
|
| 449 |
+
|
| 450 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 451 |
+
|
| 452 |
+
return hidden_states
|
| 453 |
+
|
| 454 |
+
def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
|
| 455 |
+
"""Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
|
| 456 |
+
if isinstance(x, (list, tuple)):
|
| 457 |
+
return list(x)
|
| 458 |
+
return [x for _ in range(repeat_time)]
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
|
| 462 |
+
"""Return tuple with min_len by repeating element at idx_repeat."""
|
| 463 |
+
# convert to list first
|
| 464 |
+
x = val2list(x)
|
| 465 |
+
|
| 466 |
+
# repeat elements if necessary
|
| 467 |
+
if len(x) > 0:
|
| 468 |
+
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
| 469 |
+
|
| 470 |
+
return tuple(x)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def t2i_modulate(x, shift, scale):
|
| 474 |
+
return x * (1 + scale) + shift
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
|
| 478 |
+
if isinstance(kernel_size, tuple):
|
| 479 |
+
return tuple([get_same_padding(ks) for ks in kernel_size])
|
| 480 |
+
else:
|
| 481 |
+
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
|
| 482 |
+
return kernel_size // 2
|
| 483 |
+
|
| 484 |
+
class ConvLayer(nn.Module):
|
| 485 |
+
def __init__(
|
| 486 |
+
self,
|
| 487 |
+
in_dim: int,
|
| 488 |
+
out_dim: int,
|
| 489 |
+
kernel_size=3,
|
| 490 |
+
stride=1,
|
| 491 |
+
dilation=1,
|
| 492 |
+
groups=1,
|
| 493 |
+
padding: Union[int, None] = None,
|
| 494 |
+
use_bias=False,
|
| 495 |
+
norm=None,
|
| 496 |
+
act=None,
|
| 497 |
+
dtype=None, device=None, operations=None
|
| 498 |
+
):
|
| 499 |
+
super().__init__()
|
| 500 |
+
if padding is None:
|
| 501 |
+
padding = get_same_padding(kernel_size)
|
| 502 |
+
padding *= dilation
|
| 503 |
+
|
| 504 |
+
self.in_dim = in_dim
|
| 505 |
+
self.out_dim = out_dim
|
| 506 |
+
self.kernel_size = kernel_size
|
| 507 |
+
self.stride = stride
|
| 508 |
+
self.dilation = dilation
|
| 509 |
+
self.groups = groups
|
| 510 |
+
self.padding = padding
|
| 511 |
+
self.use_bias = use_bias
|
| 512 |
+
|
| 513 |
+
self.conv = operations.Conv1d(
|
| 514 |
+
in_dim,
|
| 515 |
+
out_dim,
|
| 516 |
+
kernel_size=kernel_size,
|
| 517 |
+
stride=stride,
|
| 518 |
+
padding=padding,
|
| 519 |
+
dilation=dilation,
|
| 520 |
+
groups=groups,
|
| 521 |
+
bias=use_bias,
|
| 522 |
+
device=device,
|
| 523 |
+
dtype=dtype
|
| 524 |
+
)
|
| 525 |
+
if norm is not None:
|
| 526 |
+
self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 527 |
+
else:
|
| 528 |
+
self.norm = None
|
| 529 |
+
if act is not None:
|
| 530 |
+
self.act = nn.SiLU(inplace=True)
|
| 531 |
+
else:
|
| 532 |
+
self.act = None
|
| 533 |
+
|
| 534 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 535 |
+
x = self.conv(x)
|
| 536 |
+
if self.norm:
|
| 537 |
+
x = self.norm(x)
|
| 538 |
+
if self.act:
|
| 539 |
+
x = self.act(x)
|
| 540 |
+
return x
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class GLUMBConv(nn.Module):
|
| 544 |
+
def __init__(
|
| 545 |
+
self,
|
| 546 |
+
in_features: int,
|
| 547 |
+
hidden_features: int,
|
| 548 |
+
out_feature=None,
|
| 549 |
+
kernel_size=3,
|
| 550 |
+
stride=1,
|
| 551 |
+
padding: Union[int, None] = None,
|
| 552 |
+
use_bias=False,
|
| 553 |
+
norm=(None, None, None),
|
| 554 |
+
act=("silu", "silu", None),
|
| 555 |
+
dilation=1,
|
| 556 |
+
dtype=None, device=None, operations=None
|
| 557 |
+
):
|
| 558 |
+
out_feature = out_feature or in_features
|
| 559 |
+
super().__init__()
|
| 560 |
+
use_bias = val2tuple(use_bias, 3)
|
| 561 |
+
norm = val2tuple(norm, 3)
|
| 562 |
+
act = val2tuple(act, 3)
|
| 563 |
+
|
| 564 |
+
self.glu_act = nn.SiLU(inplace=False)
|
| 565 |
+
self.inverted_conv = ConvLayer(
|
| 566 |
+
in_features,
|
| 567 |
+
hidden_features * 2,
|
| 568 |
+
1,
|
| 569 |
+
use_bias=use_bias[0],
|
| 570 |
+
norm=norm[0],
|
| 571 |
+
act=act[0],
|
| 572 |
+
dtype=dtype,
|
| 573 |
+
device=device,
|
| 574 |
+
operations=operations,
|
| 575 |
+
)
|
| 576 |
+
self.depth_conv = ConvLayer(
|
| 577 |
+
hidden_features * 2,
|
| 578 |
+
hidden_features * 2,
|
| 579 |
+
kernel_size,
|
| 580 |
+
stride=stride,
|
| 581 |
+
groups=hidden_features * 2,
|
| 582 |
+
padding=padding,
|
| 583 |
+
use_bias=use_bias[1],
|
| 584 |
+
norm=norm[1],
|
| 585 |
+
act=None,
|
| 586 |
+
dilation=dilation,
|
| 587 |
+
dtype=dtype,
|
| 588 |
+
device=device,
|
| 589 |
+
operations=operations,
|
| 590 |
+
)
|
| 591 |
+
self.point_conv = ConvLayer(
|
| 592 |
+
hidden_features,
|
| 593 |
+
out_feature,
|
| 594 |
+
1,
|
| 595 |
+
use_bias=use_bias[2],
|
| 596 |
+
norm=norm[2],
|
| 597 |
+
act=act[2],
|
| 598 |
+
dtype=dtype,
|
| 599 |
+
device=device,
|
| 600 |
+
operations=operations,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 604 |
+
x = x.transpose(1, 2)
|
| 605 |
+
x = self.inverted_conv(x)
|
| 606 |
+
x = self.depth_conv(x)
|
| 607 |
+
|
| 608 |
+
x, gate = torch.chunk(x, 2, dim=1)
|
| 609 |
+
gate = self.glu_act(gate)
|
| 610 |
+
x = x * gate
|
| 611 |
+
|
| 612 |
+
x = self.point_conv(x)
|
| 613 |
+
x = x.transpose(1, 2)
|
| 614 |
+
|
| 615 |
+
return x
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class LinearTransformerBlock(nn.Module):
|
| 619 |
+
"""
|
| 620 |
+
A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
|
| 621 |
+
"""
|
| 622 |
+
def __init__(
|
| 623 |
+
self,
|
| 624 |
+
dim,
|
| 625 |
+
num_attention_heads,
|
| 626 |
+
attention_head_dim,
|
| 627 |
+
use_adaln_single=True,
|
| 628 |
+
cross_attention_dim=None,
|
| 629 |
+
added_kv_proj_dim=None,
|
| 630 |
+
context_pre_only=False,
|
| 631 |
+
mlp_ratio=4.0,
|
| 632 |
+
add_cross_attention=False,
|
| 633 |
+
add_cross_attention_dim=None,
|
| 634 |
+
qk_norm=None,
|
| 635 |
+
dtype=None, device=None, operations=None
|
| 636 |
+
):
|
| 637 |
+
super().__init__()
|
| 638 |
+
|
| 639 |
+
self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 640 |
+
self.attn = Attention(
|
| 641 |
+
query_dim=dim,
|
| 642 |
+
cross_attention_dim=cross_attention_dim,
|
| 643 |
+
added_kv_proj_dim=added_kv_proj_dim,
|
| 644 |
+
dim_head=attention_head_dim,
|
| 645 |
+
heads=num_attention_heads,
|
| 646 |
+
out_dim=dim,
|
| 647 |
+
bias=True,
|
| 648 |
+
qk_norm=qk_norm,
|
| 649 |
+
processor=CustomLiteLAProcessor2_0(),
|
| 650 |
+
dtype=dtype,
|
| 651 |
+
device=device,
|
| 652 |
+
operations=operations,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
self.add_cross_attention = add_cross_attention
|
| 656 |
+
self.context_pre_only = context_pre_only
|
| 657 |
+
|
| 658 |
+
if add_cross_attention and add_cross_attention_dim is not None:
|
| 659 |
+
self.cross_attn = Attention(
|
| 660 |
+
query_dim=dim,
|
| 661 |
+
cross_attention_dim=add_cross_attention_dim,
|
| 662 |
+
added_kv_proj_dim=add_cross_attention_dim,
|
| 663 |
+
dim_head=attention_head_dim,
|
| 664 |
+
heads=num_attention_heads,
|
| 665 |
+
out_dim=dim,
|
| 666 |
+
context_pre_only=context_pre_only,
|
| 667 |
+
bias=True,
|
| 668 |
+
qk_norm=qk_norm,
|
| 669 |
+
processor=CustomerAttnProcessor2_0(),
|
| 670 |
+
dtype=dtype,
|
| 671 |
+
device=device,
|
| 672 |
+
operations=operations,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
|
| 676 |
+
|
| 677 |
+
self.ff = GLUMBConv(
|
| 678 |
+
in_features=dim,
|
| 679 |
+
hidden_features=int(dim * mlp_ratio),
|
| 680 |
+
use_bias=(True, True, False),
|
| 681 |
+
norm=(None, None, None),
|
| 682 |
+
act=("silu", "silu", None),
|
| 683 |
+
dtype=dtype,
|
| 684 |
+
device=device,
|
| 685 |
+
operations=operations,
|
| 686 |
+
)
|
| 687 |
+
self.use_adaln_single = use_adaln_single
|
| 688 |
+
if use_adaln_single:
|
| 689 |
+
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
|
| 690 |
+
|
| 691 |
+
def forward(
|
| 692 |
+
self,
|
| 693 |
+
hidden_states: torch.FloatTensor,
|
| 694 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 695 |
+
attention_mask: torch.FloatTensor = None,
|
| 696 |
+
encoder_attention_mask: torch.FloatTensor = None,
|
| 697 |
+
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| 698 |
+
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
|
| 699 |
+
temb: torch.FloatTensor = None,
|
| 700 |
+
):
|
| 701 |
+
|
| 702 |
+
N = hidden_states.shape[0]
|
| 703 |
+
|
| 704 |
+
# step 1: AdaLN single
|
| 705 |
+
if self.use_adaln_single:
|
| 706 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 707 |
+
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
|
| 708 |
+
).chunk(6, dim=1)
|
| 709 |
+
|
| 710 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 711 |
+
if self.use_adaln_single:
|
| 712 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 713 |
+
|
| 714 |
+
# step 2: attention
|
| 715 |
+
if not self.add_cross_attention:
|
| 716 |
+
attn_output, encoder_hidden_states = self.attn(
|
| 717 |
+
hidden_states=norm_hidden_states,
|
| 718 |
+
attention_mask=attention_mask,
|
| 719 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 720 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 721 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
| 722 |
+
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
| 723 |
+
)
|
| 724 |
+
else:
|
| 725 |
+
attn_output, _ = self.attn(
|
| 726 |
+
hidden_states=norm_hidden_states,
|
| 727 |
+
attention_mask=attention_mask,
|
| 728 |
+
encoder_hidden_states=None,
|
| 729 |
+
encoder_attention_mask=None,
|
| 730 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
| 731 |
+
rotary_freqs_cis_cross=None,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
if self.use_adaln_single:
|
| 735 |
+
attn_output = gate_msa * attn_output
|
| 736 |
+
hidden_states = attn_output + hidden_states
|
| 737 |
+
|
| 738 |
+
if self.add_cross_attention:
|
| 739 |
+
attn_output = self.cross_attn(
|
| 740 |
+
hidden_states=hidden_states,
|
| 741 |
+
attention_mask=attention_mask,
|
| 742 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 743 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 744 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
| 745 |
+
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
|
| 746 |
+
)
|
| 747 |
+
hidden_states = attn_output + hidden_states
|
| 748 |
+
|
| 749 |
+
# step 3: add norm
|
| 750 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 751 |
+
if self.use_adaln_single:
|
| 752 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 753 |
+
|
| 754 |
+
# step 4: feed forward
|
| 755 |
+
ff_output = self.ff(norm_hidden_states)
|
| 756 |
+
if self.use_adaln_single:
|
| 757 |
+
ff_output = gate_mlp * ff_output
|
| 758 |
+
|
| 759 |
+
hidden_states = hidden_states + ff_output
|
| 760 |
+
|
| 761 |
+
return hidden_states
|
ComfyUI/comfy/ldm/ace/lyric_encoder.py
ADDED
|
@@ -0,0 +1,1067 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/lyrics_utils/lyric_encoder.py
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
import comfy.model_management
|
| 8 |
+
|
| 9 |
+
class ConvolutionModule(nn.Module):
|
| 10 |
+
"""ConvolutionModule in Conformer model."""
|
| 11 |
+
|
| 12 |
+
def __init__(self,
|
| 13 |
+
channels: int,
|
| 14 |
+
kernel_size: int = 15,
|
| 15 |
+
activation: nn.Module = nn.ReLU(),
|
| 16 |
+
norm: str = "batch_norm",
|
| 17 |
+
causal: bool = False,
|
| 18 |
+
bias: bool = True,
|
| 19 |
+
dtype=None, device=None, operations=None):
|
| 20 |
+
"""Construct an ConvolutionModule object.
|
| 21 |
+
Args:
|
| 22 |
+
channels (int): The number of channels of conv layers.
|
| 23 |
+
kernel_size (int): Kernel size of conv layers.
|
| 24 |
+
causal (int): Whether use causal convolution or not
|
| 25 |
+
"""
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.pointwise_conv1 = operations.Conv1d(
|
| 29 |
+
channels,
|
| 30 |
+
2 * channels,
|
| 31 |
+
kernel_size=1,
|
| 32 |
+
stride=1,
|
| 33 |
+
padding=0,
|
| 34 |
+
bias=bias,
|
| 35 |
+
dtype=dtype, device=device
|
| 36 |
+
)
|
| 37 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
| 38 |
+
# if self.lorder > 0: it's a causal convolution, the input will be
|
| 39 |
+
# padded with self.lorder frames on the left in forward.
|
| 40 |
+
# else: it's a symmetrical convolution
|
| 41 |
+
if causal:
|
| 42 |
+
padding = 0
|
| 43 |
+
self.lorder = kernel_size - 1
|
| 44 |
+
else:
|
| 45 |
+
# kernel_size should be an odd number for none causal convolution
|
| 46 |
+
assert (kernel_size - 1) % 2 == 0
|
| 47 |
+
padding = (kernel_size - 1) // 2
|
| 48 |
+
self.lorder = 0
|
| 49 |
+
self.depthwise_conv = operations.Conv1d(
|
| 50 |
+
channels,
|
| 51 |
+
channels,
|
| 52 |
+
kernel_size,
|
| 53 |
+
stride=1,
|
| 54 |
+
padding=padding,
|
| 55 |
+
groups=channels,
|
| 56 |
+
bias=bias,
|
| 57 |
+
dtype=dtype, device=device
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
assert norm in ['batch_norm', 'layer_norm']
|
| 61 |
+
if norm == "batch_norm":
|
| 62 |
+
self.use_layer_norm = False
|
| 63 |
+
self.norm = nn.BatchNorm1d(channels)
|
| 64 |
+
else:
|
| 65 |
+
self.use_layer_norm = True
|
| 66 |
+
self.norm = operations.LayerNorm(channels, dtype=dtype, device=device)
|
| 67 |
+
|
| 68 |
+
self.pointwise_conv2 = operations.Conv1d(
|
| 69 |
+
channels,
|
| 70 |
+
channels,
|
| 71 |
+
kernel_size=1,
|
| 72 |
+
stride=1,
|
| 73 |
+
padding=0,
|
| 74 |
+
bias=bias,
|
| 75 |
+
dtype=dtype, device=device
|
| 76 |
+
)
|
| 77 |
+
self.activation = activation
|
| 78 |
+
|
| 79 |
+
def forward(
|
| 80 |
+
self,
|
| 81 |
+
x: torch.Tensor,
|
| 82 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 83 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
| 84 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 85 |
+
"""Compute convolution module.
|
| 86 |
+
Args:
|
| 87 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
| 88 |
+
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
| 89 |
+
(0, 0, 0) means fake mask.
|
| 90 |
+
cache (torch.Tensor): left context cache, it is only
|
| 91 |
+
used in causal convolution (#batch, channels, cache_t),
|
| 92 |
+
(0, 0, 0) meas fake cache.
|
| 93 |
+
Returns:
|
| 94 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
| 95 |
+
"""
|
| 96 |
+
# exchange the temporal dimension and the feature dimension
|
| 97 |
+
x = x.transpose(1, 2) # (#batch, channels, time)
|
| 98 |
+
|
| 99 |
+
# mask batch padding
|
| 100 |
+
if mask_pad.size(2) > 0: # time > 0
|
| 101 |
+
x.masked_fill_(~mask_pad, 0.0)
|
| 102 |
+
|
| 103 |
+
if self.lorder > 0:
|
| 104 |
+
if cache.size(2) == 0: # cache_t == 0
|
| 105 |
+
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
| 106 |
+
else:
|
| 107 |
+
assert cache.size(0) == x.size(0) # equal batch
|
| 108 |
+
assert cache.size(1) == x.size(1) # equal channel
|
| 109 |
+
x = torch.cat((cache, x), dim=2)
|
| 110 |
+
assert (x.size(2) > self.lorder)
|
| 111 |
+
new_cache = x[:, :, -self.lorder:]
|
| 112 |
+
else:
|
| 113 |
+
# It's better we just return None if no cache is required,
|
| 114 |
+
# However, for JIT export, here we just fake one tensor instead of
|
| 115 |
+
# None.
|
| 116 |
+
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
| 117 |
+
|
| 118 |
+
# GLU mechanism
|
| 119 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
| 120 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
| 121 |
+
|
| 122 |
+
# 1D Depthwise Conv
|
| 123 |
+
x = self.depthwise_conv(x)
|
| 124 |
+
if self.use_layer_norm:
|
| 125 |
+
x = x.transpose(1, 2)
|
| 126 |
+
x = self.activation(self.norm(x))
|
| 127 |
+
if self.use_layer_norm:
|
| 128 |
+
x = x.transpose(1, 2)
|
| 129 |
+
x = self.pointwise_conv2(x)
|
| 130 |
+
# mask batch padding
|
| 131 |
+
if mask_pad.size(2) > 0: # time > 0
|
| 132 |
+
x.masked_fill_(~mask_pad, 0.0)
|
| 133 |
+
|
| 134 |
+
return x.transpose(1, 2), new_cache
|
| 135 |
+
|
| 136 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
| 137 |
+
"""Positionwise feed forward layer.
|
| 138 |
+
|
| 139 |
+
FeedForward are appied on each position of the sequence.
|
| 140 |
+
The output dim is same with the input dim.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
idim (int): Input dimenstion.
|
| 144 |
+
hidden_units (int): The number of hidden units.
|
| 145 |
+
dropout_rate (float): Dropout rate.
|
| 146 |
+
activation (torch.nn.Module): Activation function
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
idim: int,
|
| 152 |
+
hidden_units: int,
|
| 153 |
+
dropout_rate: float,
|
| 154 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 155 |
+
dtype=None, device=None, operations=None
|
| 156 |
+
):
|
| 157 |
+
"""Construct a PositionwiseFeedForward object."""
|
| 158 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 159 |
+
self.w_1 = operations.Linear(idim, hidden_units, dtype=dtype, device=device)
|
| 160 |
+
self.activation = activation
|
| 161 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 162 |
+
self.w_2 = operations.Linear(hidden_units, idim, dtype=dtype, device=device)
|
| 163 |
+
|
| 164 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
"""Forward function.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
xs: input tensor (B, L, D)
|
| 169 |
+
Returns:
|
| 170 |
+
output tensor, (B, L, D)
|
| 171 |
+
"""
|
| 172 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
| 173 |
+
|
| 174 |
+
class Swish(torch.nn.Module):
|
| 175 |
+
"""Construct an Swish object."""
|
| 176 |
+
|
| 177 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 178 |
+
"""Return Swish activation function."""
|
| 179 |
+
return x * torch.sigmoid(x)
|
| 180 |
+
|
| 181 |
+
class MultiHeadedAttention(nn.Module):
|
| 182 |
+
"""Multi-Head Attention layer.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
n_head (int): The number of heads.
|
| 186 |
+
n_feat (int): The number of features.
|
| 187 |
+
dropout_rate (float): Dropout rate.
|
| 188 |
+
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(self,
|
| 192 |
+
n_head: int,
|
| 193 |
+
n_feat: int,
|
| 194 |
+
dropout_rate: float,
|
| 195 |
+
key_bias: bool = True,
|
| 196 |
+
dtype=None, device=None, operations=None):
|
| 197 |
+
"""Construct an MultiHeadedAttention object."""
|
| 198 |
+
super().__init__()
|
| 199 |
+
assert n_feat % n_head == 0
|
| 200 |
+
# We assume d_v always equals d_k
|
| 201 |
+
self.d_k = n_feat // n_head
|
| 202 |
+
self.h = n_head
|
| 203 |
+
self.linear_q = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
|
| 204 |
+
self.linear_k = operations.Linear(n_feat, n_feat, bias=key_bias, dtype=dtype, device=device)
|
| 205 |
+
self.linear_v = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
|
| 206 |
+
self.linear_out = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
|
| 207 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 208 |
+
|
| 209 |
+
def forward_qkv(
|
| 210 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
| 211 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 212 |
+
"""Transform query, key and value.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 216 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 217 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
torch.Tensor: Transformed query tensor, size
|
| 221 |
+
(#batch, n_head, time1, d_k).
|
| 222 |
+
torch.Tensor: Transformed key tensor, size
|
| 223 |
+
(#batch, n_head, time2, d_k).
|
| 224 |
+
torch.Tensor: Transformed value tensor, size
|
| 225 |
+
(#batch, n_head, time2, d_k).
|
| 226 |
+
|
| 227 |
+
"""
|
| 228 |
+
n_batch = query.size(0)
|
| 229 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
| 230 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
| 231 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
| 232 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
| 233 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 234 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 235 |
+
return q, k, v
|
| 236 |
+
|
| 237 |
+
def forward_attention(
|
| 238 |
+
self,
|
| 239 |
+
value: torch.Tensor,
|
| 240 |
+
scores: torch.Tensor,
|
| 241 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
| 242 |
+
) -> torch.Tensor:
|
| 243 |
+
"""Compute attention context vector.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
value (torch.Tensor): Transformed value, size
|
| 247 |
+
(#batch, n_head, time2, d_k).
|
| 248 |
+
scores (torch.Tensor): Attention score, size
|
| 249 |
+
(#batch, n_head, time1, time2).
|
| 250 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
| 251 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
| 255 |
+
weighted by the attention score (#batch, time1, time2).
|
| 256 |
+
|
| 257 |
+
"""
|
| 258 |
+
n_batch = value.size(0)
|
| 259 |
+
|
| 260 |
+
if mask is not None and mask.size(2) > 0: # time2 > 0
|
| 261 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
| 262 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
| 263 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
| 264 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
| 265 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
| 266 |
+
mask, 0.0) # (batch, head, time1, time2)
|
| 267 |
+
|
| 268 |
+
else:
|
| 269 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 270 |
+
|
| 271 |
+
p_attn = self.dropout(attn)
|
| 272 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 273 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
| 274 |
+
self.h * self.d_k)
|
| 275 |
+
) # (batch, time1, d_model)
|
| 276 |
+
|
| 277 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
| 278 |
+
|
| 279 |
+
def forward(
|
| 280 |
+
self,
|
| 281 |
+
query: torch.Tensor,
|
| 282 |
+
key: torch.Tensor,
|
| 283 |
+
value: torch.Tensor,
|
| 284 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 285 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
| 286 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
| 287 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 288 |
+
"""Compute scaled dot product attention.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 292 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 293 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 294 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 295 |
+
(#batch, time1, time2).
|
| 296 |
+
1.When applying cross attention between decoder and encoder,
|
| 297 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
| 298 |
+
2.When applying self attention of encoder,
|
| 299 |
+
the mask is in (#batch, T, T) shape.
|
| 300 |
+
3.When applying self attention of decoder,
|
| 301 |
+
the mask is in (#batch, L, L) shape.
|
| 302 |
+
4.If the different position in decoder see different block
|
| 303 |
+
of the encoder, such as Mocha, the passed in mask could be
|
| 304 |
+
in (#batch, L, T) shape. But there is no such case in current
|
| 305 |
+
CosyVoice.
|
| 306 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
| 307 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 308 |
+
and `head * d_k == size`
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 313 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
| 314 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 315 |
+
and `head * d_k == size`
|
| 316 |
+
|
| 317 |
+
"""
|
| 318 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 319 |
+
if cache.size(0) > 0:
|
| 320 |
+
key_cache, value_cache = torch.split(cache,
|
| 321 |
+
cache.size(-1) // 2,
|
| 322 |
+
dim=-1)
|
| 323 |
+
k = torch.cat([key_cache, k], dim=2)
|
| 324 |
+
v = torch.cat([value_cache, v], dim=2)
|
| 325 |
+
new_cache = torch.cat((k, v), dim=-1)
|
| 326 |
+
|
| 327 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 328 |
+
return self.forward_attention(v, scores, mask), new_cache
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 332 |
+
"""Multi-Head Attention layer with relative position encoding.
|
| 333 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 334 |
+
Args:
|
| 335 |
+
n_head (int): The number of heads.
|
| 336 |
+
n_feat (int): The number of features.
|
| 337 |
+
dropout_rate (float): Dropout rate.
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def __init__(self,
|
| 341 |
+
n_head: int,
|
| 342 |
+
n_feat: int,
|
| 343 |
+
dropout_rate: float,
|
| 344 |
+
key_bias: bool = True,
|
| 345 |
+
dtype=None, device=None, operations=None):
|
| 346 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
| 347 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias, dtype=dtype, device=device, operations=operations)
|
| 348 |
+
# linear transformation for positional encoding
|
| 349 |
+
self.linear_pos = operations.Linear(n_feat, n_feat, bias=False, dtype=dtype, device=device)
|
| 350 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 351 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 352 |
+
self.pos_bias_u = nn.Parameter(torch.empty(self.h, self.d_k, dtype=dtype, device=device))
|
| 353 |
+
self.pos_bias_v = nn.Parameter(torch.empty(self.h, self.d_k, dtype=dtype, device=device))
|
| 354 |
+
# torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 355 |
+
# torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 356 |
+
|
| 357 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
| 358 |
+
"""Compute relative positional encoding.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
| 362 |
+
time1 means the length of query vector.
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
torch.Tensor: Output tensor.
|
| 366 |
+
|
| 367 |
+
"""
|
| 368 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
| 369 |
+
device=x.device,
|
| 370 |
+
dtype=x.dtype)
|
| 371 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 372 |
+
|
| 373 |
+
x_padded = x_padded.view(x.size()[0],
|
| 374 |
+
x.size()[1],
|
| 375 |
+
x.size(3) + 1, x.size(2))
|
| 376 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
| 377 |
+
:, :, :, : x.size(-1) // 2 + 1
|
| 378 |
+
] # only keep the positions from 0 to time2
|
| 379 |
+
return x
|
| 380 |
+
|
| 381 |
+
def forward(
|
| 382 |
+
self,
|
| 383 |
+
query: torch.Tensor,
|
| 384 |
+
key: torch.Tensor,
|
| 385 |
+
value: torch.Tensor,
|
| 386 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 387 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
| 388 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
| 389 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 390 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 391 |
+
Args:
|
| 392 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 393 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 394 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 395 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 396 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
| 397 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
| 398 |
+
(#batch, time2, size).
|
| 399 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
| 400 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 401 |
+
and `head * d_k == size`
|
| 402 |
+
Returns:
|
| 403 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 404 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
| 405 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
| 406 |
+
and `head * d_k == size`
|
| 407 |
+
"""
|
| 408 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 409 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 410 |
+
|
| 411 |
+
if cache.size(0) > 0:
|
| 412 |
+
key_cache, value_cache = torch.split(cache,
|
| 413 |
+
cache.size(-1) // 2,
|
| 414 |
+
dim=-1)
|
| 415 |
+
k = torch.cat([key_cache, k], dim=2)
|
| 416 |
+
v = torch.cat([value_cache, v], dim=2)
|
| 417 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
| 418 |
+
# non-trivial to calculate `next_cache_start` here.
|
| 419 |
+
new_cache = torch.cat((k, v), dim=-1)
|
| 420 |
+
|
| 421 |
+
n_batch_pos = pos_emb.size(0)
|
| 422 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
| 423 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
| 424 |
+
|
| 425 |
+
# (batch, head, time1, d_k)
|
| 426 |
+
q_with_bias_u = (q + comfy.model_management.cast_to(self.pos_bias_u, dtype=q.dtype, device=q.device)).transpose(1, 2)
|
| 427 |
+
# (batch, head, time1, d_k)
|
| 428 |
+
q_with_bias_v = (q + comfy.model_management.cast_to(self.pos_bias_v, dtype=q.dtype, device=q.device)).transpose(1, 2)
|
| 429 |
+
|
| 430 |
+
# compute attention score
|
| 431 |
+
# first compute matrix a and matrix c
|
| 432 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 433 |
+
# (batch, head, time1, time2)
|
| 434 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 435 |
+
|
| 436 |
+
# compute matrix b and matrix d
|
| 437 |
+
# (batch, head, time1, time2)
|
| 438 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 439 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
| 440 |
+
if matrix_ac.shape != matrix_bd.shape:
|
| 441 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 442 |
+
|
| 443 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
| 444 |
+
self.d_k) # (batch, head, time1, time2)
|
| 445 |
+
|
| 446 |
+
return self.forward_attention(v, scores, mask), new_cache
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def subsequent_mask(
|
| 451 |
+
size: int,
|
| 452 |
+
device: torch.device = torch.device("cpu"),
|
| 453 |
+
) -> torch.Tensor:
|
| 454 |
+
"""Create mask for subsequent steps (size, size).
|
| 455 |
+
|
| 456 |
+
This mask is used only in decoder which works in an auto-regressive mode.
|
| 457 |
+
This means the current step could only do attention with its left steps.
|
| 458 |
+
|
| 459 |
+
In encoder, fully attention is used when streaming is not necessary and
|
| 460 |
+
the sequence is not long. In this case, no attention mask is needed.
|
| 461 |
+
|
| 462 |
+
When streaming is need, chunk-based attention is used in encoder. See
|
| 463 |
+
subsequent_chunk_mask for the chunk-based attention mask.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
size (int): size of mask
|
| 467 |
+
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
| 468 |
+
dtype (torch.device): result dtype
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
torch.Tensor: mask
|
| 472 |
+
|
| 473 |
+
Examples:
|
| 474 |
+
>>> subsequent_mask(3)
|
| 475 |
+
[[1, 0, 0],
|
| 476 |
+
[1, 1, 0],
|
| 477 |
+
[1, 1, 1]]
|
| 478 |
+
"""
|
| 479 |
+
arange = torch.arange(size, device=device)
|
| 480 |
+
mask = arange.expand(size, size)
|
| 481 |
+
arange = arange.unsqueeze(-1)
|
| 482 |
+
mask = mask <= arange
|
| 483 |
+
return mask
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def subsequent_chunk_mask(
|
| 487 |
+
size: int,
|
| 488 |
+
chunk_size: int,
|
| 489 |
+
num_left_chunks: int = -1,
|
| 490 |
+
device: torch.device = torch.device("cpu"),
|
| 491 |
+
) -> torch.Tensor:
|
| 492 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
| 493 |
+
this is for streaming encoder
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
size (int): size of mask
|
| 497 |
+
chunk_size (int): size of chunk
|
| 498 |
+
num_left_chunks (int): number of left chunks
|
| 499 |
+
<0: use full chunk
|
| 500 |
+
>=0: use num_left_chunks
|
| 501 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
torch.Tensor: mask
|
| 505 |
+
|
| 506 |
+
Examples:
|
| 507 |
+
>>> subsequent_chunk_mask(4, 2)
|
| 508 |
+
[[1, 1, 0, 0],
|
| 509 |
+
[1, 1, 0, 0],
|
| 510 |
+
[1, 1, 1, 1],
|
| 511 |
+
[1, 1, 1, 1]]
|
| 512 |
+
"""
|
| 513 |
+
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
| 514 |
+
for i in range(size):
|
| 515 |
+
if num_left_chunks < 0:
|
| 516 |
+
start = 0
|
| 517 |
+
else:
|
| 518 |
+
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
| 519 |
+
ending = min((i // chunk_size + 1) * chunk_size, size)
|
| 520 |
+
ret[i, start:ending] = True
|
| 521 |
+
return ret
|
| 522 |
+
|
| 523 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
| 524 |
+
masks: torch.Tensor,
|
| 525 |
+
use_dynamic_chunk: bool,
|
| 526 |
+
use_dynamic_left_chunk: bool,
|
| 527 |
+
decoding_chunk_size: int,
|
| 528 |
+
static_chunk_size: int,
|
| 529 |
+
num_decoding_left_chunks: int,
|
| 530 |
+
enable_full_context: bool = True):
|
| 531 |
+
""" Apply optional mask for encoder.
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
| 535 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
| 536 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
| 537 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
| 538 |
+
training.
|
| 539 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
| 540 |
+
0: default for training, use random dynamic chunk.
|
| 541 |
+
<0: for decoding, use full chunk.
|
| 542 |
+
>0: for decoding, use fixed chunk size as set.
|
| 543 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
| 544 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
| 545 |
+
this parameter will be ignored
|
| 546 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| 547 |
+
the chunk size is decoding_chunk_size.
|
| 548 |
+
>=0: use num_decoding_left_chunks
|
| 549 |
+
<0: use all left chunks
|
| 550 |
+
enable_full_context (bool):
|
| 551 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
| 552 |
+
False: chunk size ~ U[1, 25]
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
torch.Tensor: chunk mask of the input xs.
|
| 556 |
+
"""
|
| 557 |
+
# Whether to use chunk mask or not
|
| 558 |
+
if use_dynamic_chunk:
|
| 559 |
+
max_len = xs.size(1)
|
| 560 |
+
if decoding_chunk_size < 0:
|
| 561 |
+
chunk_size = max_len
|
| 562 |
+
num_left_chunks = -1
|
| 563 |
+
elif decoding_chunk_size > 0:
|
| 564 |
+
chunk_size = decoding_chunk_size
|
| 565 |
+
num_left_chunks = num_decoding_left_chunks
|
| 566 |
+
else:
|
| 567 |
+
# chunk size is either [1, 25] or full context(max_len).
|
| 568 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
| 569 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
| 570 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
| 571 |
+
num_left_chunks = -1
|
| 572 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
| 573 |
+
chunk_size = max_len
|
| 574 |
+
else:
|
| 575 |
+
chunk_size = chunk_size % 25 + 1
|
| 576 |
+
if use_dynamic_left_chunk:
|
| 577 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
| 578 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
| 579 |
+
(1, )).item()
|
| 580 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
| 581 |
+
num_left_chunks,
|
| 582 |
+
xs.device) # (L, L)
|
| 583 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 584 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 585 |
+
elif static_chunk_size > 0:
|
| 586 |
+
num_left_chunks = num_decoding_left_chunks
|
| 587 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
| 588 |
+
num_left_chunks,
|
| 589 |
+
xs.device) # (L, L)
|
| 590 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 591 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 592 |
+
else:
|
| 593 |
+
chunk_masks = masks
|
| 594 |
+
return chunk_masks
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
class ConformerEncoderLayer(nn.Module):
|
| 598 |
+
"""Encoder layer module.
|
| 599 |
+
Args:
|
| 600 |
+
size (int): Input dimension.
|
| 601 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 602 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
| 603 |
+
instance can be used as the argument.
|
| 604 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 605 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 606 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
| 607 |
+
instance.
|
| 608 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 609 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
| 610 |
+
`ConvlutionModule` instance can be used as the argument.
|
| 611 |
+
dropout_rate (float): Dropout rate.
|
| 612 |
+
normalize_before (bool):
|
| 613 |
+
True: use layer_norm before each sub-block.
|
| 614 |
+
False: use layer_norm after each sub-block.
|
| 615 |
+
"""
|
| 616 |
+
|
| 617 |
+
def __init__(
|
| 618 |
+
self,
|
| 619 |
+
size: int,
|
| 620 |
+
self_attn: torch.nn.Module,
|
| 621 |
+
feed_forward: Optional[nn.Module] = None,
|
| 622 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
| 623 |
+
conv_module: Optional[nn.Module] = None,
|
| 624 |
+
dropout_rate: float = 0.1,
|
| 625 |
+
normalize_before: bool = True,
|
| 626 |
+
dtype=None, device=None, operations=None
|
| 627 |
+
):
|
| 628 |
+
"""Construct an EncoderLayer object."""
|
| 629 |
+
super().__init__()
|
| 630 |
+
self.self_attn = self_attn
|
| 631 |
+
self.feed_forward = feed_forward
|
| 632 |
+
self.feed_forward_macaron = feed_forward_macaron
|
| 633 |
+
self.conv_module = conv_module
|
| 634 |
+
self.norm_ff = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the FNN module
|
| 635 |
+
self.norm_mha = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the MHA module
|
| 636 |
+
if feed_forward_macaron is not None:
|
| 637 |
+
self.norm_ff_macaron = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device)
|
| 638 |
+
self.ff_scale = 0.5
|
| 639 |
+
else:
|
| 640 |
+
self.ff_scale = 1.0
|
| 641 |
+
if self.conv_module is not None:
|
| 642 |
+
self.norm_conv = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the CNN module
|
| 643 |
+
self.norm_final = operations.LayerNorm(
|
| 644 |
+
size, eps=1e-5, dtype=dtype, device=device) # for the final output of the block
|
| 645 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 646 |
+
self.size = size
|
| 647 |
+
self.normalize_before = normalize_before
|
| 648 |
+
|
| 649 |
+
def forward(
|
| 650 |
+
self,
|
| 651 |
+
x: torch.Tensor,
|
| 652 |
+
mask: torch.Tensor,
|
| 653 |
+
pos_emb: torch.Tensor,
|
| 654 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
| 655 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
| 656 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
| 657 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 658 |
+
"""Compute encoded features.
|
| 659 |
+
|
| 660 |
+
Args:
|
| 661 |
+
x (torch.Tensor): (#batch, time, size)
|
| 662 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
| 663 |
+
(0, 0, 0) means fake mask.
|
| 664 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
| 665 |
+
for ConformerEncoderLayer.
|
| 666 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
| 667 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
| 668 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
| 669 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
| 670 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
| 671 |
+
(#batch=1, size, cache_t2)
|
| 672 |
+
Returns:
|
| 673 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 674 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
| 675 |
+
torch.Tensor: att_cache tensor,
|
| 676 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
| 677 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
# whether to use macaron style
|
| 681 |
+
if self.feed_forward_macaron is not None:
|
| 682 |
+
residual = x
|
| 683 |
+
if self.normalize_before:
|
| 684 |
+
x = self.norm_ff_macaron(x)
|
| 685 |
+
x = residual + self.ff_scale * self.dropout(
|
| 686 |
+
self.feed_forward_macaron(x))
|
| 687 |
+
if not self.normalize_before:
|
| 688 |
+
x = self.norm_ff_macaron(x)
|
| 689 |
+
|
| 690 |
+
# multi-headed self-attention module
|
| 691 |
+
residual = x
|
| 692 |
+
if self.normalize_before:
|
| 693 |
+
x = self.norm_mha(x)
|
| 694 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
| 695 |
+
att_cache)
|
| 696 |
+
x = residual + self.dropout(x_att)
|
| 697 |
+
if not self.normalize_before:
|
| 698 |
+
x = self.norm_mha(x)
|
| 699 |
+
|
| 700 |
+
# convolution module
|
| 701 |
+
# Fake new cnn cache here, and then change it in conv_module
|
| 702 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
| 703 |
+
if self.conv_module is not None:
|
| 704 |
+
residual = x
|
| 705 |
+
if self.normalize_before:
|
| 706 |
+
x = self.norm_conv(x)
|
| 707 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
| 708 |
+
x = residual + self.dropout(x)
|
| 709 |
+
|
| 710 |
+
if not self.normalize_before:
|
| 711 |
+
x = self.norm_conv(x)
|
| 712 |
+
|
| 713 |
+
# feed forward module
|
| 714 |
+
residual = x
|
| 715 |
+
if self.normalize_before:
|
| 716 |
+
x = self.norm_ff(x)
|
| 717 |
+
|
| 718 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
| 719 |
+
if not self.normalize_before:
|
| 720 |
+
x = self.norm_ff(x)
|
| 721 |
+
|
| 722 |
+
if self.conv_module is not None:
|
| 723 |
+
x = self.norm_final(x)
|
| 724 |
+
|
| 725 |
+
return x, mask, new_att_cache, new_cnn_cache
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
| 730 |
+
"""Relative positional encoding module (new implementation).
|
| 731 |
+
|
| 732 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
| 733 |
+
|
| 734 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
d_model (int): Embedding dimension.
|
| 738 |
+
dropout_rate (float): Dropout rate.
|
| 739 |
+
max_len (int): Maximum input length.
|
| 740 |
+
|
| 741 |
+
"""
|
| 742 |
+
|
| 743 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
| 744 |
+
"""Construct an PositionalEncoding object."""
|
| 745 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
| 746 |
+
self.d_model = d_model
|
| 747 |
+
self.xscale = math.sqrt(self.d_model)
|
| 748 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 749 |
+
self.pe = None
|
| 750 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 751 |
+
|
| 752 |
+
def extend_pe(self, x: torch.Tensor):
|
| 753 |
+
"""Reset the positional encodings."""
|
| 754 |
+
if self.pe is not None:
|
| 755 |
+
# self.pe contains both positive and negative parts
|
| 756 |
+
# the length of self.pe is 2 * input_len - 1
|
| 757 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 758 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 759 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 760 |
+
return
|
| 761 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
| 762 |
+
# position of key vector. We use position relative positions when keys
|
| 763 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 764 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 765 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 766 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 767 |
+
div_term = torch.exp(
|
| 768 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 769 |
+
* -(math.log(10000.0) / self.d_model)
|
| 770 |
+
)
|
| 771 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 772 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 773 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 774 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 775 |
+
|
| 776 |
+
# Reserve the order of positive indices and concat both positive and
|
| 777 |
+
# negative indices. This is used to support the shifting trick
|
| 778 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 779 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 780 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 781 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 782 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 783 |
+
|
| 784 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
| 785 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
| 786 |
+
"""Add positional encoding.
|
| 787 |
+
|
| 788 |
+
Args:
|
| 789 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 790 |
+
|
| 791 |
+
Returns:
|
| 792 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 793 |
+
|
| 794 |
+
"""
|
| 795 |
+
self.extend_pe(x)
|
| 796 |
+
x = x * self.xscale
|
| 797 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
| 798 |
+
return self.dropout(x), self.dropout(pos_emb)
|
| 799 |
+
|
| 800 |
+
def position_encoding(self,
|
| 801 |
+
offset: Union[int, torch.Tensor],
|
| 802 |
+
size: int) -> torch.Tensor:
|
| 803 |
+
""" For getting encoding in a streaming fashion
|
| 804 |
+
|
| 805 |
+
Attention!!!!!
|
| 806 |
+
we apply dropout only once at the whole utterance level in a none
|
| 807 |
+
streaming way, but will call this function several times with
|
| 808 |
+
increasing input size in a streaming scenario, so the dropout will
|
| 809 |
+
be applied several times.
|
| 810 |
+
|
| 811 |
+
Args:
|
| 812 |
+
offset (int or torch.tensor): start offset
|
| 813 |
+
size (int): required size of position encoding
|
| 814 |
+
|
| 815 |
+
Returns:
|
| 816 |
+
torch.Tensor: Corresponding encoding
|
| 817 |
+
"""
|
| 818 |
+
pos_emb = self.pe[
|
| 819 |
+
:,
|
| 820 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
| 821 |
+
]
|
| 822 |
+
return pos_emb
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
class LinearEmbed(torch.nn.Module):
|
| 827 |
+
"""Linear transform the input without subsampling
|
| 828 |
+
|
| 829 |
+
Args:
|
| 830 |
+
idim (int): Input dimension.
|
| 831 |
+
odim (int): Output dimension.
|
| 832 |
+
dropout_rate (float): Dropout rate.
|
| 833 |
+
|
| 834 |
+
"""
|
| 835 |
+
|
| 836 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
| 837 |
+
pos_enc_class: torch.nn.Module, dtype=None, device=None, operations=None):
|
| 838 |
+
"""Construct an linear object."""
|
| 839 |
+
super().__init__()
|
| 840 |
+
self.out = torch.nn.Sequential(
|
| 841 |
+
operations.Linear(idim, odim, dtype=dtype, device=device),
|
| 842 |
+
operations.LayerNorm(odim, eps=1e-5, dtype=dtype, device=device),
|
| 843 |
+
torch.nn.Dropout(dropout_rate),
|
| 844 |
+
)
|
| 845 |
+
self.pos_enc = pos_enc_class #rel_pos_espnet
|
| 846 |
+
|
| 847 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
| 848 |
+
size: int) -> torch.Tensor:
|
| 849 |
+
return self.pos_enc.position_encoding(offset, size)
|
| 850 |
+
|
| 851 |
+
def forward(
|
| 852 |
+
self,
|
| 853 |
+
x: torch.Tensor,
|
| 854 |
+
offset: Union[int, torch.Tensor] = 0
|
| 855 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 856 |
+
"""Input x.
|
| 857 |
+
|
| 858 |
+
Args:
|
| 859 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 860 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 861 |
+
|
| 862 |
+
Returns:
|
| 863 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
| 864 |
+
where time' = time .
|
| 865 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
| 866 |
+
where time' = time .
|
| 867 |
+
|
| 868 |
+
"""
|
| 869 |
+
x = self.out(x)
|
| 870 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 871 |
+
return x, pos_emb
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
ATTENTION_CLASSES = {
|
| 875 |
+
"selfattn": MultiHeadedAttention,
|
| 876 |
+
"rel_selfattn": RelPositionMultiHeadedAttention,
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
ACTIVATION_CLASSES = {
|
| 880 |
+
"hardtanh": torch.nn.Hardtanh,
|
| 881 |
+
"tanh": torch.nn.Tanh,
|
| 882 |
+
"relu": torch.nn.ReLU,
|
| 883 |
+
"selu": torch.nn.SELU,
|
| 884 |
+
"swish": getattr(torch.nn, "SiLU", Swish),
|
| 885 |
+
"gelu": torch.nn.GELU,
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 890 |
+
"""Make mask tensor containing indices of padded part.
|
| 891 |
+
|
| 892 |
+
See description of make_non_pad_mask.
|
| 893 |
+
|
| 894 |
+
Args:
|
| 895 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
| 896 |
+
Returns:
|
| 897 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
| 898 |
+
|
| 899 |
+
Examples:
|
| 900 |
+
>>> lengths = [5, 3, 2]
|
| 901 |
+
>>> make_pad_mask(lengths)
|
| 902 |
+
masks = [[0, 0, 0, 0 ,0],
|
| 903 |
+
[0, 0, 0, 1, 1],
|
| 904 |
+
[0, 0, 1, 1, 1]]
|
| 905 |
+
"""
|
| 906 |
+
batch_size = lengths.size(0)
|
| 907 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
| 908 |
+
seq_range = torch.arange(0,
|
| 909 |
+
max_len,
|
| 910 |
+
dtype=torch.int64,
|
| 911 |
+
device=lengths.device)
|
| 912 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
| 913 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
| 914 |
+
mask = seq_range_expand >= seq_length_expand
|
| 915 |
+
return mask
|
| 916 |
+
|
| 917 |
+
#https://github.com/FunAudioLLM/CosyVoice/blob/main/examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml
|
| 918 |
+
class ConformerEncoder(torch.nn.Module):
|
| 919 |
+
"""Conformer encoder module."""
|
| 920 |
+
|
| 921 |
+
def __init__(
|
| 922 |
+
self,
|
| 923 |
+
input_size: int,
|
| 924 |
+
output_size: int = 1024,
|
| 925 |
+
attention_heads: int = 16,
|
| 926 |
+
linear_units: int = 4096,
|
| 927 |
+
num_blocks: int = 6,
|
| 928 |
+
dropout_rate: float = 0.1,
|
| 929 |
+
positional_dropout_rate: float = 0.1,
|
| 930 |
+
attention_dropout_rate: float = 0.0,
|
| 931 |
+
input_layer: str = 'linear',
|
| 932 |
+
pos_enc_layer_type: str = 'rel_pos_espnet',
|
| 933 |
+
normalize_before: bool = True,
|
| 934 |
+
static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask
|
| 935 |
+
use_dynamic_chunk: bool = False,
|
| 936 |
+
use_dynamic_left_chunk: bool = False,
|
| 937 |
+
positionwise_conv_kernel_size: int = 1,
|
| 938 |
+
macaron_style: bool =False,
|
| 939 |
+
selfattention_layer_type: str = "rel_selfattn",
|
| 940 |
+
activation_type: str = "swish",
|
| 941 |
+
use_cnn_module: bool = False,
|
| 942 |
+
cnn_module_kernel: int = 15,
|
| 943 |
+
causal: bool = False,
|
| 944 |
+
cnn_module_norm: str = "batch_norm",
|
| 945 |
+
key_bias: bool = True,
|
| 946 |
+
dtype=None, device=None, operations=None
|
| 947 |
+
):
|
| 948 |
+
"""Construct ConformerEncoder
|
| 949 |
+
|
| 950 |
+
Args:
|
| 951 |
+
input_size to use_dynamic_chunk, see in BaseEncoder
|
| 952 |
+
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
| 953 |
+
conv1d layer.
|
| 954 |
+
macaron_style (bool): Whether to use macaron style for
|
| 955 |
+
positionwise layer.
|
| 956 |
+
selfattention_layer_type (str): Encoder attention layer type,
|
| 957 |
+
the parameter has no effect now, it's just for configure
|
| 958 |
+
compatibility. #'rel_selfattn'
|
| 959 |
+
activation_type (str): Encoder activation function type.
|
| 960 |
+
use_cnn_module (bool): Whether to use convolution module.
|
| 961 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
| 962 |
+
causal (bool): whether to use causal convolution or not.
|
| 963 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
| 964 |
+
"""
|
| 965 |
+
super().__init__()
|
| 966 |
+
self.output_size = output_size
|
| 967 |
+
self.embed = LinearEmbed(input_size, output_size, dropout_rate,
|
| 968 |
+
EspnetRelPositionalEncoding(output_size, positional_dropout_rate), dtype=dtype, device=device, operations=operations)
|
| 969 |
+
self.normalize_before = normalize_before
|
| 970 |
+
self.after_norm = operations.LayerNorm(output_size, eps=1e-5, dtype=dtype, device=device)
|
| 971 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
| 972 |
+
|
| 973 |
+
self.static_chunk_size = static_chunk_size
|
| 974 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
| 975 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
| 976 |
+
activation = ACTIVATION_CLASSES[activation_type]()
|
| 977 |
+
|
| 978 |
+
# self-attention module definition
|
| 979 |
+
encoder_selfattn_layer_args = (
|
| 980 |
+
attention_heads,
|
| 981 |
+
output_size,
|
| 982 |
+
attention_dropout_rate,
|
| 983 |
+
key_bias,
|
| 984 |
+
)
|
| 985 |
+
# feed-forward module definition
|
| 986 |
+
positionwise_layer_args = (
|
| 987 |
+
output_size,
|
| 988 |
+
linear_units,
|
| 989 |
+
dropout_rate,
|
| 990 |
+
activation,
|
| 991 |
+
)
|
| 992 |
+
# convolution module definition
|
| 993 |
+
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
| 994 |
+
cnn_module_norm, causal)
|
| 995 |
+
|
| 996 |
+
self.encoders = torch.nn.ModuleList([
|
| 997 |
+
ConformerEncoderLayer(
|
| 998 |
+
output_size,
|
| 999 |
+
RelPositionMultiHeadedAttention(
|
| 1000 |
+
*encoder_selfattn_layer_args, dtype=dtype, device=device, operations=operations),
|
| 1001 |
+
PositionwiseFeedForward(*positionwise_layer_args, dtype=dtype, device=device, operations=operations),
|
| 1002 |
+
PositionwiseFeedForward(
|
| 1003 |
+
*positionwise_layer_args, dtype=dtype, device=device, operations=operations) if macaron_style else None,
|
| 1004 |
+
ConvolutionModule(
|
| 1005 |
+
*convolution_layer_args, dtype=dtype, device=device, operations=operations) if use_cnn_module else None,
|
| 1006 |
+
dropout_rate,
|
| 1007 |
+
normalize_before, dtype=dtype, device=device, operations=operations
|
| 1008 |
+
) for _ in range(num_blocks)
|
| 1009 |
+
])
|
| 1010 |
+
|
| 1011 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
| 1012 |
+
pos_emb: torch.Tensor,
|
| 1013 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
| 1014 |
+
for layer in self.encoders:
|
| 1015 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
| 1016 |
+
return xs
|
| 1017 |
+
|
| 1018 |
+
def forward(
|
| 1019 |
+
self,
|
| 1020 |
+
xs: torch.Tensor,
|
| 1021 |
+
pad_mask: torch.Tensor,
|
| 1022 |
+
decoding_chunk_size: int = 0,
|
| 1023 |
+
num_decoding_left_chunks: int = -1,
|
| 1024 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1025 |
+
"""Embed positions in tensor.
|
| 1026 |
+
|
| 1027 |
+
Args:
|
| 1028 |
+
xs: padded input tensor (B, T, D)
|
| 1029 |
+
xs_lens: input length (B)
|
| 1030 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
| 1031 |
+
0: default for training, use random dynamic chunk.
|
| 1032 |
+
<0: for decoding, use full chunk.
|
| 1033 |
+
>0: for decoding, use fixed chunk size as set.
|
| 1034 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| 1035 |
+
the chunk size is decoding_chunk_size.
|
| 1036 |
+
>=0: use num_decoding_left_chunks
|
| 1037 |
+
<0: use all left chunks
|
| 1038 |
+
Returns:
|
| 1039 |
+
encoder output tensor xs, and subsampled masks
|
| 1040 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
| 1041 |
+
masks: torch.Tensor batch padding mask after subsample
|
| 1042 |
+
(B, 1, T' ~= T/subsample_rate)
|
| 1043 |
+
NOTE(xcsong):
|
| 1044 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
| 1045 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
| 1046 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
| 1047 |
+
"""
|
| 1048 |
+
masks = None
|
| 1049 |
+
if pad_mask is not None:
|
| 1050 |
+
masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T)
|
| 1051 |
+
xs, pos_emb = self.embed(xs)
|
| 1052 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
| 1053 |
+
chunk_masks = add_optional_chunk_mask(xs, masks,
|
| 1054 |
+
self.use_dynamic_chunk,
|
| 1055 |
+
self.use_dynamic_left_chunk,
|
| 1056 |
+
decoding_chunk_size,
|
| 1057 |
+
self.static_chunk_size,
|
| 1058 |
+
num_decoding_left_chunks)
|
| 1059 |
+
|
| 1060 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
| 1061 |
+
if self.normalize_before:
|
| 1062 |
+
xs = self.after_norm(xs)
|
| 1063 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
| 1064 |
+
# return the masks before encoder layers, and the masks will be used
|
| 1065 |
+
# for cross attention with decoder later
|
| 1066 |
+
return xs, masks
|
| 1067 |
+
|
ComfyUI/comfy/ldm/ace/model.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
|
| 2 |
+
|
| 3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from typing import Optional, List, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
import comfy.model_management
|
| 22 |
+
|
| 23 |
+
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
| 24 |
+
from .attention import LinearTransformerBlock, t2i_modulate
|
| 25 |
+
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def cross_norm(hidden_states, controlnet_input):
|
| 29 |
+
# input N x T x c
|
| 30 |
+
mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
|
| 31 |
+
mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
|
| 32 |
+
controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
|
| 33 |
+
return controlnet_input
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
| 37 |
+
class Qwen2RotaryEmbedding(nn.Module):
|
| 38 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.dim = dim
|
| 42 |
+
self.max_position_embeddings = max_position_embeddings
|
| 43 |
+
self.base = base
|
| 44 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
|
| 45 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 46 |
+
|
| 47 |
+
# Build here to make `torch.jit.trace` work.
|
| 48 |
+
self._set_cos_sin_cache(
|
| 49 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 53 |
+
self.max_seq_len_cached = seq_len
|
| 54 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 55 |
+
|
| 56 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 57 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 58 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 59 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 60 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 61 |
+
|
| 62 |
+
def forward(self, x, seq_len=None):
|
| 63 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 64 |
+
if seq_len > self.max_seq_len_cached:
|
| 65 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 66 |
+
|
| 67 |
+
return (
|
| 68 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 69 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class T2IFinalLayer(nn.Module):
|
| 74 |
+
"""
|
| 75 |
+
The final layer of Sana.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 81 |
+
self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
|
| 82 |
+
self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
|
| 83 |
+
self.out_channels = out_channels
|
| 84 |
+
self.patch_size = patch_size
|
| 85 |
+
|
| 86 |
+
def unpatchfy(
|
| 87 |
+
self,
|
| 88 |
+
hidden_states: torch.Tensor,
|
| 89 |
+
width: int,
|
| 90 |
+
):
|
| 91 |
+
# 4 unpatchify
|
| 92 |
+
new_height, new_width = 1, hidden_states.size(1)
|
| 93 |
+
hidden_states = hidden_states.reshape(
|
| 94 |
+
shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
|
| 95 |
+
).contiguous()
|
| 96 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
| 97 |
+
output = hidden_states.reshape(
|
| 98 |
+
shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
|
| 99 |
+
).contiguous()
|
| 100 |
+
if width > new_width:
|
| 101 |
+
output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
|
| 102 |
+
elif width < new_width:
|
| 103 |
+
output = output[:, :, :, :width]
|
| 104 |
+
return output
|
| 105 |
+
|
| 106 |
+
def forward(self, x, t, output_length):
|
| 107 |
+
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
| 108 |
+
x = t2i_modulate(self.norm_final(x), shift, scale)
|
| 109 |
+
x = self.linear(x)
|
| 110 |
+
# unpatchify
|
| 111 |
+
output = self.unpatchfy(x, output_length)
|
| 112 |
+
return output
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class PatchEmbed(nn.Module):
|
| 116 |
+
"""2D Image to Patch Embedding"""
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
height=16,
|
| 121 |
+
width=4096,
|
| 122 |
+
patch_size=(16, 1),
|
| 123 |
+
in_channels=8,
|
| 124 |
+
embed_dim=1152,
|
| 125 |
+
bias=True,
|
| 126 |
+
dtype=None, device=None, operations=None
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
patch_size_h, patch_size_w = patch_size
|
| 130 |
+
self.early_conv_layers = nn.Sequential(
|
| 131 |
+
operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
|
| 132 |
+
operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
|
| 133 |
+
operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
|
| 134 |
+
)
|
| 135 |
+
self.patch_size = patch_size
|
| 136 |
+
self.height, self.width = height // patch_size_h, width // patch_size_w
|
| 137 |
+
self.base_size = self.width
|
| 138 |
+
|
| 139 |
+
def forward(self, latent):
|
| 140 |
+
# early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
|
| 141 |
+
latent = self.early_conv_layers(latent)
|
| 142 |
+
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 143 |
+
return latent
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class ACEStepTransformer2DModel(nn.Module):
|
| 147 |
+
# _supports_gradient_checkpointing = True
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
in_channels: Optional[int] = 8,
|
| 152 |
+
num_layers: int = 28,
|
| 153 |
+
inner_dim: int = 1536,
|
| 154 |
+
attention_head_dim: int = 64,
|
| 155 |
+
num_attention_heads: int = 24,
|
| 156 |
+
mlp_ratio: float = 4.0,
|
| 157 |
+
out_channels: int = 8,
|
| 158 |
+
max_position: int = 32768,
|
| 159 |
+
rope_theta: float = 1000000.0,
|
| 160 |
+
speaker_embedding_dim: int = 512,
|
| 161 |
+
text_embedding_dim: int = 768,
|
| 162 |
+
ssl_encoder_depths: List[int] = [9, 9],
|
| 163 |
+
ssl_names: List[str] = ["mert", "m-hubert"],
|
| 164 |
+
ssl_latent_dims: List[int] = [1024, 768],
|
| 165 |
+
lyric_encoder_vocab_size: int = 6681,
|
| 166 |
+
lyric_hidden_size: int = 1024,
|
| 167 |
+
patch_size: List[int] = [16, 1],
|
| 168 |
+
max_height: int = 16,
|
| 169 |
+
max_width: int = 4096,
|
| 170 |
+
audio_model=None,
|
| 171 |
+
dtype=None, device=None, operations=None
|
| 172 |
+
|
| 173 |
+
):
|
| 174 |
+
super().__init__()
|
| 175 |
+
|
| 176 |
+
self.dtype = dtype
|
| 177 |
+
self.num_attention_heads = num_attention_heads
|
| 178 |
+
self.attention_head_dim = attention_head_dim
|
| 179 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 180 |
+
self.inner_dim = inner_dim
|
| 181 |
+
self.out_channels = out_channels
|
| 182 |
+
self.max_position = max_position
|
| 183 |
+
self.patch_size = patch_size
|
| 184 |
+
|
| 185 |
+
self.rope_theta = rope_theta
|
| 186 |
+
|
| 187 |
+
self.rotary_emb = Qwen2RotaryEmbedding(
|
| 188 |
+
dim=self.attention_head_dim,
|
| 189 |
+
max_position_embeddings=self.max_position,
|
| 190 |
+
base=self.rope_theta,
|
| 191 |
+
dtype=dtype,
|
| 192 |
+
device=device,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# 2. Define input layers
|
| 196 |
+
self.in_channels = in_channels
|
| 197 |
+
|
| 198 |
+
self.num_layers = num_layers
|
| 199 |
+
# 3. Define transformers blocks
|
| 200 |
+
self.transformer_blocks = nn.ModuleList(
|
| 201 |
+
[
|
| 202 |
+
LinearTransformerBlock(
|
| 203 |
+
dim=self.inner_dim,
|
| 204 |
+
num_attention_heads=self.num_attention_heads,
|
| 205 |
+
attention_head_dim=attention_head_dim,
|
| 206 |
+
mlp_ratio=mlp_ratio,
|
| 207 |
+
add_cross_attention=True,
|
| 208 |
+
add_cross_attention_dim=self.inner_dim,
|
| 209 |
+
dtype=dtype,
|
| 210 |
+
device=device,
|
| 211 |
+
operations=operations,
|
| 212 |
+
)
|
| 213 |
+
for i in range(self.num_layers)
|
| 214 |
+
]
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 218 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
|
| 219 |
+
self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
|
| 220 |
+
|
| 221 |
+
# speaker
|
| 222 |
+
self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
| 223 |
+
|
| 224 |
+
# genre
|
| 225 |
+
self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
|
| 226 |
+
|
| 227 |
+
# lyric
|
| 228 |
+
self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
|
| 229 |
+
self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
|
| 230 |
+
self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
|
| 231 |
+
|
| 232 |
+
projector_dim = 2 * self.inner_dim
|
| 233 |
+
|
| 234 |
+
self.projectors = nn.ModuleList([
|
| 235 |
+
nn.Sequential(
|
| 236 |
+
operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
|
| 237 |
+
nn.SiLU(),
|
| 238 |
+
operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
|
| 239 |
+
nn.SiLU(),
|
| 240 |
+
operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
|
| 241 |
+
) for ssl_dim in ssl_latent_dims
|
| 242 |
+
])
|
| 243 |
+
|
| 244 |
+
self.proj_in = PatchEmbed(
|
| 245 |
+
height=max_height,
|
| 246 |
+
width=max_width,
|
| 247 |
+
patch_size=patch_size,
|
| 248 |
+
embed_dim=self.inner_dim,
|
| 249 |
+
bias=True,
|
| 250 |
+
dtype=dtype,
|
| 251 |
+
device=device,
|
| 252 |
+
operations=operations,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
|
| 256 |
+
|
| 257 |
+
def forward_lyric_encoder(
|
| 258 |
+
self,
|
| 259 |
+
lyric_token_idx: Optional[torch.LongTensor] = None,
|
| 260 |
+
lyric_mask: Optional[torch.LongTensor] = None,
|
| 261 |
+
out_dtype=None,
|
| 262 |
+
):
|
| 263 |
+
# N x T x D
|
| 264 |
+
lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
|
| 265 |
+
prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
| 266 |
+
prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
|
| 267 |
+
return prompt_prenet_out
|
| 268 |
+
|
| 269 |
+
def encode(
|
| 270 |
+
self,
|
| 271 |
+
encoder_text_hidden_states: Optional[torch.Tensor] = None,
|
| 272 |
+
text_attention_mask: Optional[torch.LongTensor] = None,
|
| 273 |
+
speaker_embeds: Optional[torch.FloatTensor] = None,
|
| 274 |
+
lyric_token_idx: Optional[torch.LongTensor] = None,
|
| 275 |
+
lyric_mask: Optional[torch.LongTensor] = None,
|
| 276 |
+
lyrics_strength=1.0,
|
| 277 |
+
):
|
| 278 |
+
|
| 279 |
+
bs = encoder_text_hidden_states.shape[0]
|
| 280 |
+
device = encoder_text_hidden_states.device
|
| 281 |
+
|
| 282 |
+
# speaker embedding
|
| 283 |
+
encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
|
| 284 |
+
|
| 285 |
+
# genre embedding
|
| 286 |
+
encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
|
| 287 |
+
|
| 288 |
+
# lyric
|
| 289 |
+
encoder_lyric_hidden_states = self.forward_lyric_encoder(
|
| 290 |
+
lyric_token_idx=lyric_token_idx,
|
| 291 |
+
lyric_mask=lyric_mask,
|
| 292 |
+
out_dtype=encoder_text_hidden_states.dtype,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
encoder_lyric_hidden_states *= lyrics_strength
|
| 296 |
+
|
| 297 |
+
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
|
| 298 |
+
|
| 299 |
+
encoder_hidden_mask = None
|
| 300 |
+
if text_attention_mask is not None:
|
| 301 |
+
speaker_mask = torch.ones(bs, 1, device=device)
|
| 302 |
+
encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
|
| 303 |
+
|
| 304 |
+
return encoder_hidden_states, encoder_hidden_mask
|
| 305 |
+
|
| 306 |
+
def decode(
|
| 307 |
+
self,
|
| 308 |
+
hidden_states: torch.Tensor,
|
| 309 |
+
attention_mask: torch.Tensor,
|
| 310 |
+
encoder_hidden_states: torch.Tensor,
|
| 311 |
+
encoder_hidden_mask: torch.Tensor,
|
| 312 |
+
timestep: Optional[torch.Tensor],
|
| 313 |
+
output_length: int = 0,
|
| 314 |
+
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
| 315 |
+
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
| 316 |
+
):
|
| 317 |
+
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
|
| 318 |
+
temb = self.t_block(embedded_timestep)
|
| 319 |
+
|
| 320 |
+
hidden_states = self.proj_in(hidden_states)
|
| 321 |
+
|
| 322 |
+
# controlnet logic
|
| 323 |
+
if block_controlnet_hidden_states is not None:
|
| 324 |
+
control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
|
| 325 |
+
hidden_states = hidden_states + control_condi * controlnet_scale
|
| 326 |
+
|
| 327 |
+
# inner_hidden_states = []
|
| 328 |
+
|
| 329 |
+
rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
|
| 330 |
+
encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
|
| 331 |
+
|
| 332 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 333 |
+
hidden_states = block(
|
| 334 |
+
hidden_states=hidden_states,
|
| 335 |
+
attention_mask=attention_mask,
|
| 336 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 337 |
+
encoder_attention_mask=encoder_hidden_mask,
|
| 338 |
+
rotary_freqs_cis=rotary_freqs_cis,
|
| 339 |
+
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
|
| 340 |
+
temb=temb,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
output = self.final_layer(hidden_states, embedded_timestep, output_length)
|
| 344 |
+
return output
|
| 345 |
+
|
| 346 |
+
def forward(
|
| 347 |
+
self,
|
| 348 |
+
x,
|
| 349 |
+
timestep,
|
| 350 |
+
attention_mask=None,
|
| 351 |
+
context: Optional[torch.Tensor] = None,
|
| 352 |
+
text_attention_mask: Optional[torch.LongTensor] = None,
|
| 353 |
+
speaker_embeds: Optional[torch.FloatTensor] = None,
|
| 354 |
+
lyric_token_idx: Optional[torch.LongTensor] = None,
|
| 355 |
+
lyric_mask: Optional[torch.LongTensor] = None,
|
| 356 |
+
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
| 357 |
+
controlnet_scale: Union[float, torch.Tensor] = 1.0,
|
| 358 |
+
lyrics_strength=1.0,
|
| 359 |
+
**kwargs
|
| 360 |
+
):
|
| 361 |
+
hidden_states = x
|
| 362 |
+
encoder_text_hidden_states = context
|
| 363 |
+
encoder_hidden_states, encoder_hidden_mask = self.encode(
|
| 364 |
+
encoder_text_hidden_states=encoder_text_hidden_states,
|
| 365 |
+
text_attention_mask=text_attention_mask,
|
| 366 |
+
speaker_embeds=speaker_embeds,
|
| 367 |
+
lyric_token_idx=lyric_token_idx,
|
| 368 |
+
lyric_mask=lyric_mask,
|
| 369 |
+
lyrics_strength=lyrics_strength,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
output_length = hidden_states.shape[-1]
|
| 373 |
+
|
| 374 |
+
output = self.decode(
|
| 375 |
+
hidden_states=hidden_states,
|
| 376 |
+
attention_mask=attention_mask,
|
| 377 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 378 |
+
encoder_hidden_mask=encoder_hidden_mask,
|
| 379 |
+
timestep=timestep,
|
| 380 |
+
output_length=output_length,
|
| 381 |
+
block_controlnet_hidden_states=block_controlnet_hidden_states,
|
| 382 |
+
controlnet_scale=controlnet_scale,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
return output
|
ComfyUI/comfy/ldm/ace/vae/music_log_mel.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
import logging
|
| 6 |
+
try:
|
| 7 |
+
from torchaudio.transforms import MelScale
|
| 8 |
+
except:
|
| 9 |
+
logging.warning("torchaudio missing, ACE model will be broken")
|
| 10 |
+
|
| 11 |
+
import comfy.model_management
|
| 12 |
+
|
| 13 |
+
class LinearSpectrogram(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
n_fft=2048,
|
| 17 |
+
win_length=2048,
|
| 18 |
+
hop_length=512,
|
| 19 |
+
center=False,
|
| 20 |
+
mode="pow2_sqrt",
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
self.n_fft = n_fft
|
| 25 |
+
self.win_length = win_length
|
| 26 |
+
self.hop_length = hop_length
|
| 27 |
+
self.center = center
|
| 28 |
+
self.mode = mode
|
| 29 |
+
|
| 30 |
+
self.register_buffer("window", torch.hann_window(win_length))
|
| 31 |
+
|
| 32 |
+
def forward(self, y: Tensor) -> Tensor:
|
| 33 |
+
if y.ndim == 3:
|
| 34 |
+
y = y.squeeze(1)
|
| 35 |
+
|
| 36 |
+
y = torch.nn.functional.pad(
|
| 37 |
+
y.unsqueeze(1),
|
| 38 |
+
(
|
| 39 |
+
(self.win_length - self.hop_length) // 2,
|
| 40 |
+
(self.win_length - self.hop_length + 1) // 2,
|
| 41 |
+
),
|
| 42 |
+
mode="reflect",
|
| 43 |
+
).squeeze(1)
|
| 44 |
+
dtype = y.dtype
|
| 45 |
+
spec = torch.stft(
|
| 46 |
+
y.float(),
|
| 47 |
+
self.n_fft,
|
| 48 |
+
hop_length=self.hop_length,
|
| 49 |
+
win_length=self.win_length,
|
| 50 |
+
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
|
| 51 |
+
center=self.center,
|
| 52 |
+
pad_mode="reflect",
|
| 53 |
+
normalized=False,
|
| 54 |
+
onesided=True,
|
| 55 |
+
return_complex=True,
|
| 56 |
+
)
|
| 57 |
+
spec = torch.view_as_real(spec)
|
| 58 |
+
|
| 59 |
+
if self.mode == "pow2_sqrt":
|
| 60 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
| 61 |
+
spec = spec.to(dtype)
|
| 62 |
+
return spec
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class LogMelSpectrogram(nn.Module):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
sample_rate=44100,
|
| 69 |
+
n_fft=2048,
|
| 70 |
+
win_length=2048,
|
| 71 |
+
hop_length=512,
|
| 72 |
+
n_mels=128,
|
| 73 |
+
center=False,
|
| 74 |
+
f_min=0.0,
|
| 75 |
+
f_max=None,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.sample_rate = sample_rate
|
| 80 |
+
self.n_fft = n_fft
|
| 81 |
+
self.win_length = win_length
|
| 82 |
+
self.hop_length = hop_length
|
| 83 |
+
self.center = center
|
| 84 |
+
self.n_mels = n_mels
|
| 85 |
+
self.f_min = f_min
|
| 86 |
+
self.f_max = f_max or sample_rate // 2
|
| 87 |
+
|
| 88 |
+
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
| 89 |
+
self.mel_scale = MelScale(
|
| 90 |
+
self.n_mels,
|
| 91 |
+
self.sample_rate,
|
| 92 |
+
self.f_min,
|
| 93 |
+
self.f_max,
|
| 94 |
+
self.n_fft // 2 + 1,
|
| 95 |
+
"slaney",
|
| 96 |
+
"slaney",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def compress(self, x: Tensor) -> Tensor:
|
| 100 |
+
return torch.log(torch.clamp(x, min=1e-5))
|
| 101 |
+
|
| 102 |
+
def decompress(self, x: Tensor) -> Tensor:
|
| 103 |
+
return torch.exp(x)
|
| 104 |
+
|
| 105 |
+
def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
|
| 106 |
+
linear = self.spectrogram(x)
|
| 107 |
+
x = self.mel_scale(linear)
|
| 108 |
+
x = self.compress(x)
|
| 109 |
+
# print(x.shape)
|
| 110 |
+
if return_linear:
|
| 111 |
+
return x, self.compress(linear)
|
| 112 |
+
|
| 113 |
+
return x
|
ComfyUI/comfy/ldm/audio/autoencoder.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from typing import Literal
|
| 6 |
+
import math
|
| 7 |
+
import comfy.ops
|
| 8 |
+
ops = comfy.ops.disable_weight_init
|
| 9 |
+
|
| 10 |
+
def vae_sample(mean, scale):
|
| 11 |
+
stdev = nn.functional.softplus(scale) + 1e-4
|
| 12 |
+
var = stdev * stdev
|
| 13 |
+
logvar = torch.log(var)
|
| 14 |
+
latents = torch.randn_like(mean) * stdev + mean
|
| 15 |
+
|
| 16 |
+
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
| 17 |
+
|
| 18 |
+
return latents, kl
|
| 19 |
+
|
| 20 |
+
class VAEBottleneck(nn.Module):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.is_discrete = False
|
| 24 |
+
|
| 25 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 26 |
+
info = {}
|
| 27 |
+
|
| 28 |
+
mean, scale = x.chunk(2, dim=1)
|
| 29 |
+
|
| 30 |
+
x, kl = vae_sample(mean, scale)
|
| 31 |
+
|
| 32 |
+
info["kl"] = kl
|
| 33 |
+
|
| 34 |
+
if return_info:
|
| 35 |
+
return x, info
|
| 36 |
+
else:
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
def decode(self, x):
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def snake_beta(x, alpha, beta):
|
| 44 |
+
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
| 45 |
+
|
| 46 |
+
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
| 47 |
+
class SnakeBeta(nn.Module):
|
| 48 |
+
|
| 49 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
| 50 |
+
super(SnakeBeta, self).__init__()
|
| 51 |
+
self.in_features = in_features
|
| 52 |
+
|
| 53 |
+
# initialize alpha
|
| 54 |
+
self.alpha_logscale = alpha_logscale
|
| 55 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 56 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 57 |
+
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 58 |
+
else: # linear scale alphas initialized to ones
|
| 59 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
| 60 |
+
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
| 61 |
+
|
| 62 |
+
# self.alpha.requires_grad = alpha_trainable
|
| 63 |
+
# self.beta.requires_grad = alpha_trainable
|
| 64 |
+
|
| 65 |
+
self.no_div_by_zero = 0.000000001
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
|
| 69 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
|
| 70 |
+
if self.alpha_logscale:
|
| 71 |
+
alpha = torch.exp(alpha)
|
| 72 |
+
beta = torch.exp(beta)
|
| 73 |
+
x = snake_beta(x, alpha, beta)
|
| 74 |
+
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
def WNConv1d(*args, **kwargs):
|
| 78 |
+
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
| 79 |
+
|
| 80 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 81 |
+
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
| 82 |
+
|
| 83 |
+
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
| 84 |
+
if activation == "elu":
|
| 85 |
+
act = torch.nn.ELU()
|
| 86 |
+
elif activation == "snake":
|
| 87 |
+
act = SnakeBeta(channels)
|
| 88 |
+
elif activation == "none":
|
| 89 |
+
act = torch.nn.Identity()
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Unknown activation {activation}")
|
| 92 |
+
|
| 93 |
+
if antialias:
|
| 94 |
+
act = Activation1d(act) # noqa: F821 Activation1d is not defined
|
| 95 |
+
|
| 96 |
+
return act
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ResidualUnit(nn.Module):
|
| 100 |
+
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
| 101 |
+
super().__init__()
|
| 102 |
+
|
| 103 |
+
self.dilation = dilation
|
| 104 |
+
|
| 105 |
+
padding = (dilation * (7-1)) // 2
|
| 106 |
+
|
| 107 |
+
self.layers = nn.Sequential(
|
| 108 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
| 109 |
+
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
| 110 |
+
kernel_size=7, dilation=dilation, padding=padding),
|
| 111 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
| 112 |
+
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
| 113 |
+
kernel_size=1)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
res = x
|
| 118 |
+
|
| 119 |
+
#x = checkpoint(self.layers, x)
|
| 120 |
+
x = self.layers(x)
|
| 121 |
+
|
| 122 |
+
return x + res
|
| 123 |
+
|
| 124 |
+
class EncoderBlock(nn.Module):
|
| 125 |
+
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.layers = nn.Sequential(
|
| 129 |
+
ResidualUnit(in_channels=in_channels,
|
| 130 |
+
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
| 131 |
+
ResidualUnit(in_channels=in_channels,
|
| 132 |
+
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
| 133 |
+
ResidualUnit(in_channels=in_channels,
|
| 134 |
+
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
| 135 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
| 136 |
+
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
| 137 |
+
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
return self.layers(x)
|
| 142 |
+
|
| 143 |
+
class DecoderBlock(nn.Module):
|
| 144 |
+
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
if use_nearest_upsample:
|
| 148 |
+
upsample_layer = nn.Sequential(
|
| 149 |
+
nn.Upsample(scale_factor=stride, mode="nearest"),
|
| 150 |
+
WNConv1d(in_channels=in_channels,
|
| 151 |
+
out_channels=out_channels,
|
| 152 |
+
kernel_size=2*stride,
|
| 153 |
+
stride=1,
|
| 154 |
+
bias=False,
|
| 155 |
+
padding='same')
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
| 159 |
+
out_channels=out_channels,
|
| 160 |
+
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
| 161 |
+
|
| 162 |
+
self.layers = nn.Sequential(
|
| 163 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
| 164 |
+
upsample_layer,
|
| 165 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 166 |
+
dilation=1, use_snake=use_snake),
|
| 167 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 168 |
+
dilation=3, use_snake=use_snake),
|
| 169 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 170 |
+
dilation=9, use_snake=use_snake),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
return self.layers(x)
|
| 175 |
+
|
| 176 |
+
class OobleckEncoder(nn.Module):
|
| 177 |
+
def __init__(self,
|
| 178 |
+
in_channels=2,
|
| 179 |
+
channels=128,
|
| 180 |
+
latent_dim=32,
|
| 181 |
+
c_mults = [1, 2, 4, 8],
|
| 182 |
+
strides = [2, 4, 8, 8],
|
| 183 |
+
use_snake=False,
|
| 184 |
+
antialias_activation=False
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
|
| 188 |
+
c_mults = [1] + c_mults
|
| 189 |
+
|
| 190 |
+
self.depth = len(c_mults)
|
| 191 |
+
|
| 192 |
+
layers = [
|
| 193 |
+
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
for i in range(self.depth-1):
|
| 197 |
+
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
| 198 |
+
|
| 199 |
+
layers += [
|
| 200 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
| 201 |
+
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
self.layers = nn.Sequential(*layers)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
return self.layers(x)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class OobleckDecoder(nn.Module):
|
| 211 |
+
def __init__(self,
|
| 212 |
+
out_channels=2,
|
| 213 |
+
channels=128,
|
| 214 |
+
latent_dim=32,
|
| 215 |
+
c_mults = [1, 2, 4, 8],
|
| 216 |
+
strides = [2, 4, 8, 8],
|
| 217 |
+
use_snake=False,
|
| 218 |
+
antialias_activation=False,
|
| 219 |
+
use_nearest_upsample=False,
|
| 220 |
+
final_tanh=True):
|
| 221 |
+
super().__init__()
|
| 222 |
+
|
| 223 |
+
c_mults = [1] + c_mults
|
| 224 |
+
|
| 225 |
+
self.depth = len(c_mults)
|
| 226 |
+
|
| 227 |
+
layers = [
|
| 228 |
+
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
for i in range(self.depth-1, 0, -1):
|
| 232 |
+
layers += [DecoderBlock(
|
| 233 |
+
in_channels=c_mults[i]*channels,
|
| 234 |
+
out_channels=c_mults[i-1]*channels,
|
| 235 |
+
stride=strides[i-1],
|
| 236 |
+
use_snake=use_snake,
|
| 237 |
+
antialias_activation=antialias_activation,
|
| 238 |
+
use_nearest_upsample=use_nearest_upsample
|
| 239 |
+
)
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
layers += [
|
| 243 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
| 244 |
+
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
| 245 |
+
nn.Tanh() if final_tanh else nn.Identity()
|
| 246 |
+
]
|
| 247 |
+
|
| 248 |
+
self.layers = nn.Sequential(*layers)
|
| 249 |
+
|
| 250 |
+
def forward(self, x):
|
| 251 |
+
return self.layers(x)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class AudioOobleckVAE(nn.Module):
|
| 255 |
+
def __init__(self,
|
| 256 |
+
in_channels=2,
|
| 257 |
+
channels=128,
|
| 258 |
+
latent_dim=64,
|
| 259 |
+
c_mults = [1, 2, 4, 8, 16],
|
| 260 |
+
strides = [2, 4, 4, 8, 8],
|
| 261 |
+
use_snake=True,
|
| 262 |
+
antialias_activation=False,
|
| 263 |
+
use_nearest_upsample=False,
|
| 264 |
+
final_tanh=False):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
|
| 267 |
+
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
|
| 268 |
+
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
|
| 269 |
+
self.bottleneck = VAEBottleneck()
|
| 270 |
+
|
| 271 |
+
def encode(self, x):
|
| 272 |
+
return self.bottleneck.encode(self.encoder(x))
|
| 273 |
+
|
| 274 |
+
def decode(self, x):
|
| 275 |
+
return self.decoder(self.bottleneck.decode(x))
|
| 276 |
+
|
ComfyUI/comfy/ldm/audio/dit.py
ADDED
|
@@ -0,0 +1,896 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
| 2 |
+
|
| 3 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 4 |
+
import typing as tp
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
import math
|
| 12 |
+
import comfy.ops
|
| 13 |
+
|
| 14 |
+
class FourierFeatures(nn.Module):
|
| 15 |
+
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
| 16 |
+
super().__init__()
|
| 17 |
+
assert out_features % 2 == 0
|
| 18 |
+
self.weight = nn.Parameter(torch.empty(
|
| 19 |
+
[out_features // 2, in_features], dtype=dtype, device=device))
|
| 20 |
+
|
| 21 |
+
def forward(self, input):
|
| 22 |
+
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
|
| 23 |
+
return torch.cat([f.cos(), f.sin()], dim=-1)
|
| 24 |
+
|
| 25 |
+
# norms
|
| 26 |
+
class LayerNorm(nn.Module):
|
| 27 |
+
def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
|
| 28 |
+
"""
|
| 29 |
+
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
| 34 |
+
|
| 35 |
+
if bias:
|
| 36 |
+
self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
| 37 |
+
else:
|
| 38 |
+
self.beta = None
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
beta = self.beta
|
| 42 |
+
if beta is not None:
|
| 43 |
+
beta = comfy.ops.cast_to_input(beta, x)
|
| 44 |
+
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
|
| 45 |
+
|
| 46 |
+
class GLU(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
dim_in,
|
| 50 |
+
dim_out,
|
| 51 |
+
activation,
|
| 52 |
+
use_conv = False,
|
| 53 |
+
conv_kernel_size = 3,
|
| 54 |
+
dtype=None,
|
| 55 |
+
device=None,
|
| 56 |
+
operations=None,
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.act = activation
|
| 60 |
+
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
|
| 61 |
+
self.use_conv = use_conv
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
if self.use_conv:
|
| 65 |
+
x = rearrange(x, 'b n d -> b d n')
|
| 66 |
+
x = self.proj(x)
|
| 67 |
+
x = rearrange(x, 'b d n -> b n d')
|
| 68 |
+
else:
|
| 69 |
+
x = self.proj(x)
|
| 70 |
+
|
| 71 |
+
x, gate = x.chunk(2, dim = -1)
|
| 72 |
+
return x * self.act(gate)
|
| 73 |
+
|
| 74 |
+
class AbsolutePositionalEmbedding(nn.Module):
|
| 75 |
+
def __init__(self, dim, max_seq_len):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.scale = dim ** -0.5
|
| 78 |
+
self.max_seq_len = max_seq_len
|
| 79 |
+
self.emb = nn.Embedding(max_seq_len, dim)
|
| 80 |
+
|
| 81 |
+
def forward(self, x, pos = None, seq_start_pos = None):
|
| 82 |
+
seq_len, device = x.shape[1], x.device
|
| 83 |
+
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
| 84 |
+
|
| 85 |
+
if pos is None:
|
| 86 |
+
pos = torch.arange(seq_len, device = device)
|
| 87 |
+
|
| 88 |
+
if seq_start_pos is not None:
|
| 89 |
+
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
| 90 |
+
|
| 91 |
+
pos_emb = self.emb(pos)
|
| 92 |
+
pos_emb = pos_emb * self.scale
|
| 93 |
+
return pos_emb
|
| 94 |
+
|
| 95 |
+
class ScaledSinusoidalEmbedding(nn.Module):
|
| 96 |
+
def __init__(self, dim, theta = 10000):
|
| 97 |
+
super().__init__()
|
| 98 |
+
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
| 99 |
+
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
| 100 |
+
|
| 101 |
+
half_dim = dim // 2
|
| 102 |
+
freq_seq = torch.arange(half_dim).float() / half_dim
|
| 103 |
+
inv_freq = theta ** -freq_seq
|
| 104 |
+
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
| 105 |
+
|
| 106 |
+
def forward(self, x, pos = None, seq_start_pos = None):
|
| 107 |
+
seq_len, device = x.shape[1], x.device
|
| 108 |
+
|
| 109 |
+
if pos is None:
|
| 110 |
+
pos = torch.arange(seq_len, device = device)
|
| 111 |
+
|
| 112 |
+
if seq_start_pos is not None:
|
| 113 |
+
pos = pos - seq_start_pos[..., None]
|
| 114 |
+
|
| 115 |
+
emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
|
| 116 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
| 117 |
+
return emb * self.scale
|
| 118 |
+
|
| 119 |
+
class RotaryEmbedding(nn.Module):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
dim,
|
| 123 |
+
use_xpos = False,
|
| 124 |
+
scale_base = 512,
|
| 125 |
+
interpolation_factor = 1.,
|
| 126 |
+
base = 10000,
|
| 127 |
+
base_rescale_factor = 1.,
|
| 128 |
+
dtype=None,
|
| 129 |
+
device=None,
|
| 130 |
+
):
|
| 131 |
+
super().__init__()
|
| 132 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 133 |
+
# has some connection to NTK literature
|
| 134 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 135 |
+
base *= base_rescale_factor ** (dim / (dim - 2))
|
| 136 |
+
|
| 137 |
+
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 138 |
+
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
|
| 139 |
+
|
| 140 |
+
assert interpolation_factor >= 1.
|
| 141 |
+
self.interpolation_factor = interpolation_factor
|
| 142 |
+
|
| 143 |
+
if not use_xpos:
|
| 144 |
+
self.register_buffer('scale', None)
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 148 |
+
|
| 149 |
+
self.scale_base = scale_base
|
| 150 |
+
self.register_buffer('scale', scale)
|
| 151 |
+
|
| 152 |
+
def forward_from_seq_len(self, seq_len, device, dtype):
|
| 153 |
+
# device = self.inv_freq.device
|
| 154 |
+
|
| 155 |
+
t = torch.arange(seq_len, device=device, dtype=dtype)
|
| 156 |
+
return self.forward(t)
|
| 157 |
+
|
| 158 |
+
def forward(self, t):
|
| 159 |
+
# device = self.inv_freq.device
|
| 160 |
+
device = t.device
|
| 161 |
+
|
| 162 |
+
# t = t.to(torch.float32)
|
| 163 |
+
|
| 164 |
+
t = t / self.interpolation_factor
|
| 165 |
+
|
| 166 |
+
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
|
| 167 |
+
freqs = torch.cat((freqs, freqs), dim = -1)
|
| 168 |
+
|
| 169 |
+
if self.scale is None:
|
| 170 |
+
return freqs, 1.
|
| 171 |
+
|
| 172 |
+
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
|
| 173 |
+
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
| 174 |
+
scale = torch.cat((scale, scale), dim = -1)
|
| 175 |
+
|
| 176 |
+
return freqs, scale
|
| 177 |
+
|
| 178 |
+
def rotate_half(x):
|
| 179 |
+
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
| 180 |
+
x1, x2 = x.unbind(dim = -2)
|
| 181 |
+
return torch.cat((-x2, x1), dim = -1)
|
| 182 |
+
|
| 183 |
+
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
| 184 |
+
out_dtype = t.dtype
|
| 185 |
+
|
| 186 |
+
# cast to float32 if necessary for numerical stability
|
| 187 |
+
dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
| 188 |
+
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
| 189 |
+
freqs, t = freqs.to(dtype), t.to(dtype)
|
| 190 |
+
freqs = freqs[-seq_len:, :]
|
| 191 |
+
|
| 192 |
+
if t.ndim == 4 and freqs.ndim == 3:
|
| 193 |
+
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
| 194 |
+
|
| 195 |
+
# partial rotary embeddings, Wang et al. GPT-J
|
| 196 |
+
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
| 197 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
| 198 |
+
|
| 199 |
+
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
| 200 |
+
|
| 201 |
+
return torch.cat((t, t_unrotated), dim = -1)
|
| 202 |
+
|
| 203 |
+
class FeedForward(nn.Module):
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
dim,
|
| 207 |
+
dim_out = None,
|
| 208 |
+
mult = 4,
|
| 209 |
+
no_bias = False,
|
| 210 |
+
glu = True,
|
| 211 |
+
use_conv = False,
|
| 212 |
+
conv_kernel_size = 3,
|
| 213 |
+
zero_init_output = True,
|
| 214 |
+
dtype=None,
|
| 215 |
+
device=None,
|
| 216 |
+
operations=None,
|
| 217 |
+
):
|
| 218 |
+
super().__init__()
|
| 219 |
+
inner_dim = int(dim * mult)
|
| 220 |
+
|
| 221 |
+
# Default to SwiGLU
|
| 222 |
+
|
| 223 |
+
activation = nn.SiLU()
|
| 224 |
+
|
| 225 |
+
dim_out = dim if dim_out is None else dim_out
|
| 226 |
+
|
| 227 |
+
if glu:
|
| 228 |
+
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
| 229 |
+
else:
|
| 230 |
+
linear_in = nn.Sequential(
|
| 231 |
+
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
| 232 |
+
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
| 233 |
+
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
| 234 |
+
activation
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
|
| 238 |
+
|
| 239 |
+
# # init last linear layer to 0
|
| 240 |
+
# if zero_init_output:
|
| 241 |
+
# nn.init.zeros_(linear_out.weight)
|
| 242 |
+
# if not no_bias:
|
| 243 |
+
# nn.init.zeros_(linear_out.bias)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
self.ff = nn.Sequential(
|
| 247 |
+
linear_in,
|
| 248 |
+
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
| 249 |
+
linear_out,
|
| 250 |
+
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
return self.ff(x)
|
| 255 |
+
|
| 256 |
+
class Attention(nn.Module):
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
dim,
|
| 260 |
+
dim_heads = 64,
|
| 261 |
+
dim_context = None,
|
| 262 |
+
causal = False,
|
| 263 |
+
zero_init_output=True,
|
| 264 |
+
qk_norm = False,
|
| 265 |
+
natten_kernel_size = None,
|
| 266 |
+
dtype=None,
|
| 267 |
+
device=None,
|
| 268 |
+
operations=None,
|
| 269 |
+
):
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.dim = dim
|
| 272 |
+
self.dim_heads = dim_heads
|
| 273 |
+
self.causal = causal
|
| 274 |
+
|
| 275 |
+
dim_kv = dim_context if dim_context is not None else dim
|
| 276 |
+
|
| 277 |
+
self.num_heads = dim // dim_heads
|
| 278 |
+
self.kv_heads = dim_kv // dim_heads
|
| 279 |
+
|
| 280 |
+
if dim_context is not None:
|
| 281 |
+
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 282 |
+
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
|
| 283 |
+
else:
|
| 284 |
+
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
|
| 285 |
+
|
| 286 |
+
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 287 |
+
|
| 288 |
+
# if zero_init_output:
|
| 289 |
+
# nn.init.zeros_(self.to_out.weight)
|
| 290 |
+
|
| 291 |
+
self.qk_norm = qk_norm
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def forward(
|
| 295 |
+
self,
|
| 296 |
+
x,
|
| 297 |
+
context = None,
|
| 298 |
+
mask = None,
|
| 299 |
+
context_mask = None,
|
| 300 |
+
rotary_pos_emb = None,
|
| 301 |
+
causal = None
|
| 302 |
+
):
|
| 303 |
+
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
| 304 |
+
|
| 305 |
+
kv_input = context if has_context else x
|
| 306 |
+
|
| 307 |
+
if hasattr(self, 'to_q'):
|
| 308 |
+
# Use separate linear projections for q and k/v
|
| 309 |
+
q = self.to_q(x)
|
| 310 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
| 311 |
+
|
| 312 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 313 |
+
|
| 314 |
+
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
| 315 |
+
else:
|
| 316 |
+
# Use fused linear projection
|
| 317 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 318 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
| 319 |
+
|
| 320 |
+
# Normalize q and k for cosine sim attention
|
| 321 |
+
if self.qk_norm:
|
| 322 |
+
q = F.normalize(q, dim=-1)
|
| 323 |
+
k = F.normalize(k, dim=-1)
|
| 324 |
+
|
| 325 |
+
if rotary_pos_emb is not None and not has_context:
|
| 326 |
+
freqs, _ = rotary_pos_emb
|
| 327 |
+
|
| 328 |
+
q_dtype = q.dtype
|
| 329 |
+
k_dtype = k.dtype
|
| 330 |
+
|
| 331 |
+
q = q.to(torch.float32)
|
| 332 |
+
k = k.to(torch.float32)
|
| 333 |
+
freqs = freqs.to(torch.float32)
|
| 334 |
+
|
| 335 |
+
q = apply_rotary_pos_emb(q, freqs)
|
| 336 |
+
k = apply_rotary_pos_emb(k, freqs)
|
| 337 |
+
|
| 338 |
+
q = q.to(q_dtype)
|
| 339 |
+
k = k.to(k_dtype)
|
| 340 |
+
|
| 341 |
+
input_mask = context_mask
|
| 342 |
+
|
| 343 |
+
if input_mask is None and not has_context:
|
| 344 |
+
input_mask = mask
|
| 345 |
+
|
| 346 |
+
# determine masking
|
| 347 |
+
masks = []
|
| 348 |
+
|
| 349 |
+
if input_mask is not None:
|
| 350 |
+
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
| 351 |
+
masks.append(~input_mask)
|
| 352 |
+
|
| 353 |
+
# Other masks will be added here later
|
| 354 |
+
n = q.shape[-2]
|
| 355 |
+
|
| 356 |
+
causal = self.causal if causal is None else causal
|
| 357 |
+
|
| 358 |
+
if n == 1 and causal:
|
| 359 |
+
causal = False
|
| 360 |
+
|
| 361 |
+
if h != kv_h:
|
| 362 |
+
# Repeat interleave kv_heads to match q_heads
|
| 363 |
+
heads_per_kv_head = h // kv_h
|
| 364 |
+
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
| 365 |
+
|
| 366 |
+
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
| 367 |
+
out = self.to_out(out)
|
| 368 |
+
|
| 369 |
+
if mask is not None:
|
| 370 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
| 371 |
+
out = out.masked_fill(~mask, 0.)
|
| 372 |
+
|
| 373 |
+
return out
|
| 374 |
+
|
| 375 |
+
class ConformerModule(nn.Module):
|
| 376 |
+
def __init__(
|
| 377 |
+
self,
|
| 378 |
+
dim,
|
| 379 |
+
norm_kwargs = {},
|
| 380 |
+
):
|
| 381 |
+
|
| 382 |
+
super().__init__()
|
| 383 |
+
|
| 384 |
+
self.dim = dim
|
| 385 |
+
|
| 386 |
+
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
| 387 |
+
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
| 388 |
+
self.glu = GLU(dim, dim, nn.SiLU())
|
| 389 |
+
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
| 390 |
+
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
| 391 |
+
self.swish = nn.SiLU()
|
| 392 |
+
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
| 393 |
+
|
| 394 |
+
def forward(self, x):
|
| 395 |
+
x = self.in_norm(x)
|
| 396 |
+
x = rearrange(x, 'b n d -> b d n')
|
| 397 |
+
x = self.pointwise_conv(x)
|
| 398 |
+
x = rearrange(x, 'b d n -> b n d')
|
| 399 |
+
x = self.glu(x)
|
| 400 |
+
x = rearrange(x, 'b n d -> b d n')
|
| 401 |
+
x = self.depthwise_conv(x)
|
| 402 |
+
x = rearrange(x, 'b d n -> b n d')
|
| 403 |
+
x = self.mid_norm(x)
|
| 404 |
+
x = self.swish(x)
|
| 405 |
+
x = rearrange(x, 'b n d -> b d n')
|
| 406 |
+
x = self.pointwise_conv_2(x)
|
| 407 |
+
x = rearrange(x, 'b d n -> b n d')
|
| 408 |
+
|
| 409 |
+
return x
|
| 410 |
+
|
| 411 |
+
class TransformerBlock(nn.Module):
|
| 412 |
+
def __init__(
|
| 413 |
+
self,
|
| 414 |
+
dim,
|
| 415 |
+
dim_heads = 64,
|
| 416 |
+
cross_attend = False,
|
| 417 |
+
dim_context = None,
|
| 418 |
+
global_cond_dim = None,
|
| 419 |
+
causal = False,
|
| 420 |
+
zero_init_branch_outputs = True,
|
| 421 |
+
conformer = False,
|
| 422 |
+
layer_ix = -1,
|
| 423 |
+
remove_norms = False,
|
| 424 |
+
attn_kwargs = {},
|
| 425 |
+
ff_kwargs = {},
|
| 426 |
+
norm_kwargs = {},
|
| 427 |
+
dtype=None,
|
| 428 |
+
device=None,
|
| 429 |
+
operations=None,
|
| 430 |
+
):
|
| 431 |
+
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.dim = dim
|
| 434 |
+
self.dim_heads = dim_heads
|
| 435 |
+
self.cross_attend = cross_attend
|
| 436 |
+
self.dim_context = dim_context
|
| 437 |
+
self.causal = causal
|
| 438 |
+
|
| 439 |
+
self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
| 440 |
+
|
| 441 |
+
self.self_attn = Attention(
|
| 442 |
+
dim,
|
| 443 |
+
dim_heads = dim_heads,
|
| 444 |
+
causal = causal,
|
| 445 |
+
zero_init_output=zero_init_branch_outputs,
|
| 446 |
+
dtype=dtype,
|
| 447 |
+
device=device,
|
| 448 |
+
operations=operations,
|
| 449 |
+
**attn_kwargs
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if cross_attend:
|
| 453 |
+
self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
| 454 |
+
self.cross_attn = Attention(
|
| 455 |
+
dim,
|
| 456 |
+
dim_heads = dim_heads,
|
| 457 |
+
dim_context=dim_context,
|
| 458 |
+
causal = causal,
|
| 459 |
+
zero_init_output=zero_init_branch_outputs,
|
| 460 |
+
dtype=dtype,
|
| 461 |
+
device=device,
|
| 462 |
+
operations=operations,
|
| 463 |
+
**attn_kwargs
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
| 467 |
+
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
|
| 468 |
+
|
| 469 |
+
self.layer_ix = layer_ix
|
| 470 |
+
|
| 471 |
+
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
| 472 |
+
|
| 473 |
+
self.global_cond_dim = global_cond_dim
|
| 474 |
+
|
| 475 |
+
if global_cond_dim is not None:
|
| 476 |
+
self.to_scale_shift_gate = nn.Sequential(
|
| 477 |
+
nn.SiLU(),
|
| 478 |
+
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
| 482 |
+
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
|
| 483 |
+
|
| 484 |
+
def forward(
|
| 485 |
+
self,
|
| 486 |
+
x,
|
| 487 |
+
context = None,
|
| 488 |
+
global_cond=None,
|
| 489 |
+
mask = None,
|
| 490 |
+
context_mask = None,
|
| 491 |
+
rotary_pos_emb = None
|
| 492 |
+
):
|
| 493 |
+
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
| 494 |
+
|
| 495 |
+
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
|
| 496 |
+
|
| 497 |
+
# self-attention with adaLN
|
| 498 |
+
residual = x
|
| 499 |
+
x = self.pre_norm(x)
|
| 500 |
+
x = x * (1 + scale_self) + shift_self
|
| 501 |
+
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
| 502 |
+
x = x * torch.sigmoid(1 - gate_self)
|
| 503 |
+
x = x + residual
|
| 504 |
+
|
| 505 |
+
if context is not None:
|
| 506 |
+
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
| 507 |
+
|
| 508 |
+
if self.conformer is not None:
|
| 509 |
+
x = x + self.conformer(x)
|
| 510 |
+
|
| 511 |
+
# feedforward with adaLN
|
| 512 |
+
residual = x
|
| 513 |
+
x = self.ff_norm(x)
|
| 514 |
+
x = x * (1 + scale_ff) + shift_ff
|
| 515 |
+
x = self.ff(x)
|
| 516 |
+
x = x * torch.sigmoid(1 - gate_ff)
|
| 517 |
+
x = x + residual
|
| 518 |
+
|
| 519 |
+
else:
|
| 520 |
+
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
| 521 |
+
|
| 522 |
+
if context is not None:
|
| 523 |
+
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
| 524 |
+
|
| 525 |
+
if self.conformer is not None:
|
| 526 |
+
x = x + self.conformer(x)
|
| 527 |
+
|
| 528 |
+
x = x + self.ff(self.ff_norm(x))
|
| 529 |
+
|
| 530 |
+
return x
|
| 531 |
+
|
| 532 |
+
class ContinuousTransformer(nn.Module):
|
| 533 |
+
def __init__(
|
| 534 |
+
self,
|
| 535 |
+
dim,
|
| 536 |
+
depth,
|
| 537 |
+
*,
|
| 538 |
+
dim_in = None,
|
| 539 |
+
dim_out = None,
|
| 540 |
+
dim_heads = 64,
|
| 541 |
+
cross_attend=False,
|
| 542 |
+
cond_token_dim=None,
|
| 543 |
+
global_cond_dim=None,
|
| 544 |
+
causal=False,
|
| 545 |
+
rotary_pos_emb=True,
|
| 546 |
+
zero_init_branch_outputs=True,
|
| 547 |
+
conformer=False,
|
| 548 |
+
use_sinusoidal_emb=False,
|
| 549 |
+
use_abs_pos_emb=False,
|
| 550 |
+
abs_pos_emb_max_length=10000,
|
| 551 |
+
dtype=None,
|
| 552 |
+
device=None,
|
| 553 |
+
operations=None,
|
| 554 |
+
**kwargs
|
| 555 |
+
):
|
| 556 |
+
|
| 557 |
+
super().__init__()
|
| 558 |
+
|
| 559 |
+
self.dim = dim
|
| 560 |
+
self.depth = depth
|
| 561 |
+
self.causal = causal
|
| 562 |
+
self.layers = nn.ModuleList([])
|
| 563 |
+
|
| 564 |
+
self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
|
| 565 |
+
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
| 566 |
+
|
| 567 |
+
if rotary_pos_emb:
|
| 568 |
+
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
|
| 569 |
+
else:
|
| 570 |
+
self.rotary_pos_emb = None
|
| 571 |
+
|
| 572 |
+
self.use_sinusoidal_emb = use_sinusoidal_emb
|
| 573 |
+
if use_sinusoidal_emb:
|
| 574 |
+
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
| 575 |
+
|
| 576 |
+
self.use_abs_pos_emb = use_abs_pos_emb
|
| 577 |
+
if use_abs_pos_emb:
|
| 578 |
+
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
| 579 |
+
|
| 580 |
+
for i in range(depth):
|
| 581 |
+
self.layers.append(
|
| 582 |
+
TransformerBlock(
|
| 583 |
+
dim,
|
| 584 |
+
dim_heads = dim_heads,
|
| 585 |
+
cross_attend = cross_attend,
|
| 586 |
+
dim_context = cond_token_dim,
|
| 587 |
+
global_cond_dim = global_cond_dim,
|
| 588 |
+
causal = causal,
|
| 589 |
+
zero_init_branch_outputs = zero_init_branch_outputs,
|
| 590 |
+
conformer=conformer,
|
| 591 |
+
layer_ix=i,
|
| 592 |
+
dtype=dtype,
|
| 593 |
+
device=device,
|
| 594 |
+
operations=operations,
|
| 595 |
+
**kwargs
|
| 596 |
+
)
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
def forward(
|
| 600 |
+
self,
|
| 601 |
+
x,
|
| 602 |
+
mask = None,
|
| 603 |
+
prepend_embeds = None,
|
| 604 |
+
prepend_mask = None,
|
| 605 |
+
global_cond = None,
|
| 606 |
+
return_info = False,
|
| 607 |
+
**kwargs
|
| 608 |
+
):
|
| 609 |
+
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
|
| 610 |
+
batch, seq, device = *x.shape[:2], x.device
|
| 611 |
+
context = kwargs["context"]
|
| 612 |
+
|
| 613 |
+
info = {
|
| 614 |
+
"hidden_states": [],
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
x = self.project_in(x)
|
| 618 |
+
|
| 619 |
+
if prepend_embeds is not None:
|
| 620 |
+
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
| 621 |
+
|
| 622 |
+
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
| 623 |
+
|
| 624 |
+
x = torch.cat((prepend_embeds, x), dim = -2)
|
| 625 |
+
|
| 626 |
+
if prepend_mask is not None or mask is not None:
|
| 627 |
+
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
|
| 628 |
+
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
|
| 629 |
+
|
| 630 |
+
mask = torch.cat((prepend_mask, mask), dim = -1)
|
| 631 |
+
|
| 632 |
+
# Attention layers
|
| 633 |
+
|
| 634 |
+
if self.rotary_pos_emb is not None:
|
| 635 |
+
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
| 636 |
+
else:
|
| 637 |
+
rotary_pos_emb = None
|
| 638 |
+
|
| 639 |
+
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
| 640 |
+
x = x + self.pos_emb(x)
|
| 641 |
+
|
| 642 |
+
blocks_replace = patches_replace.get("dit", {})
|
| 643 |
+
# Iterate over the transformer layers
|
| 644 |
+
for i, layer in enumerate(self.layers):
|
| 645 |
+
if ("double_block", i) in blocks_replace:
|
| 646 |
+
def block_wrap(args):
|
| 647 |
+
out = {}
|
| 648 |
+
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
|
| 649 |
+
return out
|
| 650 |
+
|
| 651 |
+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
|
| 652 |
+
x = out["img"]
|
| 653 |
+
else:
|
| 654 |
+
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
|
| 655 |
+
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
| 656 |
+
|
| 657 |
+
if return_info:
|
| 658 |
+
info["hidden_states"].append(x)
|
| 659 |
+
|
| 660 |
+
x = self.project_out(x)
|
| 661 |
+
|
| 662 |
+
if return_info:
|
| 663 |
+
return x, info
|
| 664 |
+
|
| 665 |
+
return x
|
| 666 |
+
|
| 667 |
+
class AudioDiffusionTransformer(nn.Module):
|
| 668 |
+
def __init__(self,
|
| 669 |
+
io_channels=64,
|
| 670 |
+
patch_size=1,
|
| 671 |
+
embed_dim=1536,
|
| 672 |
+
cond_token_dim=768,
|
| 673 |
+
project_cond_tokens=False,
|
| 674 |
+
global_cond_dim=1536,
|
| 675 |
+
project_global_cond=True,
|
| 676 |
+
input_concat_dim=0,
|
| 677 |
+
prepend_cond_dim=0,
|
| 678 |
+
depth=24,
|
| 679 |
+
num_heads=24,
|
| 680 |
+
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
| 681 |
+
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
| 682 |
+
audio_model="",
|
| 683 |
+
dtype=None,
|
| 684 |
+
device=None,
|
| 685 |
+
operations=None,
|
| 686 |
+
**kwargs):
|
| 687 |
+
|
| 688 |
+
super().__init__()
|
| 689 |
+
|
| 690 |
+
self.dtype = dtype
|
| 691 |
+
self.cond_token_dim = cond_token_dim
|
| 692 |
+
|
| 693 |
+
# Timestep embeddings
|
| 694 |
+
timestep_features_dim = 256
|
| 695 |
+
|
| 696 |
+
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
|
| 697 |
+
|
| 698 |
+
self.to_timestep_embed = nn.Sequential(
|
| 699 |
+
operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
| 700 |
+
nn.SiLU(),
|
| 701 |
+
operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
if cond_token_dim > 0:
|
| 705 |
+
# Conditioning tokens
|
| 706 |
+
|
| 707 |
+
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
| 708 |
+
self.to_cond_embed = nn.Sequential(
|
| 709 |
+
operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
|
| 710 |
+
nn.SiLU(),
|
| 711 |
+
operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
cond_embed_dim = 0
|
| 715 |
+
|
| 716 |
+
if global_cond_dim > 0:
|
| 717 |
+
# Global conditioning
|
| 718 |
+
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
| 719 |
+
self.to_global_embed = nn.Sequential(
|
| 720 |
+
operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
|
| 721 |
+
nn.SiLU(),
|
| 722 |
+
operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
if prepend_cond_dim > 0:
|
| 726 |
+
# Prepend conditioning
|
| 727 |
+
self.to_prepend_embed = nn.Sequential(
|
| 728 |
+
operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
|
| 729 |
+
nn.SiLU(),
|
| 730 |
+
operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
self.input_concat_dim = input_concat_dim
|
| 734 |
+
|
| 735 |
+
dim_in = io_channels + self.input_concat_dim
|
| 736 |
+
|
| 737 |
+
self.patch_size = patch_size
|
| 738 |
+
|
| 739 |
+
# Transformer
|
| 740 |
+
|
| 741 |
+
self.transformer_type = transformer_type
|
| 742 |
+
|
| 743 |
+
self.global_cond_type = global_cond_type
|
| 744 |
+
|
| 745 |
+
if self.transformer_type == "continuous_transformer":
|
| 746 |
+
|
| 747 |
+
global_dim = None
|
| 748 |
+
|
| 749 |
+
if self.global_cond_type == "adaLN":
|
| 750 |
+
# The global conditioning is projected to the embed_dim already at this point
|
| 751 |
+
global_dim = embed_dim
|
| 752 |
+
|
| 753 |
+
self.transformer = ContinuousTransformer(
|
| 754 |
+
dim=embed_dim,
|
| 755 |
+
depth=depth,
|
| 756 |
+
dim_heads=embed_dim // num_heads,
|
| 757 |
+
dim_in=dim_in * patch_size,
|
| 758 |
+
dim_out=io_channels * patch_size,
|
| 759 |
+
cross_attend = cond_token_dim > 0,
|
| 760 |
+
cond_token_dim = cond_embed_dim,
|
| 761 |
+
global_cond_dim=global_dim,
|
| 762 |
+
dtype=dtype,
|
| 763 |
+
device=device,
|
| 764 |
+
operations=operations,
|
| 765 |
+
**kwargs
|
| 766 |
+
)
|
| 767 |
+
else:
|
| 768 |
+
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
| 769 |
+
|
| 770 |
+
self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
|
| 771 |
+
self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
|
| 772 |
+
|
| 773 |
+
def _forward(
|
| 774 |
+
self,
|
| 775 |
+
x,
|
| 776 |
+
t,
|
| 777 |
+
mask=None,
|
| 778 |
+
cross_attn_cond=None,
|
| 779 |
+
cross_attn_cond_mask=None,
|
| 780 |
+
input_concat_cond=None,
|
| 781 |
+
global_embed=None,
|
| 782 |
+
prepend_cond=None,
|
| 783 |
+
prepend_cond_mask=None,
|
| 784 |
+
return_info=False,
|
| 785 |
+
**kwargs):
|
| 786 |
+
|
| 787 |
+
if cross_attn_cond is not None:
|
| 788 |
+
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
| 789 |
+
|
| 790 |
+
if global_embed is not None:
|
| 791 |
+
# Project the global conditioning to the embedding dimension
|
| 792 |
+
global_embed = self.to_global_embed(global_embed)
|
| 793 |
+
|
| 794 |
+
prepend_inputs = None
|
| 795 |
+
prepend_mask = None
|
| 796 |
+
prepend_length = 0
|
| 797 |
+
if prepend_cond is not None:
|
| 798 |
+
# Project the prepend conditioning to the embedding dimension
|
| 799 |
+
prepend_cond = self.to_prepend_embed(prepend_cond)
|
| 800 |
+
|
| 801 |
+
prepend_inputs = prepend_cond
|
| 802 |
+
if prepend_cond_mask is not None:
|
| 803 |
+
prepend_mask = prepend_cond_mask
|
| 804 |
+
|
| 805 |
+
if input_concat_cond is not None:
|
| 806 |
+
|
| 807 |
+
# Interpolate input_concat_cond to the same length as x
|
| 808 |
+
if input_concat_cond.shape[2] != x.shape[2]:
|
| 809 |
+
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
| 810 |
+
|
| 811 |
+
x = torch.cat([x, input_concat_cond], dim=1)
|
| 812 |
+
|
| 813 |
+
# Get the batch of timestep embeddings
|
| 814 |
+
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
|
| 815 |
+
|
| 816 |
+
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
| 817 |
+
if global_embed is not None:
|
| 818 |
+
global_embed = global_embed + timestep_embed
|
| 819 |
+
else:
|
| 820 |
+
global_embed = timestep_embed
|
| 821 |
+
|
| 822 |
+
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
| 823 |
+
if self.global_cond_type == "prepend":
|
| 824 |
+
if prepend_inputs is None:
|
| 825 |
+
# Prepend inputs are just the global embed, and the mask is all ones
|
| 826 |
+
prepend_inputs = global_embed.unsqueeze(1)
|
| 827 |
+
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
| 828 |
+
else:
|
| 829 |
+
# Prepend inputs are the prepend conditioning + the global embed
|
| 830 |
+
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
| 831 |
+
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
| 832 |
+
|
| 833 |
+
prepend_length = prepend_inputs.shape[1]
|
| 834 |
+
|
| 835 |
+
x = self.preprocess_conv(x) + x
|
| 836 |
+
|
| 837 |
+
x = rearrange(x, "b c t -> b t c")
|
| 838 |
+
|
| 839 |
+
extra_args = {}
|
| 840 |
+
|
| 841 |
+
if self.global_cond_type == "adaLN":
|
| 842 |
+
extra_args["global_cond"] = global_embed
|
| 843 |
+
|
| 844 |
+
if self.patch_size > 1:
|
| 845 |
+
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
| 846 |
+
|
| 847 |
+
if self.transformer_type == "x-transformers":
|
| 848 |
+
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
|
| 849 |
+
elif self.transformer_type == "continuous_transformer":
|
| 850 |
+
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
| 851 |
+
|
| 852 |
+
if return_info:
|
| 853 |
+
output, info = output
|
| 854 |
+
elif self.transformer_type == "mm_transformer":
|
| 855 |
+
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
|
| 856 |
+
|
| 857 |
+
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
| 858 |
+
|
| 859 |
+
if self.patch_size > 1:
|
| 860 |
+
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
| 861 |
+
|
| 862 |
+
output = self.postprocess_conv(output) + output
|
| 863 |
+
|
| 864 |
+
if return_info:
|
| 865 |
+
return output, info
|
| 866 |
+
|
| 867 |
+
return output
|
| 868 |
+
|
| 869 |
+
def forward(
|
| 870 |
+
self,
|
| 871 |
+
x,
|
| 872 |
+
timestep,
|
| 873 |
+
context=None,
|
| 874 |
+
context_mask=None,
|
| 875 |
+
input_concat_cond=None,
|
| 876 |
+
global_embed=None,
|
| 877 |
+
negative_global_embed=None,
|
| 878 |
+
prepend_cond=None,
|
| 879 |
+
prepend_cond_mask=None,
|
| 880 |
+
mask=None,
|
| 881 |
+
return_info=False,
|
| 882 |
+
control=None,
|
| 883 |
+
**kwargs):
|
| 884 |
+
return self._forward(
|
| 885 |
+
x,
|
| 886 |
+
timestep,
|
| 887 |
+
cross_attn_cond=context,
|
| 888 |
+
cross_attn_cond_mask=context_mask,
|
| 889 |
+
input_concat_cond=input_concat_cond,
|
| 890 |
+
global_embed=global_embed,
|
| 891 |
+
prepend_cond=prepend_cond,
|
| 892 |
+
prepend_cond_mask=prepend_cond_mask,
|
| 893 |
+
mask=mask,
|
| 894 |
+
return_info=return_info,
|
| 895 |
+
**kwargs
|
| 896 |
+
)
|
ComfyUI/comfy/ldm/audio/embedders.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from typing import List, Union
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
import math
|
| 9 |
+
import comfy.ops
|
| 10 |
+
|
| 11 |
+
class LearnedPositionalEmbedding(nn.Module):
|
| 12 |
+
"""Used for continuous time"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, dim: int):
|
| 15 |
+
super().__init__()
|
| 16 |
+
assert (dim % 2) == 0
|
| 17 |
+
half_dim = dim // 2
|
| 18 |
+
self.weights = nn.Parameter(torch.empty(half_dim))
|
| 19 |
+
|
| 20 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 21 |
+
x = rearrange(x, "b -> b 1")
|
| 22 |
+
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
|
| 23 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
| 24 |
+
fouriered = torch.cat((x, fouriered), dim=-1)
|
| 25 |
+
return fouriered
|
| 26 |
+
|
| 27 |
+
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
| 28 |
+
return nn.Sequential(
|
| 29 |
+
LearnedPositionalEmbedding(dim),
|
| 30 |
+
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class NumberEmbedder(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
features: int,
|
| 38 |
+
dim: int = 256,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.features = features
|
| 42 |
+
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
| 43 |
+
|
| 44 |
+
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
| 45 |
+
if not torch.is_tensor(x):
|
| 46 |
+
device = next(self.embedding.parameters()).device
|
| 47 |
+
x = torch.tensor(x, device=device)
|
| 48 |
+
assert isinstance(x, Tensor)
|
| 49 |
+
shape = x.shape
|
| 50 |
+
x = rearrange(x, "... -> (...)")
|
| 51 |
+
embedding = self.embedding(x)
|
| 52 |
+
x = embedding.view(*shape, self.features)
|
| 53 |
+
return x # type: ignore
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Conditioner(nn.Module):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
dim: int,
|
| 60 |
+
output_dim: int,
|
| 61 |
+
project_out: bool = False
|
| 62 |
+
):
|
| 63 |
+
|
| 64 |
+
super().__init__()
|
| 65 |
+
|
| 66 |
+
self.dim = dim
|
| 67 |
+
self.output_dim = output_dim
|
| 68 |
+
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
raise NotImplementedError()
|
| 72 |
+
|
| 73 |
+
class NumberConditioner(Conditioner):
|
| 74 |
+
'''
|
| 75 |
+
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
| 76 |
+
'''
|
| 77 |
+
def __init__(self,
|
| 78 |
+
output_dim: int,
|
| 79 |
+
min_val: float=0,
|
| 80 |
+
max_val: float=1
|
| 81 |
+
):
|
| 82 |
+
super().__init__(output_dim, output_dim)
|
| 83 |
+
|
| 84 |
+
self.min_val = min_val
|
| 85 |
+
self.max_val = max_val
|
| 86 |
+
|
| 87 |
+
self.embedder = NumberEmbedder(features=output_dim)
|
| 88 |
+
|
| 89 |
+
def forward(self, floats, device=None):
|
| 90 |
+
# Cast the inputs to floats
|
| 91 |
+
floats = [float(x) for x in floats]
|
| 92 |
+
|
| 93 |
+
if device is None:
|
| 94 |
+
device = next(self.embedder.parameters()).device
|
| 95 |
+
|
| 96 |
+
floats = torch.tensor(floats).to(device)
|
| 97 |
+
|
| 98 |
+
floats = floats.clamp(self.min_val, self.max_val)
|
| 99 |
+
|
| 100 |
+
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
| 101 |
+
|
| 102 |
+
# Cast floats to same type as embedder
|
| 103 |
+
embedder_dtype = next(self.embedder.parameters()).dtype
|
| 104 |
+
normalized_floats = normalized_floats.to(embedder_dtype)
|
| 105 |
+
|
| 106 |
+
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
| 107 |
+
|
| 108 |
+
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
ComfyUI/comfy/ldm/aura/mmdit.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#AuraFlow MMDiT
|
| 2 |
+
#Originally written by the AuraFlow Authors
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 11 |
+
import comfy.ops
|
| 12 |
+
import comfy.ldm.common_dit
|
| 13 |
+
|
| 14 |
+
def modulate(x, shift, scale):
|
| 15 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def find_multiple(n: int, k: int) -> int:
|
| 19 |
+
if n % k == 0:
|
| 20 |
+
return n
|
| 21 |
+
return n + k - (n % k)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MLP(nn.Module):
|
| 25 |
+
def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
if hidden_dim is None:
|
| 28 |
+
hidden_dim = 4 * dim
|
| 29 |
+
|
| 30 |
+
n_hidden = int(2 * hidden_dim / 3)
|
| 31 |
+
n_hidden = find_multiple(n_hidden, 256)
|
| 32 |
+
|
| 33 |
+
self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
|
| 34 |
+
self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
|
| 35 |
+
self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
|
| 36 |
+
|
| 37 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
| 39 |
+
x = self.c_proj(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class MultiHeadLayerNorm(nn.Module):
|
| 44 |
+
def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
|
| 45 |
+
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
|
| 46 |
+
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
|
| 49 |
+
self.variance_epsilon = eps
|
| 50 |
+
|
| 51 |
+
def forward(self, hidden_states):
|
| 52 |
+
input_dtype = hidden_states.dtype
|
| 53 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 54 |
+
mean = hidden_states.mean(-1, keepdim=True)
|
| 55 |
+
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
| 56 |
+
hidden_states = (hidden_states - mean) * torch.rsqrt(
|
| 57 |
+
variance + self.variance_epsilon
|
| 58 |
+
)
|
| 59 |
+
hidden_states = self.weight.to(torch.float32) * hidden_states
|
| 60 |
+
return hidden_states.to(input_dtype)
|
| 61 |
+
|
| 62 |
+
class SingleAttention(nn.Module):
|
| 63 |
+
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
|
| 64 |
+
super().__init__()
|
| 65 |
+
|
| 66 |
+
self.n_heads = n_heads
|
| 67 |
+
self.head_dim = dim // n_heads
|
| 68 |
+
|
| 69 |
+
# this is for cond
|
| 70 |
+
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 71 |
+
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 72 |
+
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 73 |
+
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 74 |
+
|
| 75 |
+
self.q_norm1 = (
|
| 76 |
+
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
| 77 |
+
if mh_qknorm
|
| 78 |
+
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 79 |
+
)
|
| 80 |
+
self.k_norm1 = (
|
| 81 |
+
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
| 82 |
+
if mh_qknorm
|
| 83 |
+
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
#@torch.compile()
|
| 87 |
+
def forward(self, c):
|
| 88 |
+
|
| 89 |
+
bsz, seqlen1, _ = c.shape
|
| 90 |
+
|
| 91 |
+
q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
|
| 92 |
+
q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
| 93 |
+
k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
| 94 |
+
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
| 95 |
+
q, k = self.q_norm1(q), self.k_norm1(k)
|
| 96 |
+
|
| 97 |
+
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
| 98 |
+
c = self.w1o(output)
|
| 99 |
+
return c
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class DoubleAttention(nn.Module):
|
| 104 |
+
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
self.n_heads = n_heads
|
| 108 |
+
self.head_dim = dim // n_heads
|
| 109 |
+
|
| 110 |
+
# this is for cond
|
| 111 |
+
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 112 |
+
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 113 |
+
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 114 |
+
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 115 |
+
|
| 116 |
+
# this is for x
|
| 117 |
+
self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 118 |
+
self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 119 |
+
self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 120 |
+
self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
| 121 |
+
|
| 122 |
+
self.q_norm1 = (
|
| 123 |
+
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
| 124 |
+
if mh_qknorm
|
| 125 |
+
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 126 |
+
)
|
| 127 |
+
self.k_norm1 = (
|
| 128 |
+
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
| 129 |
+
if mh_qknorm
|
| 130 |
+
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.q_norm2 = (
|
| 134 |
+
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
| 135 |
+
if mh_qknorm
|
| 136 |
+
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 137 |
+
)
|
| 138 |
+
self.k_norm2 = (
|
| 139 |
+
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
| 140 |
+
if mh_qknorm
|
| 141 |
+
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
#@torch.compile()
|
| 146 |
+
def forward(self, c, x):
|
| 147 |
+
|
| 148 |
+
bsz, seqlen1, _ = c.shape
|
| 149 |
+
bsz, seqlen2, _ = x.shape
|
| 150 |
+
|
| 151 |
+
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
| 152 |
+
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
| 153 |
+
ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
| 154 |
+
cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
| 155 |
+
cq, ck = self.q_norm1(cq), self.k_norm1(ck)
|
| 156 |
+
|
| 157 |
+
xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
|
| 158 |
+
xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
| 159 |
+
xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
| 160 |
+
xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
| 161 |
+
xq, xk = self.q_norm2(xq), self.k_norm2(xk)
|
| 162 |
+
|
| 163 |
+
# concat all
|
| 164 |
+
q, k, v = (
|
| 165 |
+
torch.cat([cq, xq], dim=1),
|
| 166 |
+
torch.cat([ck, xk], dim=1),
|
| 167 |
+
torch.cat([cv, xv], dim=1),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
| 171 |
+
|
| 172 |
+
c, x = output.split([seqlen1, seqlen2], dim=1)
|
| 173 |
+
c = self.w1o(c)
|
| 174 |
+
x = self.w2o(x)
|
| 175 |
+
|
| 176 |
+
return c, x
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class MMDiTBlock(nn.Module):
|
| 180 |
+
def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
|
| 181 |
+
super().__init__()
|
| 182 |
+
|
| 183 |
+
self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 184 |
+
self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 185 |
+
if not is_last:
|
| 186 |
+
self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
| 187 |
+
self.modC = nn.Sequential(
|
| 188 |
+
nn.SiLU(),
|
| 189 |
+
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
self.modC = nn.Sequential(
|
| 193 |
+
nn.SiLU(),
|
| 194 |
+
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 198 |
+
self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 199 |
+
self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
| 200 |
+
self.modX = nn.Sequential(
|
| 201 |
+
nn.SiLU(),
|
| 202 |
+
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
|
| 206 |
+
self.is_last = is_last
|
| 207 |
+
|
| 208 |
+
#@torch.compile()
|
| 209 |
+
def forward(self, c, x, global_cond, **kwargs):
|
| 210 |
+
|
| 211 |
+
cres, xres = c, x
|
| 212 |
+
|
| 213 |
+
cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
|
| 214 |
+
self.modC(global_cond).chunk(6, dim=1)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
c = modulate(self.normC1(c), cshift_msa, cscale_msa)
|
| 218 |
+
|
| 219 |
+
# xpath
|
| 220 |
+
xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
|
| 221 |
+
self.modX(global_cond).chunk(6, dim=1)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
| 225 |
+
|
| 226 |
+
# attention
|
| 227 |
+
c, x = self.attn(c, x)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
| 231 |
+
c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
|
| 232 |
+
c = cres + c
|
| 233 |
+
|
| 234 |
+
x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
|
| 235 |
+
x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
|
| 236 |
+
x = xres + x
|
| 237 |
+
|
| 238 |
+
return c, x
|
| 239 |
+
|
| 240 |
+
class DiTBlock(nn.Module):
|
| 241 |
+
# like MMDiTBlock, but it only has X
|
| 242 |
+
def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
|
| 243 |
+
super().__init__()
|
| 244 |
+
|
| 245 |
+
self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 246 |
+
self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
| 247 |
+
|
| 248 |
+
self.modCX = nn.Sequential(
|
| 249 |
+
nn.SiLU(),
|
| 250 |
+
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
|
| 254 |
+
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
| 255 |
+
|
| 256 |
+
#@torch.compile()
|
| 257 |
+
def forward(self, cx, global_cond, **kwargs):
|
| 258 |
+
cxres = cx
|
| 259 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
| 260 |
+
global_cond
|
| 261 |
+
).chunk(6, dim=1)
|
| 262 |
+
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
| 263 |
+
cx = self.attn(cx)
|
| 264 |
+
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
| 265 |
+
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
| 266 |
+
cx = gate_mlp.unsqueeze(1) * mlpout
|
| 267 |
+
|
| 268 |
+
cx = cxres + cx
|
| 269 |
+
|
| 270 |
+
return cx
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class TimestepEmbedder(nn.Module):
|
| 275 |
+
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.mlp = nn.Sequential(
|
| 278 |
+
operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
|
| 279 |
+
nn.SiLU(),
|
| 280 |
+
operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
|
| 281 |
+
)
|
| 282 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 286 |
+
half = dim // 2
|
| 287 |
+
freqs = 1000 * torch.exp(
|
| 288 |
+
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
| 289 |
+
).to(t.device)
|
| 290 |
+
args = t[:, None] * freqs[None]
|
| 291 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 292 |
+
if dim % 2:
|
| 293 |
+
embedding = torch.cat(
|
| 294 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 295 |
+
)
|
| 296 |
+
return embedding
|
| 297 |
+
|
| 298 |
+
#@torch.compile()
|
| 299 |
+
def forward(self, t, dtype):
|
| 300 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
| 301 |
+
t_emb = self.mlp(t_freq)
|
| 302 |
+
return t_emb
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class MMDiT(nn.Module):
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
in_channels=4,
|
| 309 |
+
out_channels=4,
|
| 310 |
+
patch_size=2,
|
| 311 |
+
dim=3072,
|
| 312 |
+
n_layers=36,
|
| 313 |
+
n_double_layers=4,
|
| 314 |
+
n_heads=12,
|
| 315 |
+
global_conddim=3072,
|
| 316 |
+
cond_seq_dim=2048,
|
| 317 |
+
max_seq=32 * 32,
|
| 318 |
+
device=None,
|
| 319 |
+
dtype=None,
|
| 320 |
+
operations=None,
|
| 321 |
+
):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.dtype = dtype
|
| 324 |
+
|
| 325 |
+
self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
|
| 326 |
+
|
| 327 |
+
self.cond_seq_linear = operations.Linear(
|
| 328 |
+
cond_seq_dim, dim, bias=False, dtype=dtype, device=device
|
| 329 |
+
) # linear for something like text sequence.
|
| 330 |
+
self.init_x_linear = operations.Linear(
|
| 331 |
+
patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
|
| 332 |
+
) # init linear for patchified image.
|
| 333 |
+
|
| 334 |
+
self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
|
| 335 |
+
self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
|
| 336 |
+
|
| 337 |
+
self.double_layers = nn.ModuleList([])
|
| 338 |
+
self.single_layers = nn.ModuleList([])
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
for idx in range(n_double_layers):
|
| 342 |
+
self.double_layers.append(
|
| 343 |
+
MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
for idx in range(n_double_layers, n_layers):
|
| 347 |
+
self.single_layers.append(
|
| 348 |
+
DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
self.final_linear = operations.Linear(
|
| 353 |
+
dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
self.modF = nn.Sequential(
|
| 357 |
+
nn.SiLU(),
|
| 358 |
+
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
self.out_channels = out_channels
|
| 362 |
+
self.patch_size = patch_size
|
| 363 |
+
self.n_double_layers = n_double_layers
|
| 364 |
+
self.n_layers = n_layers
|
| 365 |
+
|
| 366 |
+
self.h_max = round(max_seq**0.5)
|
| 367 |
+
self.w_max = round(max_seq**0.5)
|
| 368 |
+
|
| 369 |
+
@torch.no_grad()
|
| 370 |
+
def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
|
| 371 |
+
# extend pe
|
| 372 |
+
pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
|
| 373 |
+
|
| 374 |
+
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
|
| 375 |
+
|
| 376 |
+
# now we need to extend this to target_dim. for this we will use interpolation.
|
| 377 |
+
# we will use torch.nn.functional.interpolate
|
| 378 |
+
pe_as_2d = F.interpolate(
|
| 379 |
+
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
|
| 380 |
+
)
|
| 381 |
+
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
| 382 |
+
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
| 383 |
+
self.h_max, self.w_max = target_dim
|
| 384 |
+
|
| 385 |
+
def pe_selection_index_based_on_dim(self, h, w):
|
| 386 |
+
h_p, w_p = h // self.patch_size, w // self.patch_size
|
| 387 |
+
original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
|
| 388 |
+
original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
|
| 389 |
+
starth = self.h_max // 2 - h_p // 2
|
| 390 |
+
endh =starth + h_p
|
| 391 |
+
startw = self.w_max // 2 - w_p // 2
|
| 392 |
+
endw = startw + w_p
|
| 393 |
+
original_pe_indexes = original_pe_indexes[
|
| 394 |
+
starth:endh, startw:endw
|
| 395 |
+
]
|
| 396 |
+
return original_pe_indexes.flatten()
|
| 397 |
+
|
| 398 |
+
def unpatchify(self, x, h, w):
|
| 399 |
+
c = self.out_channels
|
| 400 |
+
p = self.patch_size
|
| 401 |
+
|
| 402 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 403 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 404 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 405 |
+
return imgs
|
| 406 |
+
|
| 407 |
+
def patchify(self, x):
|
| 408 |
+
B, C, H, W = x.size()
|
| 409 |
+
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
| 410 |
+
x = x.view(
|
| 411 |
+
B,
|
| 412 |
+
C,
|
| 413 |
+
(H + 1) // self.patch_size,
|
| 414 |
+
self.patch_size,
|
| 415 |
+
(W + 1) // self.patch_size,
|
| 416 |
+
self.patch_size,
|
| 417 |
+
)
|
| 418 |
+
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
| 419 |
+
return x
|
| 420 |
+
|
| 421 |
+
def apply_pos_embeds(self, x, h, w):
|
| 422 |
+
h = (h + 1) // self.patch_size
|
| 423 |
+
w = (w + 1) // self.patch_size
|
| 424 |
+
max_dim = max(h, w)
|
| 425 |
+
|
| 426 |
+
cur_dim = self.h_max
|
| 427 |
+
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
|
| 428 |
+
|
| 429 |
+
if max_dim > cur_dim:
|
| 430 |
+
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
|
| 431 |
+
cur_dim = max_dim
|
| 432 |
+
|
| 433 |
+
from_h = (cur_dim - h) // 2
|
| 434 |
+
from_w = (cur_dim - w) // 2
|
| 435 |
+
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
| 436 |
+
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
| 437 |
+
|
| 438 |
+
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
|
| 439 |
+
patches_replace = transformer_options.get("patches_replace", {})
|
| 440 |
+
# patchify x, add PE
|
| 441 |
+
b, c, h, w = x.shape
|
| 442 |
+
|
| 443 |
+
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
|
| 444 |
+
# print(pe_indexes, pe_indexes.shape)
|
| 445 |
+
|
| 446 |
+
x = self.init_x_linear(self.patchify(x)) # B, T_x, D
|
| 447 |
+
x = self.apply_pos_embeds(x, h, w)
|
| 448 |
+
# x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
|
| 449 |
+
# x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
|
| 450 |
+
|
| 451 |
+
# process conditions for MMDiT Blocks
|
| 452 |
+
c_seq = context # B, T_c, D_c
|
| 453 |
+
t = timestep
|
| 454 |
+
|
| 455 |
+
c = self.cond_seq_linear(c_seq) # B, T_c, D
|
| 456 |
+
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
|
| 457 |
+
|
| 458 |
+
global_cond = self.t_embedder(t, x.dtype) # B, D
|
| 459 |
+
|
| 460 |
+
blocks_replace = patches_replace.get("dit", {})
|
| 461 |
+
if len(self.double_layers) > 0:
|
| 462 |
+
for i, layer in enumerate(self.double_layers):
|
| 463 |
+
if ("double_block", i) in blocks_replace:
|
| 464 |
+
def block_wrap(args):
|
| 465 |
+
out = {}
|
| 466 |
+
out["txt"], out["img"] = layer(args["txt"],
|
| 467 |
+
args["img"],
|
| 468 |
+
args["vec"])
|
| 469 |
+
return out
|
| 470 |
+
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
|
| 471 |
+
c = out["txt"]
|
| 472 |
+
x = out["img"]
|
| 473 |
+
else:
|
| 474 |
+
c, x = layer(c, x, global_cond, **kwargs)
|
| 475 |
+
|
| 476 |
+
if len(self.single_layers) > 0:
|
| 477 |
+
c_len = c.size(1)
|
| 478 |
+
cx = torch.cat([c, x], dim=1)
|
| 479 |
+
for i, layer in enumerate(self.single_layers):
|
| 480 |
+
if ("single_block", i) in blocks_replace:
|
| 481 |
+
def block_wrap(args):
|
| 482 |
+
out = {}
|
| 483 |
+
out["img"] = layer(args["img"], args["vec"])
|
| 484 |
+
return out
|
| 485 |
+
|
| 486 |
+
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
|
| 487 |
+
cx = out["img"]
|
| 488 |
+
else:
|
| 489 |
+
cx = layer(cx, global_cond, **kwargs)
|
| 490 |
+
|
| 491 |
+
x = cx[:, c_len:]
|
| 492 |
+
|
| 493 |
+
fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
|
| 494 |
+
|
| 495 |
+
x = modulate(x, fshift, fscale)
|
| 496 |
+
x = self.final_linear(x)
|
| 497 |
+
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
|
| 498 |
+
return x
|
ComfyUI/comfy/ldm/cascade/common.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Stability AI
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 22 |
+
import comfy.ops
|
| 23 |
+
|
| 24 |
+
class OptimizedAttention(nn.Module):
|
| 25 |
+
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.heads = nhead
|
| 28 |
+
|
| 29 |
+
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
| 30 |
+
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
| 31 |
+
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
| 32 |
+
|
| 33 |
+
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
| 34 |
+
|
| 35 |
+
def forward(self, q, k, v):
|
| 36 |
+
q = self.to_q(q)
|
| 37 |
+
k = self.to_k(k)
|
| 38 |
+
v = self.to_v(v)
|
| 39 |
+
|
| 40 |
+
out = optimized_attention(q, k, v, self.heads)
|
| 41 |
+
|
| 42 |
+
return self.out_proj(out)
|
| 43 |
+
|
| 44 |
+
class Attention2D(nn.Module):
|
| 45 |
+
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
| 48 |
+
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
| 49 |
+
|
| 50 |
+
def forward(self, x, kv, self_attn=False):
|
| 51 |
+
orig_shape = x.shape
|
| 52 |
+
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
| 53 |
+
if self_attn:
|
| 54 |
+
kv = torch.cat([x, kv], dim=1)
|
| 55 |
+
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
| 56 |
+
x = self.attn(x, kv, kv)
|
| 57 |
+
x = x.permute(0, 2, 1).view(*orig_shape)
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def LayerNorm2d_op(operations):
|
| 62 |
+
class LayerNorm2d(operations.LayerNorm):
|
| 63 |
+
def __init__(self, *args, **kwargs):
|
| 64 |
+
super().__init__(*args, **kwargs)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
| 68 |
+
return LayerNorm2d
|
| 69 |
+
|
| 70 |
+
class GlobalResponseNorm(nn.Module):
|
| 71 |
+
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
| 72 |
+
def __init__(self, dim, dtype=None, device=None):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
| 75 |
+
self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
| 79 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 80 |
+
return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ResBlock(nn.Module):
|
| 84 |
+
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
|
| 87 |
+
# self.depthwise = SAMBlock(c, num_heads, expansion)
|
| 88 |
+
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 89 |
+
self.channelwise = nn.Sequential(
|
| 90 |
+
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
|
| 91 |
+
nn.GELU(),
|
| 92 |
+
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
| 93 |
+
nn.Dropout(dropout),
|
| 94 |
+
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def forward(self, x, x_skip=None):
|
| 98 |
+
x_res = x
|
| 99 |
+
x = self.norm(self.depthwise(x))
|
| 100 |
+
if x_skip is not None:
|
| 101 |
+
x = torch.cat([x, x_skip], dim=1)
|
| 102 |
+
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
| 103 |
+
return x + x_res
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class AttnBlock(nn.Module):
|
| 107 |
+
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.self_attn = self_attn
|
| 110 |
+
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 111 |
+
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
|
| 112 |
+
self.kv_mapper = nn.Sequential(
|
| 113 |
+
nn.SiLU(),
|
| 114 |
+
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(self, x, kv):
|
| 118 |
+
kv = self.kv_mapper(kv)
|
| 119 |
+
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class FeedForwardBlock(nn.Module):
|
| 124 |
+
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 127 |
+
self.channelwise = nn.Sequential(
|
| 128 |
+
operations.Linear(c, c * 4, dtype=dtype, device=device),
|
| 129 |
+
nn.GELU(),
|
| 130 |
+
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
| 131 |
+
nn.Dropout(dropout),
|
| 132 |
+
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class TimestepBlock(nn.Module):
|
| 141 |
+
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
|
| 144 |
+
self.conds = conds
|
| 145 |
+
for cname in conds:
|
| 146 |
+
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
|
| 147 |
+
|
| 148 |
+
def forward(self, x, t):
|
| 149 |
+
t = t.chunk(len(self.conds) + 1, dim=1)
|
| 150 |
+
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
| 151 |
+
for i, c in enumerate(self.conds):
|
| 152 |
+
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
| 153 |
+
a, b = a + ac, b + bc
|
| 154 |
+
return x * (1 + a) + b
|
ComfyUI/comfy/ldm/cascade/controlnet.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Stability AI
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torchvision
|
| 20 |
+
from torch import nn
|
| 21 |
+
from .common import LayerNorm2d_op
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CNetResBlock(nn.Module):
|
| 25 |
+
def __init__(self, c, dtype=None, device=None, operations=None):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.blocks = nn.Sequential(
|
| 28 |
+
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
|
| 29 |
+
nn.GELU(),
|
| 30 |
+
operations.Conv2d(c, c, kernel_size=3, padding=1),
|
| 31 |
+
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
|
| 32 |
+
nn.GELU(),
|
| 33 |
+
operations.Conv2d(c, c, kernel_size=3, padding=1),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return x + self.blocks(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ControlNet(nn.Module):
|
| 41 |
+
def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
|
| 42 |
+
super().__init__()
|
| 43 |
+
if bottleneck_mode is None:
|
| 44 |
+
bottleneck_mode = 'effnet'
|
| 45 |
+
self.proj_blocks = proj_blocks
|
| 46 |
+
if bottleneck_mode == 'effnet':
|
| 47 |
+
embd_channels = 1280
|
| 48 |
+
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
| 49 |
+
if c_in != 3:
|
| 50 |
+
in_weights = self.backbone[0][0].weight.data
|
| 51 |
+
self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
|
| 52 |
+
if c_in > 3:
|
| 53 |
+
# nn.init.constant_(self.backbone[0][0].weight, 0)
|
| 54 |
+
self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
|
| 55 |
+
else:
|
| 56 |
+
self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
|
| 57 |
+
elif bottleneck_mode == 'simple':
|
| 58 |
+
embd_channels = c_in
|
| 59 |
+
self.backbone = nn.Sequential(
|
| 60 |
+
operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
|
| 61 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 62 |
+
operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
| 63 |
+
)
|
| 64 |
+
elif bottleneck_mode == 'large':
|
| 65 |
+
self.backbone = nn.Sequential(
|
| 66 |
+
operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
|
| 67 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 68 |
+
operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
|
| 69 |
+
*[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
|
| 70 |
+
operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
|
| 71 |
+
)
|
| 72 |
+
embd_channels = 1280
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
|
| 75 |
+
self.projections = nn.ModuleList()
|
| 76 |
+
for _ in range(len(proj_blocks)):
|
| 77 |
+
self.projections.append(nn.Sequential(
|
| 78 |
+
operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
|
| 79 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 80 |
+
operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
|
| 81 |
+
))
|
| 82 |
+
# nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
|
| 83 |
+
self.xl = False
|
| 84 |
+
self.input_channels = c_in
|
| 85 |
+
self.unshuffle_amount = 8
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
x = self.backbone(x)
|
| 89 |
+
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
|
| 90 |
+
for i, idx in enumerate(self.proj_blocks):
|
| 91 |
+
proj_outputs[idx] = self.projections[i](x)
|
| 92 |
+
return {"input": proj_outputs[::-1]}
|
ComfyUI/comfy/ldm/cascade/stage_a.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Stability AI
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
from torch.autograd import Function
|
| 22 |
+
import comfy.ops
|
| 23 |
+
|
| 24 |
+
ops = comfy.ops.disable_weight_init
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class vector_quantize(Function):
|
| 28 |
+
@staticmethod
|
| 29 |
+
def forward(ctx, x, codebook):
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
codebook_sqr = torch.sum(codebook ** 2, dim=1)
|
| 32 |
+
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
|
| 33 |
+
|
| 34 |
+
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
|
| 35 |
+
_, indices = dist.min(dim=1)
|
| 36 |
+
|
| 37 |
+
ctx.save_for_backward(indices, codebook)
|
| 38 |
+
ctx.mark_non_differentiable(indices)
|
| 39 |
+
|
| 40 |
+
nn = torch.index_select(codebook, 0, indices)
|
| 41 |
+
return nn, indices
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def backward(ctx, grad_output, grad_indices):
|
| 45 |
+
grad_inputs, grad_codebook = None, None
|
| 46 |
+
|
| 47 |
+
if ctx.needs_input_grad[0]:
|
| 48 |
+
grad_inputs = grad_output.clone()
|
| 49 |
+
if ctx.needs_input_grad[1]:
|
| 50 |
+
# Gradient wrt. the codebook
|
| 51 |
+
indices, codebook = ctx.saved_tensors
|
| 52 |
+
|
| 53 |
+
grad_codebook = torch.zeros_like(codebook)
|
| 54 |
+
grad_codebook.index_add_(0, indices, grad_output)
|
| 55 |
+
|
| 56 |
+
return (grad_inputs, grad_codebook)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class VectorQuantize(nn.Module):
|
| 60 |
+
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
|
| 61 |
+
"""
|
| 62 |
+
Takes an input of variable size (as long as the last dimension matches the embedding size).
|
| 63 |
+
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
|
| 64 |
+
with the same size as the input, vq and commitment components for the loss as a touple
|
| 65 |
+
in the second output and the indices of the quantized vectors in the third:
|
| 66 |
+
quantized, (vq_loss, commit_loss), indices
|
| 67 |
+
"""
|
| 68 |
+
super(VectorQuantize, self).__init__()
|
| 69 |
+
|
| 70 |
+
self.codebook = nn.Embedding(k, embedding_size)
|
| 71 |
+
self.codebook.weight.data.uniform_(-1./k, 1./k)
|
| 72 |
+
self.vq = vector_quantize.apply
|
| 73 |
+
|
| 74 |
+
self.ema_decay = ema_decay
|
| 75 |
+
self.ema_loss = ema_loss
|
| 76 |
+
if ema_loss:
|
| 77 |
+
self.register_buffer('ema_element_count', torch.ones(k))
|
| 78 |
+
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
|
| 79 |
+
|
| 80 |
+
def _laplace_smoothing(self, x, epsilon):
|
| 81 |
+
n = torch.sum(x)
|
| 82 |
+
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
| 83 |
+
|
| 84 |
+
def _updateEMA(self, z_e_x, indices):
|
| 85 |
+
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
| 86 |
+
elem_count = mask.sum(dim=0)
|
| 87 |
+
weight_sum = torch.mm(mask.t(), z_e_x)
|
| 88 |
+
|
| 89 |
+
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
| 90 |
+
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
| 91 |
+
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
| 92 |
+
|
| 93 |
+
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
| 94 |
+
|
| 95 |
+
def idx2vq(self, idx, dim=-1):
|
| 96 |
+
q_idx = self.codebook(idx)
|
| 97 |
+
if dim != -1:
|
| 98 |
+
q_idx = q_idx.movedim(-1, dim)
|
| 99 |
+
return q_idx
|
| 100 |
+
|
| 101 |
+
def forward(self, x, get_losses=True, dim=-1):
|
| 102 |
+
if dim != -1:
|
| 103 |
+
x = x.movedim(dim, -1)
|
| 104 |
+
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
|
| 105 |
+
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
|
| 106 |
+
vq_loss, commit_loss = None, None
|
| 107 |
+
if self.ema_loss and self.training:
|
| 108 |
+
self._updateEMA(z_e_x.detach(), indices.detach())
|
| 109 |
+
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
|
| 110 |
+
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
|
| 111 |
+
if get_losses:
|
| 112 |
+
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
|
| 113 |
+
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
|
| 114 |
+
|
| 115 |
+
z_q_x = z_q_x.view(x.shape)
|
| 116 |
+
if dim != -1:
|
| 117 |
+
z_q_x = z_q_x.movedim(-1, dim)
|
| 118 |
+
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class ResBlock(nn.Module):
|
| 122 |
+
def __init__(self, c, c_hidden):
|
| 123 |
+
super().__init__()
|
| 124 |
+
# depthwise/attention
|
| 125 |
+
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
| 126 |
+
self.depthwise = nn.Sequential(
|
| 127 |
+
nn.ReplicationPad2d(1),
|
| 128 |
+
ops.Conv2d(c, c, kernel_size=3, groups=c)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# channelwise
|
| 132 |
+
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
| 133 |
+
self.channelwise = nn.Sequential(
|
| 134 |
+
ops.Linear(c, c_hidden),
|
| 135 |
+
nn.GELU(),
|
| 136 |
+
ops.Linear(c_hidden, c),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
| 140 |
+
|
| 141 |
+
# Init weights
|
| 142 |
+
def _basic_init(module):
|
| 143 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
| 144 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 145 |
+
if module.bias is not None:
|
| 146 |
+
nn.init.constant_(module.bias, 0)
|
| 147 |
+
|
| 148 |
+
self.apply(_basic_init)
|
| 149 |
+
|
| 150 |
+
def _norm(self, x, norm):
|
| 151 |
+
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
| 152 |
+
|
| 153 |
+
def forward(self, x):
|
| 154 |
+
mods = self.gammas
|
| 155 |
+
|
| 156 |
+
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
| 157 |
+
try:
|
| 158 |
+
x = x + self.depthwise(x_temp) * mods[2]
|
| 159 |
+
except: #operation not implemented for bf16
|
| 160 |
+
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
|
| 161 |
+
x = x + self.depthwise[1](x_temp) * mods[2]
|
| 162 |
+
|
| 163 |
+
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
| 164 |
+
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
| 165 |
+
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class StageA(nn.Module):
|
| 170 |
+
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.c_latent = c_latent
|
| 173 |
+
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
| 174 |
+
|
| 175 |
+
# Encoder blocks
|
| 176 |
+
self.in_block = nn.Sequential(
|
| 177 |
+
nn.PixelUnshuffle(2),
|
| 178 |
+
ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
| 179 |
+
)
|
| 180 |
+
down_blocks = []
|
| 181 |
+
for i in range(levels):
|
| 182 |
+
if i > 0:
|
| 183 |
+
down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
| 184 |
+
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
| 185 |
+
down_blocks.append(block)
|
| 186 |
+
down_blocks.append(nn.Sequential(
|
| 187 |
+
ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
| 188 |
+
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
| 189 |
+
))
|
| 190 |
+
self.down_blocks = nn.Sequential(*down_blocks)
|
| 191 |
+
self.down_blocks[0]
|
| 192 |
+
|
| 193 |
+
self.codebook_size = codebook_size
|
| 194 |
+
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
|
| 195 |
+
|
| 196 |
+
# Decoder blocks
|
| 197 |
+
up_blocks = [nn.Sequential(
|
| 198 |
+
ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
| 199 |
+
)]
|
| 200 |
+
for i in range(levels):
|
| 201 |
+
for j in range(bottleneck_blocks if i == 0 else 1):
|
| 202 |
+
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
|
| 203 |
+
up_blocks.append(block)
|
| 204 |
+
if i < levels - 1:
|
| 205 |
+
up_blocks.append(
|
| 206 |
+
ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
| 207 |
+
padding=1))
|
| 208 |
+
self.up_blocks = nn.Sequential(*up_blocks)
|
| 209 |
+
self.out_block = nn.Sequential(
|
| 210 |
+
ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
| 211 |
+
nn.PixelShuffle(2),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def encode(self, x, quantize=False):
|
| 215 |
+
x = self.in_block(x)
|
| 216 |
+
x = self.down_blocks(x)
|
| 217 |
+
if quantize:
|
| 218 |
+
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
| 219 |
+
return qe, x, indices, vq_loss + commit_loss * 0.25
|
| 220 |
+
else:
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
def decode(self, x):
|
| 224 |
+
x = self.up_blocks(x)
|
| 225 |
+
x = self.out_block(x)
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
def forward(self, x, quantize=False):
|
| 229 |
+
qe, x, _, vq_loss = self.encode(x, quantize)
|
| 230 |
+
x = self.decode(qe)
|
| 231 |
+
return x, vq_loss
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class Discriminator(nn.Module):
|
| 235 |
+
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
|
| 236 |
+
super().__init__()
|
| 237 |
+
d = max(depth - 3, 3)
|
| 238 |
+
layers = [
|
| 239 |
+
nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
| 240 |
+
nn.LeakyReLU(0.2),
|
| 241 |
+
]
|
| 242 |
+
for i in range(depth - 1):
|
| 243 |
+
c_in = c_hidden // (2 ** max((d - i), 0))
|
| 244 |
+
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
| 245 |
+
layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
| 246 |
+
layers.append(nn.InstanceNorm2d(c_out))
|
| 247 |
+
layers.append(nn.LeakyReLU(0.2))
|
| 248 |
+
self.encoder = nn.Sequential(*layers)
|
| 249 |
+
self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
| 250 |
+
self.logits = nn.Sigmoid()
|
| 251 |
+
|
| 252 |
+
def forward(self, x, cond=None):
|
| 253 |
+
x = self.encoder(x)
|
| 254 |
+
if cond is not None:
|
| 255 |
+
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
| 256 |
+
x = torch.cat([x, cond], dim=1)
|
| 257 |
+
x = self.shuffle(x)
|
| 258 |
+
x = self.logits(x)
|
| 259 |
+
return x
|
ComfyUI/comfy/ldm/cascade/stage_b.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Stability AI
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
| 23 |
+
|
| 24 |
+
class StageB(nn.Module):
|
| 25 |
+
def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
|
| 26 |
+
nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
|
| 27 |
+
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
|
| 28 |
+
c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
|
| 29 |
+
t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.dtype = dtype
|
| 32 |
+
self.c_r = c_r
|
| 33 |
+
self.t_conds = t_conds
|
| 34 |
+
self.c_clip_seq = c_clip_seq
|
| 35 |
+
if not isinstance(dropout, list):
|
| 36 |
+
dropout = [dropout] * len(c_hidden)
|
| 37 |
+
if not isinstance(self_attn, list):
|
| 38 |
+
self_attn = [self_attn] * len(c_hidden)
|
| 39 |
+
|
| 40 |
+
# CONDITIONING
|
| 41 |
+
self.effnet_mapper = nn.Sequential(
|
| 42 |
+
operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
| 43 |
+
nn.GELU(),
|
| 44 |
+
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
| 45 |
+
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 46 |
+
)
|
| 47 |
+
self.pixels_mapper = nn.Sequential(
|
| 48 |
+
operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
| 49 |
+
nn.GELU(),
|
| 50 |
+
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
| 51 |
+
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 52 |
+
)
|
| 53 |
+
self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
|
| 54 |
+
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 55 |
+
|
| 56 |
+
self.embedding = nn.Sequential(
|
| 57 |
+
nn.PixelUnshuffle(patch_size),
|
| 58 |
+
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
| 59 |
+
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
| 63 |
+
if block_type == 'C':
|
| 64 |
+
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
| 65 |
+
elif block_type == 'A':
|
| 66 |
+
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
| 67 |
+
elif block_type == 'F':
|
| 68 |
+
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
| 69 |
+
elif block_type == 'T':
|
| 70 |
+
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
| 71 |
+
else:
|
| 72 |
+
raise Exception(f'Block type {block_type} not supported')
|
| 73 |
+
|
| 74 |
+
# BLOCKS
|
| 75 |
+
# -- down blocks
|
| 76 |
+
self.down_blocks = nn.ModuleList()
|
| 77 |
+
self.down_downscalers = nn.ModuleList()
|
| 78 |
+
self.down_repeat_mappers = nn.ModuleList()
|
| 79 |
+
for i in range(len(c_hidden)):
|
| 80 |
+
if i > 0:
|
| 81 |
+
self.down_downscalers.append(nn.Sequential(
|
| 82 |
+
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
| 83 |
+
operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
|
| 84 |
+
))
|
| 85 |
+
else:
|
| 86 |
+
self.down_downscalers.append(nn.Identity())
|
| 87 |
+
down_block = nn.ModuleList()
|
| 88 |
+
for _ in range(blocks[0][i]):
|
| 89 |
+
for block_type in level_config[i]:
|
| 90 |
+
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
| 91 |
+
down_block.append(block)
|
| 92 |
+
self.down_blocks.append(down_block)
|
| 93 |
+
if block_repeat is not None:
|
| 94 |
+
block_repeat_mappers = nn.ModuleList()
|
| 95 |
+
for _ in range(block_repeat[0][i] - 1):
|
| 96 |
+
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
| 97 |
+
self.down_repeat_mappers.append(block_repeat_mappers)
|
| 98 |
+
|
| 99 |
+
# -- up blocks
|
| 100 |
+
self.up_blocks = nn.ModuleList()
|
| 101 |
+
self.up_upscalers = nn.ModuleList()
|
| 102 |
+
self.up_repeat_mappers = nn.ModuleList()
|
| 103 |
+
for i in reversed(range(len(c_hidden))):
|
| 104 |
+
if i > 0:
|
| 105 |
+
self.up_upscalers.append(nn.Sequential(
|
| 106 |
+
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
| 107 |
+
operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
|
| 108 |
+
))
|
| 109 |
+
else:
|
| 110 |
+
self.up_upscalers.append(nn.Identity())
|
| 111 |
+
up_block = nn.ModuleList()
|
| 112 |
+
for j in range(blocks[1][::-1][i]):
|
| 113 |
+
for k, block_type in enumerate(level_config[i]):
|
| 114 |
+
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
| 115 |
+
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
| 116 |
+
self_attn=self_attn[i])
|
| 117 |
+
up_block.append(block)
|
| 118 |
+
self.up_blocks.append(up_block)
|
| 119 |
+
if block_repeat is not None:
|
| 120 |
+
block_repeat_mappers = nn.ModuleList()
|
| 121 |
+
for _ in range(block_repeat[1][::-1][i] - 1):
|
| 122 |
+
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
| 123 |
+
self.up_repeat_mappers.append(block_repeat_mappers)
|
| 124 |
+
|
| 125 |
+
# OUTPUT
|
| 126 |
+
self.clf = nn.Sequential(
|
| 127 |
+
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
| 128 |
+
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
| 129 |
+
nn.PixelShuffle(patch_size),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# --- WEIGHT INIT ---
|
| 133 |
+
# self.apply(self._init_weights) # General init
|
| 134 |
+
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
|
| 135 |
+
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
|
| 136 |
+
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
|
| 137 |
+
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
|
| 138 |
+
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
| 139 |
+
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
| 140 |
+
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
| 141 |
+
#
|
| 142 |
+
# # blocks
|
| 143 |
+
# for level_block in self.down_blocks + self.up_blocks:
|
| 144 |
+
# for block in level_block:
|
| 145 |
+
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
| 146 |
+
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
| 147 |
+
# elif isinstance(block, TimestepBlock):
|
| 148 |
+
# for layer in block.modules():
|
| 149 |
+
# if isinstance(layer, nn.Linear):
|
| 150 |
+
# nn.init.constant_(layer.weight, 0)
|
| 151 |
+
#
|
| 152 |
+
# def _init_weights(self, m):
|
| 153 |
+
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 154 |
+
# torch.nn.init.xavier_uniform_(m.weight)
|
| 155 |
+
# if m.bias is not None:
|
| 156 |
+
# nn.init.constant_(m.bias, 0)
|
| 157 |
+
|
| 158 |
+
def gen_r_embedding(self, r, max_positions=10000):
|
| 159 |
+
r = r * max_positions
|
| 160 |
+
half_dim = self.c_r // 2
|
| 161 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
| 162 |
+
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
| 163 |
+
emb = r[:, None] * emb[None, :]
|
| 164 |
+
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
| 165 |
+
if self.c_r % 2 == 1: # zero pad
|
| 166 |
+
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
| 167 |
+
return emb
|
| 168 |
+
|
| 169 |
+
def gen_c_embeddings(self, clip):
|
| 170 |
+
if len(clip.shape) == 2:
|
| 171 |
+
clip = clip.unsqueeze(1)
|
| 172 |
+
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
|
| 173 |
+
clip = self.clip_norm(clip)
|
| 174 |
+
return clip
|
| 175 |
+
|
| 176 |
+
def _down_encode(self, x, r_embed, clip):
|
| 177 |
+
level_outputs = []
|
| 178 |
+
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
| 179 |
+
for down_block, downscaler, repmap in block_group:
|
| 180 |
+
x = downscaler(x)
|
| 181 |
+
for i in range(len(repmap) + 1):
|
| 182 |
+
for block in down_block:
|
| 183 |
+
if isinstance(block, ResBlock) or (
|
| 184 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 185 |
+
ResBlock)):
|
| 186 |
+
x = block(x)
|
| 187 |
+
elif isinstance(block, AttnBlock) or (
|
| 188 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 189 |
+
AttnBlock)):
|
| 190 |
+
x = block(x, clip)
|
| 191 |
+
elif isinstance(block, TimestepBlock) or (
|
| 192 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 193 |
+
TimestepBlock)):
|
| 194 |
+
x = block(x, r_embed)
|
| 195 |
+
else:
|
| 196 |
+
x = block(x)
|
| 197 |
+
if i < len(repmap):
|
| 198 |
+
x = repmap[i](x)
|
| 199 |
+
level_outputs.insert(0, x)
|
| 200 |
+
return level_outputs
|
| 201 |
+
|
| 202 |
+
def _up_decode(self, level_outputs, r_embed, clip):
|
| 203 |
+
x = level_outputs[0]
|
| 204 |
+
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
| 205 |
+
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
| 206 |
+
for j in range(len(repmap) + 1):
|
| 207 |
+
for k, block in enumerate(up_block):
|
| 208 |
+
if isinstance(block, ResBlock) or (
|
| 209 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 210 |
+
ResBlock)):
|
| 211 |
+
skip = level_outputs[i] if k == 0 and i > 0 else None
|
| 212 |
+
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
| 213 |
+
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
| 214 |
+
align_corners=True)
|
| 215 |
+
x = block(x, skip)
|
| 216 |
+
elif isinstance(block, AttnBlock) or (
|
| 217 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 218 |
+
AttnBlock)):
|
| 219 |
+
x = block(x, clip)
|
| 220 |
+
elif isinstance(block, TimestepBlock) or (
|
| 221 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 222 |
+
TimestepBlock)):
|
| 223 |
+
x = block(x, r_embed)
|
| 224 |
+
else:
|
| 225 |
+
x = block(x)
|
| 226 |
+
if j < len(repmap):
|
| 227 |
+
x = repmap[j](x)
|
| 228 |
+
x = upscaler(x)
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
| 232 |
+
if pixels is None:
|
| 233 |
+
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
| 234 |
+
|
| 235 |
+
# Process the conditioning embeddings
|
| 236 |
+
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
| 237 |
+
for c in self.t_conds:
|
| 238 |
+
t_cond = kwargs.get(c, torch.zeros_like(r))
|
| 239 |
+
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
| 240 |
+
clip = self.gen_c_embeddings(clip)
|
| 241 |
+
|
| 242 |
+
# Model Blocks
|
| 243 |
+
x = self.embedding(x)
|
| 244 |
+
x = x + self.effnet_mapper(
|
| 245 |
+
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
| 246 |
+
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
| 247 |
+
align_corners=True)
|
| 248 |
+
level_outputs = self._down_encode(x, r_embed, clip)
|
| 249 |
+
x = self._up_decode(level_outputs, r_embed, clip)
|
| 250 |
+
return self.clf(x)
|
| 251 |
+
|
| 252 |
+
def update_weights_ema(self, src_model, beta=0.999):
|
| 253 |
+
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
| 254 |
+
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
| 255 |
+
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
| 256 |
+
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
ComfyUI/comfy/ldm/cascade/stage_c.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Stability AI
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
import math
|
| 22 |
+
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
| 23 |
+
# from .controlnet import ControlNetDeliverer
|
| 24 |
+
|
| 25 |
+
class UpDownBlock2d(nn.Module):
|
| 26 |
+
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert mode in ['up', 'down']
|
| 29 |
+
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
|
| 30 |
+
align_corners=True) if enabled else nn.Identity()
|
| 31 |
+
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
|
| 32 |
+
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
for block in self.blocks:
|
| 36 |
+
x = block(x)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class StageC(nn.Module):
|
| 41 |
+
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
|
| 42 |
+
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
|
| 43 |
+
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
|
| 44 |
+
dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
|
| 45 |
+
dtype=None, device=None, operations=None):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.dtype = dtype
|
| 48 |
+
self.c_r = c_r
|
| 49 |
+
self.t_conds = t_conds
|
| 50 |
+
self.c_clip_seq = c_clip_seq
|
| 51 |
+
if not isinstance(dropout, list):
|
| 52 |
+
dropout = [dropout] * len(c_hidden)
|
| 53 |
+
if not isinstance(self_attn, list):
|
| 54 |
+
self_attn = [self_attn] * len(c_hidden)
|
| 55 |
+
|
| 56 |
+
# CONDITIONING
|
| 57 |
+
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
|
| 58 |
+
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
|
| 59 |
+
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
|
| 60 |
+
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 61 |
+
|
| 62 |
+
self.embedding = nn.Sequential(
|
| 63 |
+
nn.PixelUnshuffle(patch_size),
|
| 64 |
+
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
| 65 |
+
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
| 69 |
+
if block_type == 'C':
|
| 70 |
+
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
| 71 |
+
elif block_type == 'A':
|
| 72 |
+
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
| 73 |
+
elif block_type == 'F':
|
| 74 |
+
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
| 75 |
+
elif block_type == 'T':
|
| 76 |
+
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
| 77 |
+
else:
|
| 78 |
+
raise Exception(f'Block type {block_type} not supported')
|
| 79 |
+
|
| 80 |
+
# BLOCKS
|
| 81 |
+
# -- down blocks
|
| 82 |
+
self.down_blocks = nn.ModuleList()
|
| 83 |
+
self.down_downscalers = nn.ModuleList()
|
| 84 |
+
self.down_repeat_mappers = nn.ModuleList()
|
| 85 |
+
for i in range(len(c_hidden)):
|
| 86 |
+
if i > 0:
|
| 87 |
+
self.down_downscalers.append(nn.Sequential(
|
| 88 |
+
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
|
| 89 |
+
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
| 90 |
+
))
|
| 91 |
+
else:
|
| 92 |
+
self.down_downscalers.append(nn.Identity())
|
| 93 |
+
down_block = nn.ModuleList()
|
| 94 |
+
for _ in range(blocks[0][i]):
|
| 95 |
+
for block_type in level_config[i]:
|
| 96 |
+
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
| 97 |
+
down_block.append(block)
|
| 98 |
+
self.down_blocks.append(down_block)
|
| 99 |
+
if block_repeat is not None:
|
| 100 |
+
block_repeat_mappers = nn.ModuleList()
|
| 101 |
+
for _ in range(block_repeat[0][i] - 1):
|
| 102 |
+
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
| 103 |
+
self.down_repeat_mappers.append(block_repeat_mappers)
|
| 104 |
+
|
| 105 |
+
# -- up blocks
|
| 106 |
+
self.up_blocks = nn.ModuleList()
|
| 107 |
+
self.up_upscalers = nn.ModuleList()
|
| 108 |
+
self.up_repeat_mappers = nn.ModuleList()
|
| 109 |
+
for i in reversed(range(len(c_hidden))):
|
| 110 |
+
if i > 0:
|
| 111 |
+
self.up_upscalers.append(nn.Sequential(
|
| 112 |
+
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
|
| 113 |
+
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
| 114 |
+
))
|
| 115 |
+
else:
|
| 116 |
+
self.up_upscalers.append(nn.Identity())
|
| 117 |
+
up_block = nn.ModuleList()
|
| 118 |
+
for j in range(blocks[1][::-1][i]):
|
| 119 |
+
for k, block_type in enumerate(level_config[i]):
|
| 120 |
+
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
| 121 |
+
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
| 122 |
+
self_attn=self_attn[i])
|
| 123 |
+
up_block.append(block)
|
| 124 |
+
self.up_blocks.append(up_block)
|
| 125 |
+
if block_repeat is not None:
|
| 126 |
+
block_repeat_mappers = nn.ModuleList()
|
| 127 |
+
for _ in range(block_repeat[1][::-1][i] - 1):
|
| 128 |
+
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
| 129 |
+
self.up_repeat_mappers.append(block_repeat_mappers)
|
| 130 |
+
|
| 131 |
+
# OUTPUT
|
| 132 |
+
self.clf = nn.Sequential(
|
| 133 |
+
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
| 134 |
+
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
| 135 |
+
nn.PixelShuffle(patch_size),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# --- WEIGHT INIT ---
|
| 139 |
+
# self.apply(self._init_weights) # General init
|
| 140 |
+
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
|
| 141 |
+
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
|
| 142 |
+
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
| 143 |
+
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
| 144 |
+
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
| 145 |
+
#
|
| 146 |
+
# # blocks
|
| 147 |
+
# for level_block in self.down_blocks + self.up_blocks:
|
| 148 |
+
# for block in level_block:
|
| 149 |
+
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
| 150 |
+
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
| 151 |
+
# elif isinstance(block, TimestepBlock):
|
| 152 |
+
# for layer in block.modules():
|
| 153 |
+
# if isinstance(layer, nn.Linear):
|
| 154 |
+
# nn.init.constant_(layer.weight, 0)
|
| 155 |
+
#
|
| 156 |
+
# def _init_weights(self, m):
|
| 157 |
+
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 158 |
+
# torch.nn.init.xavier_uniform_(m.weight)
|
| 159 |
+
# if m.bias is not None:
|
| 160 |
+
# nn.init.constant_(m.bias, 0)
|
| 161 |
+
|
| 162 |
+
def gen_r_embedding(self, r, max_positions=10000):
|
| 163 |
+
r = r * max_positions
|
| 164 |
+
half_dim = self.c_r // 2
|
| 165 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
| 166 |
+
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
| 167 |
+
emb = r[:, None] * emb[None, :]
|
| 168 |
+
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
| 169 |
+
if self.c_r % 2 == 1: # zero pad
|
| 170 |
+
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
| 171 |
+
return emb
|
| 172 |
+
|
| 173 |
+
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
|
| 174 |
+
clip_txt = self.clip_txt_mapper(clip_txt)
|
| 175 |
+
if len(clip_txt_pooled.shape) == 2:
|
| 176 |
+
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
|
| 177 |
+
if len(clip_img.shape) == 2:
|
| 178 |
+
clip_img = clip_img.unsqueeze(1)
|
| 179 |
+
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
|
| 180 |
+
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
|
| 181 |
+
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
| 182 |
+
clip = self.clip_norm(clip)
|
| 183 |
+
return clip
|
| 184 |
+
|
| 185 |
+
def _down_encode(self, x, r_embed, clip, cnet=None):
|
| 186 |
+
level_outputs = []
|
| 187 |
+
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
| 188 |
+
for down_block, downscaler, repmap in block_group:
|
| 189 |
+
x = downscaler(x)
|
| 190 |
+
for i in range(len(repmap) + 1):
|
| 191 |
+
for block in down_block:
|
| 192 |
+
if isinstance(block, ResBlock) or (
|
| 193 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 194 |
+
ResBlock)):
|
| 195 |
+
if cnet is not None:
|
| 196 |
+
next_cnet = cnet.pop()
|
| 197 |
+
if next_cnet is not None:
|
| 198 |
+
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
| 199 |
+
align_corners=True).to(x.dtype)
|
| 200 |
+
x = block(x)
|
| 201 |
+
elif isinstance(block, AttnBlock) or (
|
| 202 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 203 |
+
AttnBlock)):
|
| 204 |
+
x = block(x, clip)
|
| 205 |
+
elif isinstance(block, TimestepBlock) or (
|
| 206 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 207 |
+
TimestepBlock)):
|
| 208 |
+
x = block(x, r_embed)
|
| 209 |
+
else:
|
| 210 |
+
x = block(x)
|
| 211 |
+
if i < len(repmap):
|
| 212 |
+
x = repmap[i](x)
|
| 213 |
+
level_outputs.insert(0, x)
|
| 214 |
+
return level_outputs
|
| 215 |
+
|
| 216 |
+
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
| 217 |
+
x = level_outputs[0]
|
| 218 |
+
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
| 219 |
+
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
| 220 |
+
for j in range(len(repmap) + 1):
|
| 221 |
+
for k, block in enumerate(up_block):
|
| 222 |
+
if isinstance(block, ResBlock) or (
|
| 223 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 224 |
+
ResBlock)):
|
| 225 |
+
skip = level_outputs[i] if k == 0 and i > 0 else None
|
| 226 |
+
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
| 227 |
+
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
| 228 |
+
align_corners=True)
|
| 229 |
+
if cnet is not None:
|
| 230 |
+
next_cnet = cnet.pop()
|
| 231 |
+
if next_cnet is not None:
|
| 232 |
+
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
| 233 |
+
align_corners=True).to(x.dtype)
|
| 234 |
+
x = block(x, skip)
|
| 235 |
+
elif isinstance(block, AttnBlock) or (
|
| 236 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 237 |
+
AttnBlock)):
|
| 238 |
+
x = block(x, clip)
|
| 239 |
+
elif isinstance(block, TimestepBlock) or (
|
| 240 |
+
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
| 241 |
+
TimestepBlock)):
|
| 242 |
+
x = block(x, r_embed)
|
| 243 |
+
else:
|
| 244 |
+
x = block(x)
|
| 245 |
+
if j < len(repmap):
|
| 246 |
+
x = repmap[j](x)
|
| 247 |
+
x = upscaler(x)
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
| 251 |
+
# Process the conditioning embeddings
|
| 252 |
+
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
| 253 |
+
for c in self.t_conds:
|
| 254 |
+
t_cond = kwargs.get(c, torch.zeros_like(r))
|
| 255 |
+
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
| 256 |
+
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
|
| 257 |
+
|
| 258 |
+
if control is not None:
|
| 259 |
+
cnet = control.get("input")
|
| 260 |
+
else:
|
| 261 |
+
cnet = None
|
| 262 |
+
|
| 263 |
+
# Model Blocks
|
| 264 |
+
x = self.embedding(x)
|
| 265 |
+
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
| 266 |
+
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
| 267 |
+
return self.clf(x)
|
| 268 |
+
|
| 269 |
+
def update_weights_ema(self, src_model, beta=0.999):
|
| 270 |
+
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
| 271 |
+
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
| 272 |
+
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
| 273 |
+
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
ComfyUI/comfy/ldm/cascade/stage_c_coder.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Stability AI
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
import torch
|
| 19 |
+
import torchvision
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
import comfy.ops
|
| 23 |
+
|
| 24 |
+
ops = comfy.ops.disable_weight_init
|
| 25 |
+
|
| 26 |
+
# EfficientNet
|
| 27 |
+
class EfficientNetEncoder(nn.Module):
|
| 28 |
+
def __init__(self, c_latent=16):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
| 31 |
+
self.mapper = nn.Sequential(
|
| 32 |
+
ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
| 33 |
+
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
| 34 |
+
)
|
| 35 |
+
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
| 36 |
+
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
x = x * 0.5 + 0.5
|
| 40 |
+
x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
|
| 41 |
+
o = self.mapper(self.backbone(x))
|
| 42 |
+
return o
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
|
| 46 |
+
class Previewer(nn.Module):
|
| 47 |
+
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.blocks = nn.Sequential(
|
| 50 |
+
ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
| 51 |
+
nn.GELU(),
|
| 52 |
+
nn.BatchNorm2d(c_hidden),
|
| 53 |
+
|
| 54 |
+
ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
| 55 |
+
nn.GELU(),
|
| 56 |
+
nn.BatchNorm2d(c_hidden),
|
| 57 |
+
|
| 58 |
+
ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
| 59 |
+
nn.GELU(),
|
| 60 |
+
nn.BatchNorm2d(c_hidden // 2),
|
| 61 |
+
|
| 62 |
+
ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
nn.BatchNorm2d(c_hidden // 2),
|
| 65 |
+
|
| 66 |
+
ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
| 67 |
+
nn.GELU(),
|
| 68 |
+
nn.BatchNorm2d(c_hidden // 4),
|
| 69 |
+
|
| 70 |
+
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
| 71 |
+
nn.GELU(),
|
| 72 |
+
nn.BatchNorm2d(c_hidden // 4),
|
| 73 |
+
|
| 74 |
+
ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
| 75 |
+
nn.GELU(),
|
| 76 |
+
nn.BatchNorm2d(c_hidden // 4),
|
| 77 |
+
|
| 78 |
+
ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
| 79 |
+
nn.GELU(),
|
| 80 |
+
nn.BatchNorm2d(c_hidden // 4),
|
| 81 |
+
|
| 82 |
+
ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
return (self.blocks(x) - 0.5) * 2.0
|
| 87 |
+
|
| 88 |
+
class StageC_coder(nn.Module):
|
| 89 |
+
def __init__(self):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.previewer = Previewer()
|
| 92 |
+
self.encoder = EfficientNetEncoder()
|
| 93 |
+
|
| 94 |
+
def encode(self, x):
|
| 95 |
+
return self.encoder(x)
|
| 96 |
+
|
| 97 |
+
def decode(self, x):
|
| 98 |
+
return self.previewer(x)
|
ComfyUI/comfy/ldm/chroma/layers.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor, nn
|
| 3 |
+
|
| 4 |
+
from comfy.ldm.flux.math import attention
|
| 5 |
+
from comfy.ldm.flux.layers import (
|
| 6 |
+
MLPEmbedder,
|
| 7 |
+
RMSNorm,
|
| 8 |
+
QKNorm,
|
| 9 |
+
SelfAttention,
|
| 10 |
+
ModulationOut,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ChromaModulationOut(ModulationOut):
|
| 16 |
+
@classmethod
|
| 17 |
+
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
| 18 |
+
return cls(
|
| 19 |
+
shift=tensor[:, offset : offset + 1, :],
|
| 20 |
+
scale=tensor[:, offset + 1 : offset + 2, :],
|
| 21 |
+
gate=tensor[:, offset + 2 : offset + 3, :],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Approximator(nn.Module):
|
| 28 |
+
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
| 31 |
+
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
| 32 |
+
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
| 33 |
+
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def device(self):
|
| 37 |
+
# Get the device of the module (assumes all parameters are on the same device)
|
| 38 |
+
return next(self.parameters()).device
|
| 39 |
+
|
| 40 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 41 |
+
x = self.in_proj(x)
|
| 42 |
+
|
| 43 |
+
for layer, norms in zip(self.layers, self.norms):
|
| 44 |
+
x = x + layer(norms(x))
|
| 45 |
+
|
| 46 |
+
x = self.out_proj(x)
|
| 47 |
+
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DoubleStreamBlock(nn.Module):
|
| 52 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 56 |
+
self.num_heads = num_heads
|
| 57 |
+
self.hidden_size = hidden_size
|
| 58 |
+
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 59 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
| 60 |
+
|
| 61 |
+
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 62 |
+
self.img_mlp = nn.Sequential(
|
| 63 |
+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
| 64 |
+
nn.GELU(approximate="tanh"),
|
| 65 |
+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 69 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
| 70 |
+
|
| 71 |
+
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 72 |
+
self.txt_mlp = nn.Sequential(
|
| 73 |
+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
| 74 |
+
nn.GELU(approximate="tanh"),
|
| 75 |
+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
| 76 |
+
)
|
| 77 |
+
self.flipped_img_txt = flipped_img_txt
|
| 78 |
+
|
| 79 |
+
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
|
| 80 |
+
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
| 81 |
+
|
| 82 |
+
# prepare image for attention
|
| 83 |
+
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
| 84 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 85 |
+
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 86 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 87 |
+
|
| 88 |
+
# prepare txt for attention
|
| 89 |
+
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
| 90 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 91 |
+
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 92 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 93 |
+
|
| 94 |
+
# run actual attention
|
| 95 |
+
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
| 96 |
+
torch.cat((txt_k, img_k), dim=2),
|
| 97 |
+
torch.cat((txt_v, img_v), dim=2),
|
| 98 |
+
pe=pe, mask=attn_mask)
|
| 99 |
+
|
| 100 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
| 101 |
+
|
| 102 |
+
# calculate the img bloks
|
| 103 |
+
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
| 104 |
+
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
| 105 |
+
|
| 106 |
+
# calculate the txt bloks
|
| 107 |
+
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
| 108 |
+
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
| 109 |
+
|
| 110 |
+
if txt.dtype == torch.float16:
|
| 111 |
+
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
| 112 |
+
|
| 113 |
+
return img, txt
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class SingleStreamBlock(nn.Module):
|
| 117 |
+
"""
|
| 118 |
+
A DiT block with parallel linear layers as described in
|
| 119 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
hidden_size: int,
|
| 125 |
+
num_heads: int,
|
| 126 |
+
mlp_ratio: float = 4.0,
|
| 127 |
+
qk_scale: float = None,
|
| 128 |
+
dtype=None,
|
| 129 |
+
device=None,
|
| 130 |
+
operations=None
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.hidden_dim = hidden_size
|
| 134 |
+
self.num_heads = num_heads
|
| 135 |
+
head_dim = hidden_size // num_heads
|
| 136 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 137 |
+
|
| 138 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 139 |
+
# qkv and mlp_in
|
| 140 |
+
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
| 141 |
+
# proj and mlp_out
|
| 142 |
+
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
| 143 |
+
|
| 144 |
+
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
| 145 |
+
|
| 146 |
+
self.hidden_size = hidden_size
|
| 147 |
+
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 148 |
+
|
| 149 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 150 |
+
|
| 151 |
+
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
|
| 152 |
+
mod = vec
|
| 153 |
+
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
| 154 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
| 155 |
+
|
| 156 |
+
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 157 |
+
q, k = self.norm(q, k, v)
|
| 158 |
+
|
| 159 |
+
# compute attention
|
| 160 |
+
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
| 161 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 162 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 163 |
+
x.addcmul_(mod.gate, output)
|
| 164 |
+
if x.dtype == torch.float16:
|
| 165 |
+
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class LastLayer(nn.Module):
|
| 170 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 173 |
+
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
| 174 |
+
|
| 175 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 176 |
+
shift, scale = vec
|
| 177 |
+
shift = shift.squeeze(1)
|
| 178 |
+
scale = scale.squeeze(1)
|
| 179 |
+
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
| 180 |
+
x = self.linear(x)
|
| 181 |
+
return x
|
ComfyUI/comfy/ldm/chroma/model.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Original code can be found on: https://github.com/black-forest-labs/flux
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
import comfy.ldm.common_dit
|
| 9 |
+
|
| 10 |
+
from comfy.ldm.flux.layers import (
|
| 11 |
+
EmbedND,
|
| 12 |
+
timestep_embedding,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from .layers import (
|
| 16 |
+
DoubleStreamBlock,
|
| 17 |
+
LastLayer,
|
| 18 |
+
SingleStreamBlock,
|
| 19 |
+
Approximator,
|
| 20 |
+
ChromaModulationOut,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class ChromaParams:
|
| 26 |
+
in_channels: int
|
| 27 |
+
out_channels: int
|
| 28 |
+
context_in_dim: int
|
| 29 |
+
hidden_size: int
|
| 30 |
+
mlp_ratio: float
|
| 31 |
+
num_heads: int
|
| 32 |
+
depth: int
|
| 33 |
+
depth_single_blocks: int
|
| 34 |
+
axes_dim: list
|
| 35 |
+
theta: int
|
| 36 |
+
patch_size: int
|
| 37 |
+
qkv_bias: bool
|
| 38 |
+
in_dim: int
|
| 39 |
+
out_dim: int
|
| 40 |
+
hidden_dim: int
|
| 41 |
+
n_layers: int
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Chroma(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Transformer model for flow matching on sequences.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.dtype = dtype
|
| 54 |
+
params = ChromaParams(**kwargs)
|
| 55 |
+
self.params = params
|
| 56 |
+
self.patch_size = params.patch_size
|
| 57 |
+
self.in_channels = params.in_channels
|
| 58 |
+
self.out_channels = params.out_channels
|
| 59 |
+
if params.hidden_size % params.num_heads != 0:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
| 62 |
+
)
|
| 63 |
+
pe_dim = params.hidden_size // params.num_heads
|
| 64 |
+
if sum(params.axes_dim) != pe_dim:
|
| 65 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
| 66 |
+
self.hidden_size = params.hidden_size
|
| 67 |
+
self.num_heads = params.num_heads
|
| 68 |
+
self.in_dim = params.in_dim
|
| 69 |
+
self.out_dim = params.out_dim
|
| 70 |
+
self.hidden_dim = params.hidden_dim
|
| 71 |
+
self.n_layers = params.n_layers
|
| 72 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
| 73 |
+
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
| 74 |
+
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
| 75 |
+
# set as nn identity for now, will overwrite it later.
|
| 76 |
+
self.distilled_guidance_layer = Approximator(
|
| 77 |
+
in_dim=self.in_dim,
|
| 78 |
+
hidden_dim=self.hidden_dim,
|
| 79 |
+
out_dim=self.out_dim,
|
| 80 |
+
n_layers=self.n_layers,
|
| 81 |
+
dtype=dtype, device=device, operations=operations
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
self.double_blocks = nn.ModuleList(
|
| 86 |
+
[
|
| 87 |
+
DoubleStreamBlock(
|
| 88 |
+
self.hidden_size,
|
| 89 |
+
self.num_heads,
|
| 90 |
+
mlp_ratio=params.mlp_ratio,
|
| 91 |
+
qkv_bias=params.qkv_bias,
|
| 92 |
+
dtype=dtype, device=device, operations=operations
|
| 93 |
+
)
|
| 94 |
+
for _ in range(params.depth)
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.single_blocks = nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
| 101 |
+
for _ in range(params.depth_single_blocks)
|
| 102 |
+
]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if final_layer:
|
| 106 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
| 107 |
+
|
| 108 |
+
self.skip_mmdit = []
|
| 109 |
+
self.skip_dit = []
|
| 110 |
+
self.lite = False
|
| 111 |
+
|
| 112 |
+
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
|
| 113 |
+
# This function slices up the modulations tensor which has the following layout:
|
| 114 |
+
# single : num_single_blocks * 3 elements
|
| 115 |
+
# double_img : num_double_blocks * 6 elements
|
| 116 |
+
# double_txt : num_double_blocks * 6 elements
|
| 117 |
+
# final : 2 elements
|
| 118 |
+
if block_type == "final":
|
| 119 |
+
return (tensor[:, -2:-1, :], tensor[:, -1:, :])
|
| 120 |
+
single_block_count = self.params.depth_single_blocks
|
| 121 |
+
double_block_count = self.params.depth
|
| 122 |
+
offset = 3 * idx
|
| 123 |
+
if block_type == "single":
|
| 124 |
+
return ChromaModulationOut.from_offset(tensor, offset)
|
| 125 |
+
# Double block modulations are 6 elements so we double 3 * idx.
|
| 126 |
+
offset *= 2
|
| 127 |
+
if block_type in {"double_img", "double_txt"}:
|
| 128 |
+
# Advance past the single block modulations.
|
| 129 |
+
offset += 3 * single_block_count
|
| 130 |
+
if block_type == "double_txt":
|
| 131 |
+
# Advance past the double block img modulations.
|
| 132 |
+
offset += 6 * double_block_count
|
| 133 |
+
return (
|
| 134 |
+
ChromaModulationOut.from_offset(tensor, offset),
|
| 135 |
+
ChromaModulationOut.from_offset(tensor, offset + 3),
|
| 136 |
+
)
|
| 137 |
+
raise ValueError("Bad block_type")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def forward_orig(
|
| 141 |
+
self,
|
| 142 |
+
img: Tensor,
|
| 143 |
+
img_ids: Tensor,
|
| 144 |
+
txt: Tensor,
|
| 145 |
+
txt_ids: Tensor,
|
| 146 |
+
timesteps: Tensor,
|
| 147 |
+
guidance: Tensor = None,
|
| 148 |
+
control = None,
|
| 149 |
+
transformer_options={},
|
| 150 |
+
attn_mask: Tensor = None,
|
| 151 |
+
) -> Tensor:
|
| 152 |
+
patches_replace = transformer_options.get("patches_replace", {})
|
| 153 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 154 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 155 |
+
|
| 156 |
+
# running on sequences img
|
| 157 |
+
img = self.img_in(img)
|
| 158 |
+
|
| 159 |
+
# distilled vector guidance
|
| 160 |
+
mod_index_length = 344
|
| 161 |
+
distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
|
| 162 |
+
# guidance = guidance *
|
| 163 |
+
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
|
| 164 |
+
|
| 165 |
+
# get all modulation index
|
| 166 |
+
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
|
| 167 |
+
# we need to broadcast the modulation index here so each batch has all of the index
|
| 168 |
+
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
|
| 169 |
+
# and we need to broadcast timestep and guidance along too
|
| 170 |
+
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
|
| 171 |
+
# then and only then we could concatenate it together
|
| 172 |
+
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
|
| 173 |
+
|
| 174 |
+
mod_vectors = self.distilled_guidance_layer(input_vec)
|
| 175 |
+
|
| 176 |
+
txt = self.txt_in(txt)
|
| 177 |
+
|
| 178 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 179 |
+
pe = self.pe_embedder(ids)
|
| 180 |
+
|
| 181 |
+
blocks_replace = patches_replace.get("dit", {})
|
| 182 |
+
for i, block in enumerate(self.double_blocks):
|
| 183 |
+
if i not in self.skip_mmdit:
|
| 184 |
+
double_mod = (
|
| 185 |
+
self.get_modulations(mod_vectors, "double_img", idx=i),
|
| 186 |
+
self.get_modulations(mod_vectors, "double_txt", idx=i),
|
| 187 |
+
)
|
| 188 |
+
if ("double_block", i) in blocks_replace:
|
| 189 |
+
def block_wrap(args):
|
| 190 |
+
out = {}
|
| 191 |
+
out["img"], out["txt"] = block(img=args["img"],
|
| 192 |
+
txt=args["txt"],
|
| 193 |
+
vec=args["vec"],
|
| 194 |
+
pe=args["pe"],
|
| 195 |
+
attn_mask=args.get("attn_mask"))
|
| 196 |
+
return out
|
| 197 |
+
|
| 198 |
+
out = blocks_replace[("double_block", i)]({"img": img,
|
| 199 |
+
"txt": txt,
|
| 200 |
+
"vec": double_mod,
|
| 201 |
+
"pe": pe,
|
| 202 |
+
"attn_mask": attn_mask},
|
| 203 |
+
{"original_block": block_wrap})
|
| 204 |
+
txt = out["txt"]
|
| 205 |
+
img = out["img"]
|
| 206 |
+
else:
|
| 207 |
+
img, txt = block(img=img,
|
| 208 |
+
txt=txt,
|
| 209 |
+
vec=double_mod,
|
| 210 |
+
pe=pe,
|
| 211 |
+
attn_mask=attn_mask)
|
| 212 |
+
|
| 213 |
+
if control is not None: # Controlnet
|
| 214 |
+
control_i = control.get("input")
|
| 215 |
+
if i < len(control_i):
|
| 216 |
+
add = control_i[i]
|
| 217 |
+
if add is not None:
|
| 218 |
+
img += add
|
| 219 |
+
|
| 220 |
+
img = torch.cat((txt, img), 1)
|
| 221 |
+
|
| 222 |
+
for i, block in enumerate(self.single_blocks):
|
| 223 |
+
if i not in self.skip_dit:
|
| 224 |
+
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
| 225 |
+
if ("single_block", i) in blocks_replace:
|
| 226 |
+
def block_wrap(args):
|
| 227 |
+
out = {}
|
| 228 |
+
out["img"] = block(args["img"],
|
| 229 |
+
vec=args["vec"],
|
| 230 |
+
pe=args["pe"],
|
| 231 |
+
attn_mask=args.get("attn_mask"))
|
| 232 |
+
return out
|
| 233 |
+
|
| 234 |
+
out = blocks_replace[("single_block", i)]({"img": img,
|
| 235 |
+
"vec": single_mod,
|
| 236 |
+
"pe": pe,
|
| 237 |
+
"attn_mask": attn_mask},
|
| 238 |
+
{"original_block": block_wrap})
|
| 239 |
+
img = out["img"]
|
| 240 |
+
else:
|
| 241 |
+
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
|
| 242 |
+
|
| 243 |
+
if control is not None: # Controlnet
|
| 244 |
+
control_o = control.get("output")
|
| 245 |
+
if i < len(control_o):
|
| 246 |
+
add = control_o[i]
|
| 247 |
+
if add is not None:
|
| 248 |
+
img[:, txt.shape[1] :, ...] += add
|
| 249 |
+
|
| 250 |
+
img = img[:, txt.shape[1] :, ...]
|
| 251 |
+
final_mod = self.get_modulations(mod_vectors, "final")
|
| 252 |
+
img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
| 253 |
+
return img
|
| 254 |
+
|
| 255 |
+
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
| 256 |
+
bs, c, h, w = x.shape
|
| 257 |
+
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
| 258 |
+
|
| 259 |
+
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
|
| 260 |
+
|
| 261 |
+
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
|
| 262 |
+
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
|
| 263 |
+
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
| 264 |
+
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
| 265 |
+
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
| 266 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
| 267 |
+
|
| 268 |
+
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
| 269 |
+
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
| 270 |
+
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]
|
ComfyUI/comfy/ldm/cosmos/blocks.py
ADDED
|
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from typing import Optional
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from einops import rearrange, repeat
|
| 23 |
+
from einops.layers.torch import Rearrange
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
| 30 |
+
if name == "I":
|
| 31 |
+
return nn.Identity()
|
| 32 |
+
elif name == "R":
|
| 33 |
+
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f"Normalization {name} not found")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BaseAttentionOp(nn.Module):
|
| 39 |
+
def __init__(self):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Attention(nn.Module):
|
| 44 |
+
"""
|
| 45 |
+
Generalized attention impl.
|
| 46 |
+
|
| 47 |
+
Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided.
|
| 48 |
+
If `context_dim` is None, self-attention is assumed.
|
| 49 |
+
|
| 50 |
+
Parameters:
|
| 51 |
+
query_dim (int): Dimension of each query vector.
|
| 52 |
+
context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed.
|
| 53 |
+
heads (int, optional): Number of attention heads. Defaults to 8.
|
| 54 |
+
dim_head (int, optional): Dimension of each head. Defaults to 64.
|
| 55 |
+
dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0.
|
| 56 |
+
attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default.
|
| 57 |
+
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
|
| 58 |
+
out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False.
|
| 59 |
+
qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections.
|
| 60 |
+
Defaults to "SSI".
|
| 61 |
+
qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections.
|
| 62 |
+
Defaults to 'per_head'. Only support 'per_head'.
|
| 63 |
+
|
| 64 |
+
Examples:
|
| 65 |
+
>>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1)
|
| 66 |
+
>>> query = torch.randn(10, 128) # Batch size of 10
|
| 67 |
+
>>> context = torch.randn(10, 256) # Batch size of 10
|
| 68 |
+
>>> output = attn(query, context) # Perform the attention operation
|
| 69 |
+
|
| 70 |
+
Note:
|
| 71 |
+
https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
query_dim: int,
|
| 77 |
+
context_dim=None,
|
| 78 |
+
heads=8,
|
| 79 |
+
dim_head=64,
|
| 80 |
+
dropout=0.0,
|
| 81 |
+
attn_op: Optional[BaseAttentionOp] = None,
|
| 82 |
+
qkv_bias: bool = False,
|
| 83 |
+
out_bias: bool = False,
|
| 84 |
+
qkv_norm: str = "SSI",
|
| 85 |
+
qkv_norm_mode: str = "per_head",
|
| 86 |
+
backend: str = "transformer_engine",
|
| 87 |
+
qkv_format: str = "bshd",
|
| 88 |
+
weight_args={},
|
| 89 |
+
operations=None,
|
| 90 |
+
) -> None:
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
self.is_selfattn = context_dim is None # self attention
|
| 94 |
+
|
| 95 |
+
inner_dim = dim_head * heads
|
| 96 |
+
context_dim = query_dim if context_dim is None else context_dim
|
| 97 |
+
|
| 98 |
+
self.heads = heads
|
| 99 |
+
self.dim_head = dim_head
|
| 100 |
+
self.qkv_norm_mode = qkv_norm_mode
|
| 101 |
+
self.qkv_format = qkv_format
|
| 102 |
+
|
| 103 |
+
if self.qkv_norm_mode == "per_head":
|
| 104 |
+
norm_dim = dim_head
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
| 107 |
+
|
| 108 |
+
self.backend = backend
|
| 109 |
+
|
| 110 |
+
self.to_q = nn.Sequential(
|
| 111 |
+
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
|
| 112 |
+
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
|
| 113 |
+
)
|
| 114 |
+
self.to_k = nn.Sequential(
|
| 115 |
+
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
| 116 |
+
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
|
| 117 |
+
)
|
| 118 |
+
self.to_v = nn.Sequential(
|
| 119 |
+
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
|
| 120 |
+
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.to_out = nn.Sequential(
|
| 124 |
+
operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args),
|
| 125 |
+
nn.Dropout(dropout),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def cal_qkv(
|
| 129 |
+
self, x, context=None, mask=None, rope_emb=None, **kwargs
|
| 130 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 131 |
+
del kwargs
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
"""
|
| 135 |
+
self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers.
|
| 136 |
+
Before 07/24/2024, these modules normalize across all heads.
|
| 137 |
+
After 07/24/2024, to support tensor parallelism and follow the common practice in the community,
|
| 138 |
+
we support to normalize per head.
|
| 139 |
+
To keep the checkpoint copatibility with the previous code,
|
| 140 |
+
we keep the nn.Sequential but call the projection and the normalization layers separately.
|
| 141 |
+
We use a flag `self.qkv_norm_mode` to control the normalization behavior.
|
| 142 |
+
The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head.
|
| 143 |
+
"""
|
| 144 |
+
if self.qkv_norm_mode == "per_head":
|
| 145 |
+
q = self.to_q[0](x)
|
| 146 |
+
context = x if context is None else context
|
| 147 |
+
k = self.to_k[0](context)
|
| 148 |
+
v = self.to_v[0](context)
|
| 149 |
+
q, k, v = map(
|
| 150 |
+
lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
|
| 151 |
+
(q, k, v),
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")
|
| 155 |
+
|
| 156 |
+
q = self.to_q[1](q)
|
| 157 |
+
k = self.to_k[1](k)
|
| 158 |
+
v = self.to_v[1](v)
|
| 159 |
+
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
| 160 |
+
# apply_rotary_pos_emb inlined
|
| 161 |
+
q_shape = q.shape
|
| 162 |
+
q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
| 163 |
+
q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
|
| 164 |
+
q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
|
| 165 |
+
|
| 166 |
+
# apply_rotary_pos_emb inlined
|
| 167 |
+
k_shape = k.shape
|
| 168 |
+
k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
|
| 169 |
+
k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
|
| 170 |
+
k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
|
| 171 |
+
return q, k, v
|
| 172 |
+
|
| 173 |
+
def forward(
|
| 174 |
+
self,
|
| 175 |
+
x,
|
| 176 |
+
context=None,
|
| 177 |
+
mask=None,
|
| 178 |
+
rope_emb=None,
|
| 179 |
+
**kwargs,
|
| 180 |
+
):
|
| 181 |
+
"""
|
| 182 |
+
Args:
|
| 183 |
+
x (Tensor): The query tensor of shape [B, Mq, K]
|
| 184 |
+
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
| 185 |
+
"""
|
| 186 |
+
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
|
| 187 |
+
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
|
| 188 |
+
del q, k, v
|
| 189 |
+
out = rearrange(out, " b n s c -> s b (n c)")
|
| 190 |
+
return self.to_out(out)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class FeedForward(nn.Module):
|
| 194 |
+
"""
|
| 195 |
+
Transformer FFN with optional gating
|
| 196 |
+
|
| 197 |
+
Parameters:
|
| 198 |
+
d_model (int): Dimensionality of input features.
|
| 199 |
+
d_ff (int): Dimensionality of the hidden layer.
|
| 200 |
+
dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1.
|
| 201 |
+
activation (callable, optional): The activation function applied after the first linear layer.
|
| 202 |
+
Defaults to nn.ReLU().
|
| 203 |
+
is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer.
|
| 204 |
+
Defaults to False.
|
| 205 |
+
bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True.
|
| 206 |
+
|
| 207 |
+
Example:
|
| 208 |
+
>>> ff = FeedForward(d_model=512, d_ff=2048)
|
| 209 |
+
>>> x = torch.randn(64, 10, 512) # Example input tensor
|
| 210 |
+
>>> output = ff(x)
|
| 211 |
+
>>> print(output.shape) # Expected shape: (64, 10, 512)
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
d_model: int,
|
| 217 |
+
d_ff: int,
|
| 218 |
+
dropout: float = 0.1,
|
| 219 |
+
activation=nn.ReLU(),
|
| 220 |
+
is_gated: bool = False,
|
| 221 |
+
bias: bool = False,
|
| 222 |
+
weight_args={},
|
| 223 |
+
operations=None,
|
| 224 |
+
) -> None:
|
| 225 |
+
super().__init__()
|
| 226 |
+
|
| 227 |
+
self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args)
|
| 228 |
+
self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args)
|
| 229 |
+
|
| 230 |
+
self.dropout = nn.Dropout(dropout)
|
| 231 |
+
self.activation = activation
|
| 232 |
+
self.is_gated = is_gated
|
| 233 |
+
if is_gated:
|
| 234 |
+
self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args)
|
| 235 |
+
|
| 236 |
+
def forward(self, x: torch.Tensor):
|
| 237 |
+
g = self.activation(self.layer1(x))
|
| 238 |
+
if self.is_gated:
|
| 239 |
+
x = g * self.linear_gate(x)
|
| 240 |
+
else:
|
| 241 |
+
x = g
|
| 242 |
+
assert self.dropout.p == 0.0, "we skip dropout"
|
| 243 |
+
return self.layer2(x)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class GPT2FeedForward(FeedForward):
|
| 247 |
+
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None):
|
| 248 |
+
super().__init__(
|
| 249 |
+
d_model=d_model,
|
| 250 |
+
d_ff=d_ff,
|
| 251 |
+
dropout=dropout,
|
| 252 |
+
activation=nn.GELU(),
|
| 253 |
+
is_gated=False,
|
| 254 |
+
bias=bias,
|
| 255 |
+
weight_args=weight_args,
|
| 256 |
+
operations=operations,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def forward(self, x: torch.Tensor):
|
| 260 |
+
assert self.dropout.p == 0.0, "we skip dropout"
|
| 261 |
+
|
| 262 |
+
x = self.layer1(x)
|
| 263 |
+
x = self.activation(x)
|
| 264 |
+
x = self.layer2(x)
|
| 265 |
+
|
| 266 |
+
return x
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def modulate(x, shift, scale):
|
| 270 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class Timesteps(nn.Module):
|
| 274 |
+
def __init__(self, num_channels):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.num_channels = num_channels
|
| 277 |
+
|
| 278 |
+
def forward(self, timesteps):
|
| 279 |
+
half_dim = self.num_channels // 2
|
| 280 |
+
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
| 281 |
+
exponent = exponent / (half_dim - 0.0)
|
| 282 |
+
|
| 283 |
+
emb = torch.exp(exponent)
|
| 284 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 285 |
+
|
| 286 |
+
sin_emb = torch.sin(emb)
|
| 287 |
+
cos_emb = torch.cos(emb)
|
| 288 |
+
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
| 289 |
+
|
| 290 |
+
return emb
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class TimestepEmbedding(nn.Module):
|
| 294 |
+
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
|
| 295 |
+
super().__init__()
|
| 296 |
+
logging.debug(
|
| 297 |
+
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
| 298 |
+
)
|
| 299 |
+
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
|
| 300 |
+
self.activation = nn.SiLU()
|
| 301 |
+
self.use_adaln_lora = use_adaln_lora
|
| 302 |
+
if use_adaln_lora:
|
| 303 |
+
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args)
|
| 304 |
+
else:
|
| 305 |
+
self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args)
|
| 306 |
+
|
| 307 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
| 308 |
+
emb = self.linear_1(sample)
|
| 309 |
+
emb = self.activation(emb)
|
| 310 |
+
emb = self.linear_2(emb)
|
| 311 |
+
|
| 312 |
+
if self.use_adaln_lora:
|
| 313 |
+
adaln_lora_B_3D = emb
|
| 314 |
+
emb_B_D = sample
|
| 315 |
+
else:
|
| 316 |
+
emb_B_D = emb
|
| 317 |
+
adaln_lora_B_3D = None
|
| 318 |
+
|
| 319 |
+
return emb_B_D, adaln_lora_B_3D
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class FourierFeatures(nn.Module):
|
| 323 |
+
"""
|
| 324 |
+
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
|
| 325 |
+
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
|
| 326 |
+
|
| 327 |
+
[B] -> [B, D]
|
| 328 |
+
|
| 329 |
+
Parameters:
|
| 330 |
+
num_channels (int): The number of Fourier features to generate.
|
| 331 |
+
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
|
| 332 |
+
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
|
| 333 |
+
the variance of the features. Defaults to False.
|
| 334 |
+
|
| 335 |
+
Example:
|
| 336 |
+
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
|
| 337 |
+
>>> x = torch.randn(10, 256) # Example input tensor
|
| 338 |
+
>>> output = layer(x)
|
| 339 |
+
>>> print(output.shape) # Expected shape: (10, 256)
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, num_channels, bandwidth=1, normalize=False):
|
| 343 |
+
super().__init__()
|
| 344 |
+
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
| 345 |
+
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
| 346 |
+
self.gain = np.sqrt(2) if normalize else 1
|
| 347 |
+
|
| 348 |
+
def forward(self, x, gain: float = 1.0):
|
| 349 |
+
"""
|
| 350 |
+
Apply the Fourier feature transformation to the input tensor.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
x (torch.Tensor): The input tensor.
|
| 354 |
+
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
torch.Tensor: The transformed tensor, with Fourier features applied.
|
| 358 |
+
"""
|
| 359 |
+
in_dtype = x.dtype
|
| 360 |
+
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
| 361 |
+
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
| 362 |
+
return x
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class PatchEmbed(nn.Module):
|
| 366 |
+
"""
|
| 367 |
+
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
| 368 |
+
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
| 369 |
+
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
| 370 |
+
and embedding each patch into a vector of size `out_channels`.
|
| 371 |
+
|
| 372 |
+
Parameters:
|
| 373 |
+
- spatial_patch_size (int): The size of each spatial patch.
|
| 374 |
+
- temporal_patch_size (int): The size of each temporal patch.
|
| 375 |
+
- in_channels (int): Number of input channels. Default: 3.
|
| 376 |
+
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
| 377 |
+
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
spatial_patch_size,
|
| 383 |
+
temporal_patch_size,
|
| 384 |
+
in_channels=3,
|
| 385 |
+
out_channels=768,
|
| 386 |
+
bias=True,
|
| 387 |
+
weight_args={},
|
| 388 |
+
operations=None,
|
| 389 |
+
):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.spatial_patch_size = spatial_patch_size
|
| 392 |
+
self.temporal_patch_size = temporal_patch_size
|
| 393 |
+
|
| 394 |
+
self.proj = nn.Sequential(
|
| 395 |
+
Rearrange(
|
| 396 |
+
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
| 397 |
+
r=temporal_patch_size,
|
| 398 |
+
m=spatial_patch_size,
|
| 399 |
+
n=spatial_patch_size,
|
| 400 |
+
),
|
| 401 |
+
operations.Linear(
|
| 402 |
+
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args
|
| 403 |
+
),
|
| 404 |
+
)
|
| 405 |
+
self.out = nn.Identity()
|
| 406 |
+
|
| 407 |
+
def forward(self, x):
|
| 408 |
+
"""
|
| 409 |
+
Forward pass of the PatchEmbed module.
|
| 410 |
+
|
| 411 |
+
Parameters:
|
| 412 |
+
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
| 413 |
+
B is the batch size,
|
| 414 |
+
C is the number of channels,
|
| 415 |
+
T is the temporal dimension,
|
| 416 |
+
H is the height, and
|
| 417 |
+
W is the width of the input.
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
| 421 |
+
"""
|
| 422 |
+
assert x.dim() == 5
|
| 423 |
+
_, _, T, H, W = x.shape
|
| 424 |
+
assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
| 425 |
+
assert T % self.temporal_patch_size == 0
|
| 426 |
+
x = self.proj(x)
|
| 427 |
+
return self.out(x)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class FinalLayer(nn.Module):
|
| 431 |
+
"""
|
| 432 |
+
The final layer of video DiT.
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
hidden_size,
|
| 438 |
+
spatial_patch_size,
|
| 439 |
+
temporal_patch_size,
|
| 440 |
+
out_channels,
|
| 441 |
+
use_adaln_lora: bool = False,
|
| 442 |
+
adaln_lora_dim: int = 256,
|
| 443 |
+
weight_args={},
|
| 444 |
+
operations=None,
|
| 445 |
+
):
|
| 446 |
+
super().__init__()
|
| 447 |
+
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args)
|
| 448 |
+
self.linear = operations.Linear(
|
| 449 |
+
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args
|
| 450 |
+
)
|
| 451 |
+
self.hidden_size = hidden_size
|
| 452 |
+
self.n_adaln_chunks = 2
|
| 453 |
+
self.use_adaln_lora = use_adaln_lora
|
| 454 |
+
if use_adaln_lora:
|
| 455 |
+
self.adaLN_modulation = nn.Sequential(
|
| 456 |
+
nn.SiLU(),
|
| 457 |
+
operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args),
|
| 458 |
+
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args),
|
| 459 |
+
)
|
| 460 |
+
else:
|
| 461 |
+
self.adaLN_modulation = nn.Sequential(
|
| 462 |
+
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args)
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def forward(
|
| 466 |
+
self,
|
| 467 |
+
x_BT_HW_D,
|
| 468 |
+
emb_B_D,
|
| 469 |
+
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
| 470 |
+
):
|
| 471 |
+
if self.use_adaln_lora:
|
| 472 |
+
assert adaln_lora_B_3D is not None
|
| 473 |
+
shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk(
|
| 474 |
+
2, dim=1
|
| 475 |
+
)
|
| 476 |
+
else:
|
| 477 |
+
shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1)
|
| 478 |
+
|
| 479 |
+
B = emb_B_D.shape[0]
|
| 480 |
+
T = x_BT_HW_D.shape[0] // B
|
| 481 |
+
shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T)
|
| 482 |
+
x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D)
|
| 483 |
+
|
| 484 |
+
x_BT_HW_D = self.linear(x_BT_HW_D)
|
| 485 |
+
return x_BT_HW_D
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class VideoAttn(nn.Module):
|
| 489 |
+
"""
|
| 490 |
+
Implements video attention with optional cross-attention capabilities.
|
| 491 |
+
|
| 492 |
+
This module processes video features while maintaining their spatio-temporal structure. It can perform
|
| 493 |
+
self-attention within the video features or cross-attention with external context features.
|
| 494 |
+
|
| 495 |
+
Parameters:
|
| 496 |
+
x_dim (int): Dimension of input feature vectors
|
| 497 |
+
context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention
|
| 498 |
+
num_heads (int): Number of attention heads
|
| 499 |
+
bias (bool): Whether to include bias in attention projections. Default: False
|
| 500 |
+
qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head"
|
| 501 |
+
x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD"
|
| 502 |
+
|
| 503 |
+
Input shape:
|
| 504 |
+
- x: (T, H, W, B, D) video features
|
| 505 |
+
- context (optional): (M, B, D) context features for cross-attention
|
| 506 |
+
where:
|
| 507 |
+
T: temporal dimension
|
| 508 |
+
H: height
|
| 509 |
+
W: width
|
| 510 |
+
B: batch size
|
| 511 |
+
D: feature dimension
|
| 512 |
+
M: context sequence length
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
def __init__(
|
| 516 |
+
self,
|
| 517 |
+
x_dim: int,
|
| 518 |
+
context_dim: Optional[int],
|
| 519 |
+
num_heads: int,
|
| 520 |
+
bias: bool = False,
|
| 521 |
+
qkv_norm_mode: str = "per_head",
|
| 522 |
+
x_format: str = "BTHWD",
|
| 523 |
+
weight_args={},
|
| 524 |
+
operations=None,
|
| 525 |
+
) -> None:
|
| 526 |
+
super().__init__()
|
| 527 |
+
self.x_format = x_format
|
| 528 |
+
|
| 529 |
+
self.attn = Attention(
|
| 530 |
+
x_dim,
|
| 531 |
+
context_dim,
|
| 532 |
+
num_heads,
|
| 533 |
+
x_dim // num_heads,
|
| 534 |
+
qkv_bias=bias,
|
| 535 |
+
qkv_norm="RRI",
|
| 536 |
+
out_bias=bias,
|
| 537 |
+
qkv_norm_mode=qkv_norm_mode,
|
| 538 |
+
qkv_format="sbhd",
|
| 539 |
+
weight_args=weight_args,
|
| 540 |
+
operations=operations,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
def forward(
|
| 544 |
+
self,
|
| 545 |
+
x: torch.Tensor,
|
| 546 |
+
context: Optional[torch.Tensor] = None,
|
| 547 |
+
crossattn_mask: Optional[torch.Tensor] = None,
|
| 548 |
+
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
| 549 |
+
) -> torch.Tensor:
|
| 550 |
+
"""
|
| 551 |
+
Forward pass for video attention.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data.
|
| 555 |
+
context (Tensor): Context tensor of shape (B, M, D) or (M, B, D),
|
| 556 |
+
where M is the sequence length of the context.
|
| 557 |
+
crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms.
|
| 558 |
+
rope_emb_L_1_1_D (Optional[Tensor]):
|
| 559 |
+
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
| 560 |
+
|
| 561 |
+
Returns:
|
| 562 |
+
Tensor: The output tensor with applied attention, maintaining the input shape.
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
x_T_H_W_B_D = x
|
| 566 |
+
context_M_B_D = context
|
| 567 |
+
T, H, W, B, D = x_T_H_W_B_D.shape
|
| 568 |
+
x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d")
|
| 569 |
+
x_THW_B_D = self.attn(
|
| 570 |
+
x_THW_B_D,
|
| 571 |
+
context_M_B_D,
|
| 572 |
+
crossattn_mask,
|
| 573 |
+
rope_emb=rope_emb_L_1_1_D,
|
| 574 |
+
)
|
| 575 |
+
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
|
| 576 |
+
return x_T_H_W_B_D
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def adaln_norm_state(norm_state, x, scale, shift):
|
| 580 |
+
normalized = norm_state(x)
|
| 581 |
+
return normalized * (1 + scale) + shift
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class DITBuildingBlock(nn.Module):
|
| 585 |
+
"""
|
| 586 |
+
A building block for the DiT (Diffusion Transformer) architecture that supports different types of
|
| 587 |
+
attention and MLP operations with adaptive layer normalization.
|
| 588 |
+
|
| 589 |
+
Parameters:
|
| 590 |
+
block_type (str): Type of block - one of:
|
| 591 |
+
- "cross_attn"/"ca": Cross-attention
|
| 592 |
+
- "full_attn"/"fa": Full self-attention
|
| 593 |
+
- "mlp"/"ff": MLP/feedforward block
|
| 594 |
+
x_dim (int): Dimension of input features
|
| 595 |
+
context_dim (Optional[int]): Dimension of context features for cross-attention
|
| 596 |
+
num_heads (int): Number of attention heads
|
| 597 |
+
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
| 598 |
+
bias (bool): Whether to use bias in layers. Default: False
|
| 599 |
+
mlp_dropout (float): Dropout rate for MLP. Default: 0.0
|
| 600 |
+
qkv_norm_mode (str): QKV normalization mode. Default: "per_head"
|
| 601 |
+
x_format (str): Input tensor format. Default: "BTHWD"
|
| 602 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
| 603 |
+
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
| 604 |
+
"""
|
| 605 |
+
|
| 606 |
+
def __init__(
|
| 607 |
+
self,
|
| 608 |
+
block_type: str,
|
| 609 |
+
x_dim: int,
|
| 610 |
+
context_dim: Optional[int],
|
| 611 |
+
num_heads: int,
|
| 612 |
+
mlp_ratio: float = 4.0,
|
| 613 |
+
bias: bool = False,
|
| 614 |
+
mlp_dropout: float = 0.0,
|
| 615 |
+
qkv_norm_mode: str = "per_head",
|
| 616 |
+
x_format: str = "BTHWD",
|
| 617 |
+
use_adaln_lora: bool = False,
|
| 618 |
+
adaln_lora_dim: int = 256,
|
| 619 |
+
weight_args={},
|
| 620 |
+
operations=None
|
| 621 |
+
) -> None:
|
| 622 |
+
block_type = block_type.lower()
|
| 623 |
+
|
| 624 |
+
super().__init__()
|
| 625 |
+
self.x_format = x_format
|
| 626 |
+
if block_type in ["cross_attn", "ca"]:
|
| 627 |
+
self.block = VideoAttn(
|
| 628 |
+
x_dim,
|
| 629 |
+
context_dim,
|
| 630 |
+
num_heads,
|
| 631 |
+
bias=bias,
|
| 632 |
+
qkv_norm_mode=qkv_norm_mode,
|
| 633 |
+
x_format=self.x_format,
|
| 634 |
+
weight_args=weight_args,
|
| 635 |
+
operations=operations,
|
| 636 |
+
)
|
| 637 |
+
elif block_type in ["full_attn", "fa"]:
|
| 638 |
+
self.block = VideoAttn(
|
| 639 |
+
x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations
|
| 640 |
+
)
|
| 641 |
+
elif block_type in ["mlp", "ff"]:
|
| 642 |
+
self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations)
|
| 643 |
+
else:
|
| 644 |
+
raise ValueError(f"Unknown block type: {block_type}")
|
| 645 |
+
|
| 646 |
+
self.block_type = block_type
|
| 647 |
+
self.use_adaln_lora = use_adaln_lora
|
| 648 |
+
|
| 649 |
+
self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
| 650 |
+
self.n_adaln_chunks = 3
|
| 651 |
+
if use_adaln_lora:
|
| 652 |
+
self.adaLN_modulation = nn.Sequential(
|
| 653 |
+
nn.SiLU(),
|
| 654 |
+
operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args),
|
| 655 |
+
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args),
|
| 656 |
+
)
|
| 657 |
+
else:
|
| 658 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args))
|
| 659 |
+
|
| 660 |
+
def forward(
|
| 661 |
+
self,
|
| 662 |
+
x: torch.Tensor,
|
| 663 |
+
emb_B_D: torch.Tensor,
|
| 664 |
+
crossattn_emb: torch.Tensor,
|
| 665 |
+
crossattn_mask: Optional[torch.Tensor] = None,
|
| 666 |
+
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
| 667 |
+
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
| 668 |
+
) -> torch.Tensor:
|
| 669 |
+
"""
|
| 670 |
+
Forward pass for dynamically configured blocks with adaptive normalization.
|
| 671 |
+
|
| 672 |
+
Args:
|
| 673 |
+
x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D).
|
| 674 |
+
emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation.
|
| 675 |
+
crossattn_emb (Tensor): Tensor for cross-attention blocks.
|
| 676 |
+
crossattn_mask (Optional[Tensor]): Optional mask for cross-attention.
|
| 677 |
+
rope_emb_L_1_1_D (Optional[Tensor]):
|
| 678 |
+
Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training.
|
| 679 |
+
|
| 680 |
+
Returns:
|
| 681 |
+
Tensor: The output tensor after processing through the configured block and adaptive normalization.
|
| 682 |
+
"""
|
| 683 |
+
if self.use_adaln_lora:
|
| 684 |
+
shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk(
|
| 685 |
+
self.n_adaln_chunks, dim=1
|
| 686 |
+
)
|
| 687 |
+
else:
|
| 688 |
+
shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1)
|
| 689 |
+
|
| 690 |
+
shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = (
|
| 691 |
+
shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
| 692 |
+
scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
| 693 |
+
gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0),
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
if self.block_type in ["mlp", "ff"]:
|
| 697 |
+
x = x + gate_1_1_1_B_D * self.block(
|
| 698 |
+
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
| 699 |
+
)
|
| 700 |
+
elif self.block_type in ["full_attn", "fa"]:
|
| 701 |
+
x = x + gate_1_1_1_B_D * self.block(
|
| 702 |
+
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
| 703 |
+
context=None,
|
| 704 |
+
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
| 705 |
+
)
|
| 706 |
+
elif self.block_type in ["cross_attn", "ca"]:
|
| 707 |
+
x = x + gate_1_1_1_B_D * self.block(
|
| 708 |
+
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
|
| 709 |
+
context=crossattn_emb,
|
| 710 |
+
crossattn_mask=crossattn_mask,
|
| 711 |
+
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
raise ValueError(f"Unknown block type: {self.block_type}")
|
| 715 |
+
|
| 716 |
+
return x
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
class GeneralDITTransformerBlock(nn.Module):
|
| 720 |
+
"""
|
| 721 |
+
A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer.
|
| 722 |
+
Each block in the sequence is specified by a block configuration string.
|
| 723 |
+
|
| 724 |
+
Parameters:
|
| 725 |
+
x_dim (int): Dimension of input features
|
| 726 |
+
context_dim (int): Dimension of context features for cross-attention blocks
|
| 727 |
+
num_heads (int): Number of attention heads
|
| 728 |
+
block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention,
|
| 729 |
+
full-attention, then MLP)
|
| 730 |
+
mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0
|
| 731 |
+
x_format (str): Input tensor format. Default: "BTHWD"
|
| 732 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False
|
| 733 |
+
adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256
|
| 734 |
+
|
| 735 |
+
The block_config string uses "-" to separate block types:
|
| 736 |
+
- "ca"/"cross_attn": Cross-attention block
|
| 737 |
+
- "fa"/"full_attn": Full self-attention block
|
| 738 |
+
- "mlp"/"ff": MLP/feedforward block
|
| 739 |
+
|
| 740 |
+
Example:
|
| 741 |
+
block_config = "ca-fa-mlp" creates a sequence of:
|
| 742 |
+
1. Cross-attention block
|
| 743 |
+
2. Full self-attention block
|
| 744 |
+
3. MLP block
|
| 745 |
+
"""
|
| 746 |
+
|
| 747 |
+
def __init__(
|
| 748 |
+
self,
|
| 749 |
+
x_dim: int,
|
| 750 |
+
context_dim: int,
|
| 751 |
+
num_heads: int,
|
| 752 |
+
block_config: str,
|
| 753 |
+
mlp_ratio: float = 4.0,
|
| 754 |
+
x_format: str = "BTHWD",
|
| 755 |
+
use_adaln_lora: bool = False,
|
| 756 |
+
adaln_lora_dim: int = 256,
|
| 757 |
+
weight_args={},
|
| 758 |
+
operations=None
|
| 759 |
+
):
|
| 760 |
+
super().__init__()
|
| 761 |
+
self.blocks = nn.ModuleList()
|
| 762 |
+
self.x_format = x_format
|
| 763 |
+
for block_type in block_config.split("-"):
|
| 764 |
+
self.blocks.append(
|
| 765 |
+
DITBuildingBlock(
|
| 766 |
+
block_type,
|
| 767 |
+
x_dim,
|
| 768 |
+
context_dim,
|
| 769 |
+
num_heads,
|
| 770 |
+
mlp_ratio,
|
| 771 |
+
x_format=self.x_format,
|
| 772 |
+
use_adaln_lora=use_adaln_lora,
|
| 773 |
+
adaln_lora_dim=adaln_lora_dim,
|
| 774 |
+
weight_args=weight_args,
|
| 775 |
+
operations=operations,
|
| 776 |
+
)
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
def forward(
|
| 780 |
+
self,
|
| 781 |
+
x: torch.Tensor,
|
| 782 |
+
emb_B_D: torch.Tensor,
|
| 783 |
+
crossattn_emb: torch.Tensor,
|
| 784 |
+
crossattn_mask: Optional[torch.Tensor] = None,
|
| 785 |
+
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
| 786 |
+
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
| 787 |
+
) -> torch.Tensor:
|
| 788 |
+
for block in self.blocks:
|
| 789 |
+
x = block(
|
| 790 |
+
x,
|
| 791 |
+
emb_B_D,
|
| 792 |
+
crossattn_emb,
|
| 793 |
+
crossattn_mask,
|
| 794 |
+
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
| 795 |
+
adaln_lora_B_3D=adaln_lora_B_3D,
|
| 796 |
+
)
|
| 797 |
+
return x
|
ComfyUI/comfy/ldm/cosmos/model.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from typing import Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torchvision import transforms
|
| 26 |
+
|
| 27 |
+
from enum import Enum
|
| 28 |
+
import logging
|
| 29 |
+
|
| 30 |
+
from .blocks import (
|
| 31 |
+
FinalLayer,
|
| 32 |
+
GeneralDITTransformerBlock,
|
| 33 |
+
PatchEmbed,
|
| 34 |
+
TimestepEmbedding,
|
| 35 |
+
Timesteps,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DataType(Enum):
|
| 42 |
+
IMAGE = "image"
|
| 43 |
+
VIDEO = "video"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GeneralDIT(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
max_img_h (int): Maximum height of the input images.
|
| 52 |
+
max_img_w (int): Maximum width of the input images.
|
| 53 |
+
max_frames (int): Maximum number of frames in the video sequence.
|
| 54 |
+
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
| 55 |
+
out_channels (int): Number of output channels.
|
| 56 |
+
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
| 57 |
+
patch_temporal (int): Temporal resolution of patches for input processing.
|
| 58 |
+
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
| 59 |
+
block_config (str): Configuration of the transformer block. See Notes for supported block types.
|
| 60 |
+
model_channels (int): Base number of channels used throughout the model.
|
| 61 |
+
num_blocks (int): Number of transformer blocks.
|
| 62 |
+
num_heads (int): Number of heads in the multi-head attention layers.
|
| 63 |
+
mlp_ratio (float): Expansion ratio for MLP blocks.
|
| 64 |
+
block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD').
|
| 65 |
+
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
| 66 |
+
use_cross_attn_mask (bool): Whether to use mask in cross-attention.
|
| 67 |
+
pos_emb_cls (str): Type of positional embeddings.
|
| 68 |
+
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
| 69 |
+
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
| 70 |
+
affline_emb_norm (bool): Whether to normalize affine embeddings.
|
| 71 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
| 72 |
+
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
| 73 |
+
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
| 74 |
+
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
| 75 |
+
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
| 76 |
+
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
| 77 |
+
extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings.
|
| 78 |
+
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
| 79 |
+
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
| 80 |
+
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
| 81 |
+
|
| 82 |
+
Notes:
|
| 83 |
+
Supported block types in block_config:
|
| 84 |
+
* cross_attn, ca: Cross attention
|
| 85 |
+
* full_attn: Full attention on all flattened tokens
|
| 86 |
+
* mlp, ff: Feed forward block
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
max_img_h: int,
|
| 92 |
+
max_img_w: int,
|
| 93 |
+
max_frames: int,
|
| 94 |
+
in_channels: int,
|
| 95 |
+
out_channels: int,
|
| 96 |
+
patch_spatial: tuple,
|
| 97 |
+
patch_temporal: int,
|
| 98 |
+
concat_padding_mask: bool = True,
|
| 99 |
+
# attention settings
|
| 100 |
+
block_config: str = "FA-CA-MLP",
|
| 101 |
+
model_channels: int = 768,
|
| 102 |
+
num_blocks: int = 10,
|
| 103 |
+
num_heads: int = 16,
|
| 104 |
+
mlp_ratio: float = 4.0,
|
| 105 |
+
block_x_format: str = "BTHWD",
|
| 106 |
+
# cross attention settings
|
| 107 |
+
crossattn_emb_channels: int = 1024,
|
| 108 |
+
use_cross_attn_mask: bool = False,
|
| 109 |
+
# positional embedding settings
|
| 110 |
+
pos_emb_cls: str = "sincos",
|
| 111 |
+
pos_emb_learnable: bool = False,
|
| 112 |
+
pos_emb_interpolation: str = "crop",
|
| 113 |
+
affline_emb_norm: bool = False, # whether or not to normalize the affine embedding
|
| 114 |
+
use_adaln_lora: bool = False,
|
| 115 |
+
adaln_lora_dim: int = 256,
|
| 116 |
+
rope_h_extrapolation_ratio: float = 1.0,
|
| 117 |
+
rope_w_extrapolation_ratio: float = 1.0,
|
| 118 |
+
rope_t_extrapolation_ratio: float = 1.0,
|
| 119 |
+
extra_per_block_abs_pos_emb: bool = False,
|
| 120 |
+
extra_per_block_abs_pos_emb_type: str = "sincos",
|
| 121 |
+
extra_h_extrapolation_ratio: float = 1.0,
|
| 122 |
+
extra_w_extrapolation_ratio: float = 1.0,
|
| 123 |
+
extra_t_extrapolation_ratio: float = 1.0,
|
| 124 |
+
image_model=None,
|
| 125 |
+
device=None,
|
| 126 |
+
dtype=None,
|
| 127 |
+
operations=None,
|
| 128 |
+
) -> None:
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.max_img_h = max_img_h
|
| 131 |
+
self.max_img_w = max_img_w
|
| 132 |
+
self.max_frames = max_frames
|
| 133 |
+
self.in_channels = in_channels
|
| 134 |
+
self.out_channels = out_channels
|
| 135 |
+
self.patch_spatial = patch_spatial
|
| 136 |
+
self.patch_temporal = patch_temporal
|
| 137 |
+
self.num_heads = num_heads
|
| 138 |
+
self.num_blocks = num_blocks
|
| 139 |
+
self.model_channels = model_channels
|
| 140 |
+
self.use_cross_attn_mask = use_cross_attn_mask
|
| 141 |
+
self.concat_padding_mask = concat_padding_mask
|
| 142 |
+
# positional embedding settings
|
| 143 |
+
self.pos_emb_cls = pos_emb_cls
|
| 144 |
+
self.pos_emb_learnable = pos_emb_learnable
|
| 145 |
+
self.pos_emb_interpolation = pos_emb_interpolation
|
| 146 |
+
self.affline_emb_norm = affline_emb_norm
|
| 147 |
+
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
| 148 |
+
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
| 149 |
+
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
| 150 |
+
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
| 151 |
+
self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower()
|
| 152 |
+
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
| 153 |
+
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
| 154 |
+
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
| 155 |
+
self.dtype = dtype
|
| 156 |
+
weight_args = {"device": device, "dtype": dtype}
|
| 157 |
+
|
| 158 |
+
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
| 159 |
+
self.x_embedder = PatchEmbed(
|
| 160 |
+
spatial_patch_size=patch_spatial,
|
| 161 |
+
temporal_patch_size=patch_temporal,
|
| 162 |
+
in_channels=in_channels,
|
| 163 |
+
out_channels=model_channels,
|
| 164 |
+
bias=False,
|
| 165 |
+
weight_args=weight_args,
|
| 166 |
+
operations=operations,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.build_pos_embed(device=device, dtype=dtype)
|
| 170 |
+
self.block_x_format = block_x_format
|
| 171 |
+
self.use_adaln_lora = use_adaln_lora
|
| 172 |
+
self.adaln_lora_dim = adaln_lora_dim
|
| 173 |
+
self.t_embedder = nn.ModuleList(
|
| 174 |
+
[Timesteps(model_channels),
|
| 175 |
+
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
self.blocks = nn.ModuleDict()
|
| 179 |
+
|
| 180 |
+
for idx in range(num_blocks):
|
| 181 |
+
self.blocks[f"block{idx}"] = GeneralDITTransformerBlock(
|
| 182 |
+
x_dim=model_channels,
|
| 183 |
+
context_dim=crossattn_emb_channels,
|
| 184 |
+
num_heads=num_heads,
|
| 185 |
+
block_config=block_config,
|
| 186 |
+
mlp_ratio=mlp_ratio,
|
| 187 |
+
x_format=self.block_x_format,
|
| 188 |
+
use_adaln_lora=use_adaln_lora,
|
| 189 |
+
adaln_lora_dim=adaln_lora_dim,
|
| 190 |
+
weight_args=weight_args,
|
| 191 |
+
operations=operations,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if self.affline_emb_norm:
|
| 195 |
+
logging.debug("Building affine embedding normalization layer")
|
| 196 |
+
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
| 197 |
+
else:
|
| 198 |
+
self.affline_norm = nn.Identity()
|
| 199 |
+
|
| 200 |
+
self.final_layer = FinalLayer(
|
| 201 |
+
hidden_size=self.model_channels,
|
| 202 |
+
spatial_patch_size=self.patch_spatial,
|
| 203 |
+
temporal_patch_size=self.patch_temporal,
|
| 204 |
+
out_channels=self.out_channels,
|
| 205 |
+
use_adaln_lora=self.use_adaln_lora,
|
| 206 |
+
adaln_lora_dim=self.adaln_lora_dim,
|
| 207 |
+
weight_args=weight_args,
|
| 208 |
+
operations=operations,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def build_pos_embed(self, device=None, dtype=None):
|
| 212 |
+
if self.pos_emb_cls == "rope3d":
|
| 213 |
+
cls_type = VideoRopePosition3DEmb
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
| 216 |
+
|
| 217 |
+
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
| 218 |
+
kwargs = dict(
|
| 219 |
+
model_channels=self.model_channels,
|
| 220 |
+
len_h=self.max_img_h // self.patch_spatial,
|
| 221 |
+
len_w=self.max_img_w // self.patch_spatial,
|
| 222 |
+
len_t=self.max_frames // self.patch_temporal,
|
| 223 |
+
is_learnable=self.pos_emb_learnable,
|
| 224 |
+
interpolation=self.pos_emb_interpolation,
|
| 225 |
+
head_dim=self.model_channels // self.num_heads,
|
| 226 |
+
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
| 227 |
+
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
| 228 |
+
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
| 229 |
+
device=device,
|
| 230 |
+
)
|
| 231 |
+
self.pos_embedder = cls_type(
|
| 232 |
+
**kwargs,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if self.extra_per_block_abs_pos_emb:
|
| 236 |
+
assert self.extra_per_block_abs_pos_emb_type in [
|
| 237 |
+
"learnable",
|
| 238 |
+
], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}"
|
| 239 |
+
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
| 240 |
+
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
| 241 |
+
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
| 242 |
+
kwargs["device"] = device
|
| 243 |
+
kwargs["dtype"] = dtype
|
| 244 |
+
self.extra_pos_embedder = LearnablePosEmbAxis(
|
| 245 |
+
**kwargs,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def prepare_embedded_sequence(
|
| 249 |
+
self,
|
| 250 |
+
x_B_C_T_H_W: torch.Tensor,
|
| 251 |
+
fps: Optional[torch.Tensor] = None,
|
| 252 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 253 |
+
latent_condition: Optional[torch.Tensor] = None,
|
| 254 |
+
latent_condition_sigma: Optional[torch.Tensor] = None,
|
| 255 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 256 |
+
"""
|
| 257 |
+
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
x_B_C_T_H_W (torch.Tensor): video
|
| 261 |
+
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
| 262 |
+
If None, a default value (`self.base_fps`) will be used.
|
| 263 |
+
padding_mask (Optional[torch.Tensor]): current it is not used
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 267 |
+
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
| 268 |
+
- An optional positional embedding tensor, returned only if the positional embedding class
|
| 269 |
+
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
| 270 |
+
|
| 271 |
+
Notes:
|
| 272 |
+
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
| 273 |
+
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
| 274 |
+
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
| 275 |
+
the `self.pos_embedder` with the shape [T, H, W].
|
| 276 |
+
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
| 277 |
+
`self.pos_embedder` with the fps tensor.
|
| 278 |
+
- Otherwise, the positional embeddings are generated without considering fps.
|
| 279 |
+
"""
|
| 280 |
+
if self.concat_padding_mask:
|
| 281 |
+
if padding_mask is not None:
|
| 282 |
+
padding_mask = transforms.functional.resize(
|
| 283 |
+
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
| 287 |
+
|
| 288 |
+
x_B_C_T_H_W = torch.cat(
|
| 289 |
+
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
| 290 |
+
)
|
| 291 |
+
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
| 292 |
+
|
| 293 |
+
if self.extra_per_block_abs_pos_emb:
|
| 294 |
+
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
| 295 |
+
else:
|
| 296 |
+
extra_pos_emb = None
|
| 297 |
+
|
| 298 |
+
if "rope" in self.pos_emb_cls.lower():
|
| 299 |
+
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
| 300 |
+
|
| 301 |
+
if "fps_aware" in self.pos_emb_cls:
|
| 302 |
+
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
| 303 |
+
else:
|
| 304 |
+
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
| 305 |
+
|
| 306 |
+
return x_B_T_H_W_D, None, extra_pos_emb
|
| 307 |
+
|
| 308 |
+
def decoder_head(
|
| 309 |
+
self,
|
| 310 |
+
x_B_T_H_W_D: torch.Tensor,
|
| 311 |
+
emb_B_D: torch.Tensor,
|
| 312 |
+
crossattn_emb: torch.Tensor,
|
| 313 |
+
origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W]
|
| 314 |
+
crossattn_mask: Optional[torch.Tensor] = None,
|
| 315 |
+
adaln_lora_B_3D: Optional[torch.Tensor] = None,
|
| 316 |
+
) -> torch.Tensor:
|
| 317 |
+
del crossattn_emb, crossattn_mask
|
| 318 |
+
B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape
|
| 319 |
+
x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D")
|
| 320 |
+
x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D)
|
| 321 |
+
# This is to ensure x_BT_HW_D has the correct shape because
|
| 322 |
+
# when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D).
|
| 323 |
+
x_BT_HW_D = x_BT_HW_D.view(
|
| 324 |
+
B * T_before_patchify // self.patch_temporal,
|
| 325 |
+
H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial,
|
| 326 |
+
-1,
|
| 327 |
+
)
|
| 328 |
+
x_B_D_T_H_W = rearrange(
|
| 329 |
+
x_BT_HW_D,
|
| 330 |
+
"(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
| 331 |
+
p1=self.patch_spatial,
|
| 332 |
+
p2=self.patch_spatial,
|
| 333 |
+
H=H_before_patchify // self.patch_spatial,
|
| 334 |
+
W=W_before_patchify // self.patch_spatial,
|
| 335 |
+
t=self.patch_temporal,
|
| 336 |
+
B=B,
|
| 337 |
+
)
|
| 338 |
+
return x_B_D_T_H_W
|
| 339 |
+
|
| 340 |
+
def forward_before_blocks(
|
| 341 |
+
self,
|
| 342 |
+
x: torch.Tensor,
|
| 343 |
+
timesteps: torch.Tensor,
|
| 344 |
+
crossattn_emb: torch.Tensor,
|
| 345 |
+
crossattn_mask: Optional[torch.Tensor] = None,
|
| 346 |
+
fps: Optional[torch.Tensor] = None,
|
| 347 |
+
image_size: Optional[torch.Tensor] = None,
|
| 348 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 349 |
+
scalar_feature: Optional[torch.Tensor] = None,
|
| 350 |
+
data_type: Optional[DataType] = DataType.VIDEO,
|
| 351 |
+
latent_condition: Optional[torch.Tensor] = None,
|
| 352 |
+
latent_condition_sigma: Optional[torch.Tensor] = None,
|
| 353 |
+
**kwargs,
|
| 354 |
+
) -> torch.Tensor:
|
| 355 |
+
"""
|
| 356 |
+
Args:
|
| 357 |
+
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
| 358 |
+
timesteps: (B, ) tensor of timesteps
|
| 359 |
+
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
| 360 |
+
crossattn_mask: (B, N) tensor of cross-attention masks
|
| 361 |
+
"""
|
| 362 |
+
del kwargs
|
| 363 |
+
assert isinstance(
|
| 364 |
+
data_type, DataType
|
| 365 |
+
), f"Expected DataType, got {type(data_type)}. We need discuss this flag later."
|
| 366 |
+
original_shape = x.shape
|
| 367 |
+
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
| 368 |
+
x,
|
| 369 |
+
fps=fps,
|
| 370 |
+
padding_mask=padding_mask,
|
| 371 |
+
latent_condition=latent_condition,
|
| 372 |
+
latent_condition_sigma=latent_condition_sigma,
|
| 373 |
+
)
|
| 374 |
+
# logging affline scale information
|
| 375 |
+
affline_scale_log_info = {}
|
| 376 |
+
|
| 377 |
+
timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype))
|
| 378 |
+
affline_emb_B_D = timesteps_B_D
|
| 379 |
+
affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach()
|
| 380 |
+
|
| 381 |
+
if scalar_feature is not None:
|
| 382 |
+
raise NotImplementedError("Scalar feature is not implemented yet.")
|
| 383 |
+
|
| 384 |
+
affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach()
|
| 385 |
+
affline_emb_B_D = self.affline_norm(affline_emb_B_D)
|
| 386 |
+
|
| 387 |
+
if self.use_cross_attn_mask:
|
| 388 |
+
if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask):
|
| 389 |
+
crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
| 390 |
+
crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length]
|
| 391 |
+
else:
|
| 392 |
+
crossattn_mask = None
|
| 393 |
+
|
| 394 |
+
if self.blocks["block0"].x_format == "THWBD":
|
| 395 |
+
x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D")
|
| 396 |
+
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
| 397 |
+
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange(
|
| 398 |
+
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D"
|
| 399 |
+
)
|
| 400 |
+
crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D")
|
| 401 |
+
|
| 402 |
+
if crossattn_mask:
|
| 403 |
+
crossattn_mask = rearrange(crossattn_mask, "B M -> M B")
|
| 404 |
+
|
| 405 |
+
elif self.blocks["block0"].x_format == "BTHWD":
|
| 406 |
+
x = x_B_T_H_W_D
|
| 407 |
+
else:
|
| 408 |
+
raise ValueError(f"Unknown x_format {self.blocks[0].x_format}")
|
| 409 |
+
output = {
|
| 410 |
+
"x": x,
|
| 411 |
+
"affline_emb_B_D": affline_emb_B_D,
|
| 412 |
+
"crossattn_emb": crossattn_emb,
|
| 413 |
+
"crossattn_mask": crossattn_mask,
|
| 414 |
+
"rope_emb_L_1_1_D": rope_emb_L_1_1_D,
|
| 415 |
+
"adaln_lora_B_3D": adaln_lora_B_3D,
|
| 416 |
+
"original_shape": original_shape,
|
| 417 |
+
"extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
| 418 |
+
}
|
| 419 |
+
return output
|
| 420 |
+
|
| 421 |
+
def forward(
|
| 422 |
+
self,
|
| 423 |
+
x: torch.Tensor,
|
| 424 |
+
timesteps: torch.Tensor,
|
| 425 |
+
context: torch.Tensor,
|
| 426 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 427 |
+
# crossattn_emb: torch.Tensor,
|
| 428 |
+
# crossattn_mask: Optional[torch.Tensor] = None,
|
| 429 |
+
fps: Optional[torch.Tensor] = None,
|
| 430 |
+
image_size: Optional[torch.Tensor] = None,
|
| 431 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 432 |
+
scalar_feature: Optional[torch.Tensor] = None,
|
| 433 |
+
data_type: Optional[DataType] = DataType.VIDEO,
|
| 434 |
+
latent_condition: Optional[torch.Tensor] = None,
|
| 435 |
+
latent_condition_sigma: Optional[torch.Tensor] = None,
|
| 436 |
+
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
| 437 |
+
**kwargs,
|
| 438 |
+
):
|
| 439 |
+
"""
|
| 440 |
+
Args:
|
| 441 |
+
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
| 442 |
+
timesteps: (B, ) tensor of timesteps
|
| 443 |
+
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
| 444 |
+
crossattn_mask: (B, N) tensor of cross-attention masks
|
| 445 |
+
condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to
|
| 446 |
+
augment condition input, the lvg model will condition on the condition_video_augment_sigma value;
|
| 447 |
+
we need forward_before_blocks pass to the forward_before_blocks function.
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
crossattn_emb = context
|
| 451 |
+
crossattn_mask = attention_mask
|
| 452 |
+
|
| 453 |
+
inputs = self.forward_before_blocks(
|
| 454 |
+
x=x,
|
| 455 |
+
timesteps=timesteps,
|
| 456 |
+
crossattn_emb=crossattn_emb,
|
| 457 |
+
crossattn_mask=crossattn_mask,
|
| 458 |
+
fps=fps,
|
| 459 |
+
image_size=image_size,
|
| 460 |
+
padding_mask=padding_mask,
|
| 461 |
+
scalar_feature=scalar_feature,
|
| 462 |
+
data_type=data_type,
|
| 463 |
+
latent_condition=latent_condition,
|
| 464 |
+
latent_condition_sigma=latent_condition_sigma,
|
| 465 |
+
condition_video_augment_sigma=condition_video_augment_sigma,
|
| 466 |
+
**kwargs,
|
| 467 |
+
)
|
| 468 |
+
x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = (
|
| 469 |
+
inputs["x"],
|
| 470 |
+
inputs["affline_emb_B_D"],
|
| 471 |
+
inputs["crossattn_emb"],
|
| 472 |
+
inputs["crossattn_mask"],
|
| 473 |
+
inputs["rope_emb_L_1_1_D"],
|
| 474 |
+
inputs["adaln_lora_B_3D"],
|
| 475 |
+
inputs["original_shape"],
|
| 476 |
+
)
|
| 477 |
+
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
|
| 478 |
+
del inputs
|
| 479 |
+
|
| 480 |
+
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
| 481 |
+
assert (
|
| 482 |
+
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
| 483 |
+
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
|
| 484 |
+
|
| 485 |
+
for _, block in self.blocks.items():
|
| 486 |
+
assert (
|
| 487 |
+
self.blocks["block0"].x_format == block.x_format
|
| 488 |
+
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
|
| 489 |
+
|
| 490 |
+
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
| 491 |
+
x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
|
| 492 |
+
x = block(
|
| 493 |
+
x,
|
| 494 |
+
affline_emb_B_D,
|
| 495 |
+
crossattn_emb,
|
| 496 |
+
crossattn_mask,
|
| 497 |
+
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
|
| 498 |
+
adaln_lora_B_3D=adaln_lora_B_3D,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
|
| 502 |
+
|
| 503 |
+
x_B_D_T_H_W = self.decoder_head(
|
| 504 |
+
x_B_T_H_W_D=x_B_T_H_W_D,
|
| 505 |
+
emb_B_D=affline_emb_B_D,
|
| 506 |
+
crossattn_emb=None,
|
| 507 |
+
origin_shape=original_shape,
|
| 508 |
+
crossattn_mask=None,
|
| 509 |
+
adaln_lora_B_3D=adaln_lora_B_3D,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
return x_B_D_T_H_W
|
ComfyUI/comfy/ldm/cosmos/position_embedding.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import List, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from einops import rearrange, repeat
|
| 20 |
+
from torch import nn
|
| 21 |
+
import math
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
|
| 25 |
+
"""
|
| 26 |
+
Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
x (torch.Tensor): The input tensor to normalize.
|
| 30 |
+
dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first.
|
| 31 |
+
eps (float, optional): A small constant to ensure numerical stability during division.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
torch.Tensor: The normalized tensor.
|
| 35 |
+
"""
|
| 36 |
+
if dim is None:
|
| 37 |
+
dim = list(range(1, x.ndim))
|
| 38 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
| 39 |
+
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
| 40 |
+
return x / norm.to(x.dtype)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class VideoPositionEmb(nn.Module):
|
| 44 |
+
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
It delegates the embedding generation to generate_embeddings function.
|
| 47 |
+
"""
|
| 48 |
+
B_T_H_W_C = x_B_T_H_W_C.shape
|
| 49 |
+
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
|
| 50 |
+
|
| 51 |
+
return embeddings
|
| 52 |
+
|
| 53 |
+
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None):
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class VideoRopePosition3DEmb(VideoPositionEmb):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
*, # enforce keyword arguments
|
| 61 |
+
head_dim: int,
|
| 62 |
+
len_h: int,
|
| 63 |
+
len_w: int,
|
| 64 |
+
len_t: int,
|
| 65 |
+
base_fps: int = 24,
|
| 66 |
+
h_extrapolation_ratio: float = 1.0,
|
| 67 |
+
w_extrapolation_ratio: float = 1.0,
|
| 68 |
+
t_extrapolation_ratio: float = 1.0,
|
| 69 |
+
enable_fps_modulation: bool = True,
|
| 70 |
+
device=None,
|
| 71 |
+
**kwargs, # used for compatibility with other positional embeddings; unused in this class
|
| 72 |
+
):
|
| 73 |
+
del kwargs
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.base_fps = base_fps
|
| 76 |
+
self.max_h = len_h
|
| 77 |
+
self.max_w = len_w
|
| 78 |
+
self.enable_fps_modulation = enable_fps_modulation
|
| 79 |
+
|
| 80 |
+
dim = head_dim
|
| 81 |
+
dim_h = dim // 6 * 2
|
| 82 |
+
dim_w = dim_h
|
| 83 |
+
dim_t = dim - 2 * dim_h
|
| 84 |
+
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
|
| 85 |
+
self.register_buffer(
|
| 86 |
+
"dim_spatial_range",
|
| 87 |
+
torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h,
|
| 88 |
+
persistent=False,
|
| 89 |
+
)
|
| 90 |
+
self.register_buffer(
|
| 91 |
+
"dim_temporal_range",
|
| 92 |
+
torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t,
|
| 93 |
+
persistent=False,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
|
| 97 |
+
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
|
| 98 |
+
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))
|
| 99 |
+
|
| 100 |
+
def generate_embeddings(
|
| 101 |
+
self,
|
| 102 |
+
B_T_H_W_C: torch.Size,
|
| 103 |
+
fps: Optional[torch.Tensor] = None,
|
| 104 |
+
h_ntk_factor: Optional[float] = None,
|
| 105 |
+
w_ntk_factor: Optional[float] = None,
|
| 106 |
+
t_ntk_factor: Optional[float] = None,
|
| 107 |
+
device=None,
|
| 108 |
+
dtype=None,
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
Generate embeddings for the given input size.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
|
| 115 |
+
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
|
| 116 |
+
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
|
| 117 |
+
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
|
| 118 |
+
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Not specified in the original code snippet.
|
| 122 |
+
"""
|
| 123 |
+
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
|
| 124 |
+
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
|
| 125 |
+
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor
|
| 126 |
+
|
| 127 |
+
h_theta = 10000.0 * h_ntk_factor
|
| 128 |
+
w_theta = 10000.0 * w_ntk_factor
|
| 129 |
+
t_theta = 10000.0 * t_ntk_factor
|
| 130 |
+
|
| 131 |
+
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device))
|
| 132 |
+
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device))
|
| 133 |
+
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
|
| 134 |
+
|
| 135 |
+
B, T, H, W, _ = B_T_H_W_C
|
| 136 |
+
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
|
| 137 |
+
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
|
| 138 |
+
assert (
|
| 139 |
+
uniform_fps or B == 1 or T == 1
|
| 140 |
+
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
|
| 141 |
+
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
|
| 142 |
+
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
|
| 143 |
+
|
| 144 |
+
# apply sequence scaling in temporal dimension
|
| 145 |
+
if fps is None or self.enable_fps_modulation is False: # image case
|
| 146 |
+
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
|
| 147 |
+
else:
|
| 148 |
+
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
|
| 149 |
+
|
| 150 |
+
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
|
| 151 |
+
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)
|
| 152 |
+
half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1)
|
| 153 |
+
|
| 154 |
+
em_T_H_W_D = torch.cat(
|
| 155 |
+
[
|
| 156 |
+
repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W),
|
| 157 |
+
repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W),
|
| 158 |
+
repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H),
|
| 159 |
+
]
|
| 160 |
+
, dim=-2,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class LearnablePosEmbAxis(VideoPositionEmb):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
*, # enforce keyword arguments
|
| 170 |
+
interpolation: str,
|
| 171 |
+
model_channels: int,
|
| 172 |
+
len_h: int,
|
| 173 |
+
len_w: int,
|
| 174 |
+
len_t: int,
|
| 175 |
+
device=None,
|
| 176 |
+
dtype=None,
|
| 177 |
+
**kwargs,
|
| 178 |
+
):
|
| 179 |
+
"""
|
| 180 |
+
Args:
|
| 181 |
+
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
|
| 182 |
+
"""
|
| 183 |
+
del kwargs # unused
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.interpolation = interpolation
|
| 186 |
+
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
|
| 187 |
+
|
| 188 |
+
self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
|
| 189 |
+
self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
|
| 190 |
+
self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
|
| 191 |
+
|
| 192 |
+
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
|
| 193 |
+
B, T, H, W, _ = B_T_H_W_C
|
| 194 |
+
if self.interpolation == "crop":
|
| 195 |
+
emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
|
| 196 |
+
emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
|
| 197 |
+
emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
|
| 198 |
+
emb = (
|
| 199 |
+
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
|
| 200 |
+
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
|
| 201 |
+
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
|
| 202 |
+
)
|
| 203 |
+
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
|
| 204 |
+
else:
|
| 205 |
+
raise ValueError(f"Unknown interpolation method {self.interpolation}")
|
| 206 |
+
|
| 207 |
+
return normalize(emb, dim=-1, eps=1e-6)
|
ComfyUI/comfy/ldm/cosmos/predict2.py
ADDED
|
@@ -0,0 +1,864 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# original code from: https://github.com/nvidia-cosmos/cosmos-predict2
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from einops.layers.torch import Rearrange
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Callable, Optional, Tuple
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
| 12 |
+
from torchvision import transforms
|
| 13 |
+
|
| 14 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 15 |
+
|
| 16 |
+
def apply_rotary_pos_emb(
|
| 17 |
+
t: torch.Tensor,
|
| 18 |
+
freqs: torch.Tensor,
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
| 21 |
+
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
| 22 |
+
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
| 23 |
+
return t_out
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------- Feed Forward Network -----------------------
|
| 27 |
+
class GPT2FeedForward(nn.Module):
|
| 28 |
+
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.activation = nn.GELU()
|
| 31 |
+
self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
|
| 32 |
+
self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
|
| 33 |
+
|
| 34 |
+
self._layer_id = None
|
| 35 |
+
self._dim = d_model
|
| 36 |
+
self._hidden_dim = d_ff
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
x = self.layer1(x)
|
| 40 |
+
|
| 41 |
+
x = self.activation(x)
|
| 42 |
+
x = self.layer2(x)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""Computes multi-head attention using PyTorch's native implementation.
|
| 48 |
+
|
| 49 |
+
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
|
| 50 |
+
It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product
|
| 51 |
+
attention, and rearranges the output back to the original format.
|
| 52 |
+
|
| 53 |
+
The input tensor names use the following dimension conventions:
|
| 54 |
+
|
| 55 |
+
- B: batch size
|
| 56 |
+
- S: sequence length
|
| 57 |
+
- H: number of attention heads
|
| 58 |
+
- D: head dimension
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim)
|
| 62 |
+
k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim)
|
| 63 |
+
v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Attention output tensor with shape (batch, seq_len, n_heads * head_dim)
|
| 67 |
+
"""
|
| 68 |
+
in_q_shape = q_B_S_H_D.shape
|
| 69 |
+
in_k_shape = k_B_S_H_D.shape
|
| 70 |
+
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
| 71 |
+
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
| 72 |
+
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
| 73 |
+
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Attention(nn.Module):
|
| 77 |
+
"""
|
| 78 |
+
A flexible attention module supporting both self-attention and cross-attention mechanisms.
|
| 79 |
+
|
| 80 |
+
This module implements a multi-head attention layer that can operate in either self-attention
|
| 81 |
+
or cross-attention mode. The mode is determined by whether a context dimension is provided.
|
| 82 |
+
The implementation uses scaled dot-product attention and supports optional bias terms and
|
| 83 |
+
dropout regularization.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
query_dim (int): The dimensionality of the query vectors.
|
| 87 |
+
context_dim (int, optional): The dimensionality of the context (key/value) vectors.
|
| 88 |
+
If None, the module operates in self-attention mode using query_dim. Default: None
|
| 89 |
+
n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8
|
| 90 |
+
head_dim (int, optional): The dimension of each attention head. Default: 64
|
| 91 |
+
dropout (float, optional): Dropout probability applied to the output. Default: 0.0
|
| 92 |
+
qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd"
|
| 93 |
+
backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine"
|
| 94 |
+
|
| 95 |
+
Examples:
|
| 96 |
+
>>> # Self-attention with 512 dimensions and 8 heads
|
| 97 |
+
>>> self_attn = Attention(query_dim=512)
|
| 98 |
+
>>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim)
|
| 99 |
+
>>> out = self_attn(x) # (32, 16, 512)
|
| 100 |
+
|
| 101 |
+
>>> # Cross-attention
|
| 102 |
+
>>> cross_attn = Attention(query_dim=512, context_dim=256)
|
| 103 |
+
>>> query = torch.randn(32, 16, 512)
|
| 104 |
+
>>> context = torch.randn(32, 8, 256)
|
| 105 |
+
>>> out = cross_attn(query, context) # (32, 16, 512)
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
query_dim: int,
|
| 111 |
+
context_dim: Optional[int] = None,
|
| 112 |
+
n_heads: int = 8,
|
| 113 |
+
head_dim: int = 64,
|
| 114 |
+
dropout: float = 0.0,
|
| 115 |
+
device=None,
|
| 116 |
+
dtype=None,
|
| 117 |
+
operations=None,
|
| 118 |
+
) -> None:
|
| 119 |
+
super().__init__()
|
| 120 |
+
logging.debug(
|
| 121 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
| 122 |
+
f"{n_heads} heads with a dimension of {head_dim}."
|
| 123 |
+
)
|
| 124 |
+
self.is_selfattn = context_dim is None # self attention
|
| 125 |
+
|
| 126 |
+
context_dim = query_dim if context_dim is None else context_dim
|
| 127 |
+
inner_dim = head_dim * n_heads
|
| 128 |
+
|
| 129 |
+
self.n_heads = n_heads
|
| 130 |
+
self.head_dim = head_dim
|
| 131 |
+
self.query_dim = query_dim
|
| 132 |
+
self.context_dim = context_dim
|
| 133 |
+
|
| 134 |
+
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 135 |
+
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
| 136 |
+
|
| 137 |
+
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 138 |
+
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
| 139 |
+
|
| 140 |
+
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
| 141 |
+
self.v_norm = nn.Identity()
|
| 142 |
+
|
| 143 |
+
self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
| 144 |
+
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
| 145 |
+
|
| 146 |
+
self.attn_op = torch_attention_op
|
| 147 |
+
|
| 148 |
+
self._query_dim = query_dim
|
| 149 |
+
self._context_dim = context_dim
|
| 150 |
+
self._inner_dim = inner_dim
|
| 151 |
+
|
| 152 |
+
def compute_qkv(
|
| 153 |
+
self,
|
| 154 |
+
x: torch.Tensor,
|
| 155 |
+
context: Optional[torch.Tensor] = None,
|
| 156 |
+
rope_emb: Optional[torch.Tensor] = None,
|
| 157 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 158 |
+
q = self.q_proj(x)
|
| 159 |
+
context = x if context is None else context
|
| 160 |
+
k = self.k_proj(context)
|
| 161 |
+
v = self.v_proj(context)
|
| 162 |
+
q, k, v = map(
|
| 163 |
+
lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim),
|
| 164 |
+
(q, k, v),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def apply_norm_and_rotary_pos_emb(
|
| 168 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor]
|
| 169 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 170 |
+
q = self.q_norm(q)
|
| 171 |
+
k = self.k_norm(k)
|
| 172 |
+
v = self.v_norm(v)
|
| 173 |
+
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
| 174 |
+
q = apply_rotary_pos_emb(q, rope_emb)
|
| 175 |
+
k = apply_rotary_pos_emb(k, rope_emb)
|
| 176 |
+
return q, k, v
|
| 177 |
+
|
| 178 |
+
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
| 179 |
+
|
| 180 |
+
return q, k, v
|
| 181 |
+
|
| 182 |
+
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 183 |
+
result = self.attn_op(q, k, v) # [B, S, H, D]
|
| 184 |
+
return self.output_dropout(self.output_proj(result))
|
| 185 |
+
|
| 186 |
+
def forward(
|
| 187 |
+
self,
|
| 188 |
+
x: torch.Tensor,
|
| 189 |
+
context: Optional[torch.Tensor] = None,
|
| 190 |
+
rope_emb: Optional[torch.Tensor] = None,
|
| 191 |
+
) -> torch.Tensor:
|
| 192 |
+
"""
|
| 193 |
+
Args:
|
| 194 |
+
x (Tensor): The query tensor of shape [B, Mq, K]
|
| 195 |
+
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
|
| 196 |
+
"""
|
| 197 |
+
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
| 198 |
+
return self.compute_attention(q, k, v)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class Timesteps(nn.Module):
|
| 202 |
+
def __init__(self, num_channels: int):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.num_channels = num_channels
|
| 205 |
+
|
| 206 |
+
def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}"
|
| 208 |
+
timesteps = timesteps_B_T.flatten().float()
|
| 209 |
+
half_dim = self.num_channels // 2
|
| 210 |
+
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
| 211 |
+
exponent = exponent / (half_dim - 0.0)
|
| 212 |
+
|
| 213 |
+
emb = torch.exp(exponent)
|
| 214 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 215 |
+
|
| 216 |
+
sin_emb = torch.sin(emb)
|
| 217 |
+
cos_emb = torch.cos(emb)
|
| 218 |
+
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
| 219 |
+
|
| 220 |
+
return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1])
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class TimestepEmbedding(nn.Module):
|
| 224 |
+
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
| 225 |
+
super().__init__()
|
| 226 |
+
logging.debug(
|
| 227 |
+
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
| 228 |
+
)
|
| 229 |
+
self.in_dim = in_features
|
| 230 |
+
self.out_dim = out_features
|
| 231 |
+
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype)
|
| 232 |
+
self.activation = nn.SiLU()
|
| 233 |
+
self.use_adaln_lora = use_adaln_lora
|
| 234 |
+
if use_adaln_lora:
|
| 235 |
+
self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype)
|
| 236 |
+
else:
|
| 237 |
+
self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype)
|
| 238 |
+
|
| 239 |
+
def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 240 |
+
emb = self.linear_1(sample)
|
| 241 |
+
emb = self.activation(emb)
|
| 242 |
+
emb = self.linear_2(emb)
|
| 243 |
+
|
| 244 |
+
if self.use_adaln_lora:
|
| 245 |
+
adaln_lora_B_T_3D = emb
|
| 246 |
+
emb_B_T_D = sample
|
| 247 |
+
else:
|
| 248 |
+
adaln_lora_B_T_3D = None
|
| 249 |
+
emb_B_T_D = emb
|
| 250 |
+
|
| 251 |
+
return emb_B_T_D, adaln_lora_B_T_3D
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class PatchEmbed(nn.Module):
|
| 255 |
+
"""
|
| 256 |
+
PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers,
|
| 257 |
+
depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions,
|
| 258 |
+
making it suitable for video and image processing tasks. It supports dividing the input into patches
|
| 259 |
+
and embedding each patch into a vector of size `out_channels`.
|
| 260 |
+
|
| 261 |
+
Parameters:
|
| 262 |
+
- spatial_patch_size (int): The size of each spatial patch.
|
| 263 |
+
- temporal_patch_size (int): The size of each temporal patch.
|
| 264 |
+
- in_channels (int): Number of input channels. Default: 3.
|
| 265 |
+
- out_channels (int): The dimension of the embedding vector for each patch. Default: 768.
|
| 266 |
+
- bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True.
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
spatial_patch_size: int,
|
| 272 |
+
temporal_patch_size: int,
|
| 273 |
+
in_channels: int = 3,
|
| 274 |
+
out_channels: int = 768,
|
| 275 |
+
device=None, dtype=None, operations=None
|
| 276 |
+
):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.spatial_patch_size = spatial_patch_size
|
| 279 |
+
self.temporal_patch_size = temporal_patch_size
|
| 280 |
+
|
| 281 |
+
self.proj = nn.Sequential(
|
| 282 |
+
Rearrange(
|
| 283 |
+
"b c (t r) (h m) (w n) -> b t h w (c r m n)",
|
| 284 |
+
r=temporal_patch_size,
|
| 285 |
+
m=spatial_patch_size,
|
| 286 |
+
n=spatial_patch_size,
|
| 287 |
+
),
|
| 288 |
+
operations.Linear(
|
| 289 |
+
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype
|
| 290 |
+
),
|
| 291 |
+
)
|
| 292 |
+
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
| 293 |
+
|
| 294 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 295 |
+
"""
|
| 296 |
+
Forward pass of the PatchEmbed module.
|
| 297 |
+
|
| 298 |
+
Parameters:
|
| 299 |
+
- x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where
|
| 300 |
+
B is the batch size,
|
| 301 |
+
C is the number of channels,
|
| 302 |
+
T is the temporal dimension,
|
| 303 |
+
H is the height, and
|
| 304 |
+
W is the width of the input.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
- torch.Tensor: The embedded patches as a tensor, with shape b t h w c.
|
| 308 |
+
"""
|
| 309 |
+
assert x.dim() == 5
|
| 310 |
+
_, _, T, H, W = x.shape
|
| 311 |
+
assert (
|
| 312 |
+
H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
|
| 313 |
+
), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}"
|
| 314 |
+
assert T % self.temporal_patch_size == 0
|
| 315 |
+
x = self.proj(x)
|
| 316 |
+
return x
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class FinalLayer(nn.Module):
|
| 320 |
+
"""
|
| 321 |
+
The final layer of video DiT.
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
hidden_size: int,
|
| 327 |
+
spatial_patch_size: int,
|
| 328 |
+
temporal_patch_size: int,
|
| 329 |
+
out_channels: int,
|
| 330 |
+
use_adaln_lora: bool = False,
|
| 331 |
+
adaln_lora_dim: int = 256,
|
| 332 |
+
device=None, dtype=None, operations=None
|
| 333 |
+
):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 336 |
+
self.linear = operations.Linear(
|
| 337 |
+
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
| 338 |
+
)
|
| 339 |
+
self.hidden_size = hidden_size
|
| 340 |
+
self.n_adaln_chunks = 2
|
| 341 |
+
self.use_adaln_lora = use_adaln_lora
|
| 342 |
+
self.adaln_lora_dim = adaln_lora_dim
|
| 343 |
+
if use_adaln_lora:
|
| 344 |
+
self.adaln_modulation = nn.Sequential(
|
| 345 |
+
nn.SiLU(),
|
| 346 |
+
operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
| 347 |
+
operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype),
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
self.adaln_modulation = nn.Sequential(
|
| 351 |
+
nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype)
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def forward(
|
| 355 |
+
self,
|
| 356 |
+
x_B_T_H_W_D: torch.Tensor,
|
| 357 |
+
emb_B_T_D: torch.Tensor,
|
| 358 |
+
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
| 359 |
+
):
|
| 360 |
+
if self.use_adaln_lora:
|
| 361 |
+
assert adaln_lora_B_T_3D is not None
|
| 362 |
+
shift_B_T_D, scale_B_T_D = (
|
| 363 |
+
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
| 364 |
+
).chunk(2, dim=-1)
|
| 365 |
+
else:
|
| 366 |
+
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
| 367 |
+
|
| 368 |
+
shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange(
|
| 369 |
+
scale_B_T_D, "b t d -> b t 1 1 d"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def _fn(
|
| 373 |
+
_x_B_T_H_W_D: torch.Tensor,
|
| 374 |
+
_norm_layer: nn.Module,
|
| 375 |
+
_scale_B_T_1_1_D: torch.Tensor,
|
| 376 |
+
_shift_B_T_1_1_D: torch.Tensor,
|
| 377 |
+
) -> torch.Tensor:
|
| 378 |
+
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
| 379 |
+
|
| 380 |
+
x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D)
|
| 381 |
+
x_B_T_H_W_O = self.linear(x_B_T_H_W_D)
|
| 382 |
+
return x_B_T_H_W_O
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class Block(nn.Module):
|
| 386 |
+
"""
|
| 387 |
+
A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation.
|
| 388 |
+
Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation.
|
| 389 |
+
|
| 390 |
+
Parameters:
|
| 391 |
+
x_dim (int): Dimension of input features
|
| 392 |
+
context_dim (int): Dimension of context features for cross-attention
|
| 393 |
+
num_heads (int): Number of attention heads
|
| 394 |
+
mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0
|
| 395 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False
|
| 396 |
+
adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256
|
| 397 |
+
|
| 398 |
+
The block applies the following sequence:
|
| 399 |
+
1. Self-attention with AdaLN modulation
|
| 400 |
+
2. Cross-attention with AdaLN modulation
|
| 401 |
+
3. MLP with AdaLN modulation
|
| 402 |
+
|
| 403 |
+
Each component uses skip connections and layer normalization.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
x_dim: int,
|
| 409 |
+
context_dim: int,
|
| 410 |
+
num_heads: int,
|
| 411 |
+
mlp_ratio: float = 4.0,
|
| 412 |
+
use_adaln_lora: bool = False,
|
| 413 |
+
adaln_lora_dim: int = 256,
|
| 414 |
+
device=None,
|
| 415 |
+
dtype=None,
|
| 416 |
+
operations=None,
|
| 417 |
+
):
|
| 418 |
+
super().__init__()
|
| 419 |
+
self.x_dim = x_dim
|
| 420 |
+
self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
| 421 |
+
self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations)
|
| 422 |
+
|
| 423 |
+
self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
| 424 |
+
self.cross_attn = Attention(
|
| 425 |
+
x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
|
| 429 |
+
self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
| 430 |
+
|
| 431 |
+
self.use_adaln_lora = use_adaln_lora
|
| 432 |
+
if self.use_adaln_lora:
|
| 433 |
+
self.adaln_modulation_self_attn = nn.Sequential(
|
| 434 |
+
nn.SiLU(),
|
| 435 |
+
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
| 436 |
+
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
| 437 |
+
)
|
| 438 |
+
self.adaln_modulation_cross_attn = nn.Sequential(
|
| 439 |
+
nn.SiLU(),
|
| 440 |
+
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
| 441 |
+
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
| 442 |
+
)
|
| 443 |
+
self.adaln_modulation_mlp = nn.Sequential(
|
| 444 |
+
nn.SiLU(),
|
| 445 |
+
operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype),
|
| 446 |
+
operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype),
|
| 447 |
+
)
|
| 448 |
+
else:
|
| 449 |
+
self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
| 450 |
+
self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
| 451 |
+
self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype))
|
| 452 |
+
|
| 453 |
+
def forward(
|
| 454 |
+
self,
|
| 455 |
+
x_B_T_H_W_D: torch.Tensor,
|
| 456 |
+
emb_B_T_D: torch.Tensor,
|
| 457 |
+
crossattn_emb: torch.Tensor,
|
| 458 |
+
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
| 459 |
+
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
| 460 |
+
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
| 461 |
+
) -> torch.Tensor:
|
| 462 |
+
if extra_per_block_pos_emb is not None:
|
| 463 |
+
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
| 464 |
+
|
| 465 |
+
if self.use_adaln_lora:
|
| 466 |
+
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
|
| 467 |
+
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
| 468 |
+
).chunk(3, dim=-1)
|
| 469 |
+
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
| 470 |
+
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
| 471 |
+
).chunk(3, dim=-1)
|
| 472 |
+
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
| 473 |
+
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
| 474 |
+
).chunk(3, dim=-1)
|
| 475 |
+
else:
|
| 476 |
+
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
| 477 |
+
emb_B_T_D
|
| 478 |
+
).chunk(3, dim=-1)
|
| 479 |
+
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
| 480 |
+
emb_B_T_D
|
| 481 |
+
).chunk(3, dim=-1)
|
| 482 |
+
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
|
| 483 |
+
|
| 484 |
+
# Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting
|
| 485 |
+
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 486 |
+
scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 487 |
+
gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 488 |
+
|
| 489 |
+
shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 490 |
+
scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 491 |
+
gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d")
|
| 492 |
+
|
| 493 |
+
shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d")
|
| 494 |
+
scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d")
|
| 495 |
+
gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d")
|
| 496 |
+
|
| 497 |
+
B, T, H, W, D = x_B_T_H_W_D.shape
|
| 498 |
+
|
| 499 |
+
def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D):
|
| 500 |
+
return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D
|
| 501 |
+
|
| 502 |
+
normalized_x_B_T_H_W_D = _fn(
|
| 503 |
+
x_B_T_H_W_D,
|
| 504 |
+
self.layer_norm_self_attn,
|
| 505 |
+
scale_self_attn_B_T_1_1_D,
|
| 506 |
+
shift_self_attn_B_T_1_1_D,
|
| 507 |
+
)
|
| 508 |
+
result_B_T_H_W_D = rearrange(
|
| 509 |
+
self.self_attn(
|
| 510 |
+
# normalized_x_B_T_HW_D,
|
| 511 |
+
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
| 512 |
+
None,
|
| 513 |
+
rope_emb=rope_emb_L_1_1_D,
|
| 514 |
+
),
|
| 515 |
+
"b (t h w) d -> b t h w d",
|
| 516 |
+
t=T,
|
| 517 |
+
h=H,
|
| 518 |
+
w=W,
|
| 519 |
+
)
|
| 520 |
+
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
| 521 |
+
|
| 522 |
+
def _x_fn(
|
| 523 |
+
_x_B_T_H_W_D: torch.Tensor,
|
| 524 |
+
layer_norm_cross_attn: Callable,
|
| 525 |
+
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
|
| 526 |
+
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
|
| 527 |
+
) -> torch.Tensor:
|
| 528 |
+
_normalized_x_B_T_H_W_D = _fn(
|
| 529 |
+
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
|
| 530 |
+
)
|
| 531 |
+
_result_B_T_H_W_D = rearrange(
|
| 532 |
+
self.cross_attn(
|
| 533 |
+
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
| 534 |
+
crossattn_emb,
|
| 535 |
+
rope_emb=rope_emb_L_1_1_D,
|
| 536 |
+
),
|
| 537 |
+
"b (t h w) d -> b t h w d",
|
| 538 |
+
t=T,
|
| 539 |
+
h=H,
|
| 540 |
+
w=W,
|
| 541 |
+
)
|
| 542 |
+
return _result_B_T_H_W_D
|
| 543 |
+
|
| 544 |
+
result_B_T_H_W_D = _x_fn(
|
| 545 |
+
x_B_T_H_W_D,
|
| 546 |
+
self.layer_norm_cross_attn,
|
| 547 |
+
scale_cross_attn_B_T_1_1_D,
|
| 548 |
+
shift_cross_attn_B_T_1_1_D,
|
| 549 |
+
)
|
| 550 |
+
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
| 551 |
+
|
| 552 |
+
normalized_x_B_T_H_W_D = _fn(
|
| 553 |
+
x_B_T_H_W_D,
|
| 554 |
+
self.layer_norm_mlp,
|
| 555 |
+
scale_mlp_B_T_1_1_D,
|
| 556 |
+
shift_mlp_B_T_1_1_D,
|
| 557 |
+
)
|
| 558 |
+
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
| 559 |
+
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
| 560 |
+
return x_B_T_H_W_D
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class MiniTrainDIT(nn.Module):
|
| 564 |
+
"""
|
| 565 |
+
A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1)
|
| 566 |
+
A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
max_img_h (int): Maximum height of the input images.
|
| 570 |
+
max_img_w (int): Maximum width of the input images.
|
| 571 |
+
max_frames (int): Maximum number of frames in the video sequence.
|
| 572 |
+
in_channels (int): Number of input channels (e.g., RGB channels for color images).
|
| 573 |
+
out_channels (int): Number of output channels.
|
| 574 |
+
patch_spatial (tuple): Spatial resolution of patches for input processing.
|
| 575 |
+
patch_temporal (int): Temporal resolution of patches for input processing.
|
| 576 |
+
concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding.
|
| 577 |
+
model_channels (int): Base number of channels used throughout the model.
|
| 578 |
+
num_blocks (int): Number of transformer blocks.
|
| 579 |
+
num_heads (int): Number of heads in the multi-head attention layers.
|
| 580 |
+
mlp_ratio (float): Expansion ratio for MLP blocks.
|
| 581 |
+
crossattn_emb_channels (int): Number of embedding channels for cross-attention.
|
| 582 |
+
pos_emb_cls (str): Type of positional embeddings.
|
| 583 |
+
pos_emb_learnable (bool): Whether positional embeddings are learnable.
|
| 584 |
+
pos_emb_interpolation (str): Method for interpolating positional embeddings.
|
| 585 |
+
min_fps (int): Minimum frames per second.
|
| 586 |
+
max_fps (int): Maximum frames per second.
|
| 587 |
+
use_adaln_lora (bool): Whether to use AdaLN-LoRA.
|
| 588 |
+
adaln_lora_dim (int): Dimension for AdaLN-LoRA.
|
| 589 |
+
rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE.
|
| 590 |
+
rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE.
|
| 591 |
+
rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE.
|
| 592 |
+
extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings.
|
| 593 |
+
extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings.
|
| 594 |
+
extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings.
|
| 595 |
+
extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings.
|
| 596 |
+
"""
|
| 597 |
+
|
| 598 |
+
def __init__(
|
| 599 |
+
self,
|
| 600 |
+
max_img_h: int,
|
| 601 |
+
max_img_w: int,
|
| 602 |
+
max_frames: int,
|
| 603 |
+
in_channels: int,
|
| 604 |
+
out_channels: int,
|
| 605 |
+
patch_spatial: int, # tuple,
|
| 606 |
+
patch_temporal: int,
|
| 607 |
+
concat_padding_mask: bool = True,
|
| 608 |
+
# attention settings
|
| 609 |
+
model_channels: int = 768,
|
| 610 |
+
num_blocks: int = 10,
|
| 611 |
+
num_heads: int = 16,
|
| 612 |
+
mlp_ratio: float = 4.0,
|
| 613 |
+
# cross attention settings
|
| 614 |
+
crossattn_emb_channels: int = 1024,
|
| 615 |
+
# positional embedding settings
|
| 616 |
+
pos_emb_cls: str = "sincos",
|
| 617 |
+
pos_emb_learnable: bool = False,
|
| 618 |
+
pos_emb_interpolation: str = "crop",
|
| 619 |
+
min_fps: int = 1,
|
| 620 |
+
max_fps: int = 30,
|
| 621 |
+
use_adaln_lora: bool = False,
|
| 622 |
+
adaln_lora_dim: int = 256,
|
| 623 |
+
rope_h_extrapolation_ratio: float = 1.0,
|
| 624 |
+
rope_w_extrapolation_ratio: float = 1.0,
|
| 625 |
+
rope_t_extrapolation_ratio: float = 1.0,
|
| 626 |
+
extra_per_block_abs_pos_emb: bool = False,
|
| 627 |
+
extra_h_extrapolation_ratio: float = 1.0,
|
| 628 |
+
extra_w_extrapolation_ratio: float = 1.0,
|
| 629 |
+
extra_t_extrapolation_ratio: float = 1.0,
|
| 630 |
+
rope_enable_fps_modulation: bool = True,
|
| 631 |
+
image_model=None,
|
| 632 |
+
device=None,
|
| 633 |
+
dtype=None,
|
| 634 |
+
operations=None,
|
| 635 |
+
) -> None:
|
| 636 |
+
super().__init__()
|
| 637 |
+
self.dtype = dtype
|
| 638 |
+
self.max_img_h = max_img_h
|
| 639 |
+
self.max_img_w = max_img_w
|
| 640 |
+
self.max_frames = max_frames
|
| 641 |
+
self.in_channels = in_channels
|
| 642 |
+
self.out_channels = out_channels
|
| 643 |
+
self.patch_spatial = patch_spatial
|
| 644 |
+
self.patch_temporal = patch_temporal
|
| 645 |
+
self.num_heads = num_heads
|
| 646 |
+
self.num_blocks = num_blocks
|
| 647 |
+
self.model_channels = model_channels
|
| 648 |
+
self.concat_padding_mask = concat_padding_mask
|
| 649 |
+
# positional embedding settings
|
| 650 |
+
self.pos_emb_cls = pos_emb_cls
|
| 651 |
+
self.pos_emb_learnable = pos_emb_learnable
|
| 652 |
+
self.pos_emb_interpolation = pos_emb_interpolation
|
| 653 |
+
self.min_fps = min_fps
|
| 654 |
+
self.max_fps = max_fps
|
| 655 |
+
self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio
|
| 656 |
+
self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio
|
| 657 |
+
self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio
|
| 658 |
+
self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb
|
| 659 |
+
self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio
|
| 660 |
+
self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio
|
| 661 |
+
self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio
|
| 662 |
+
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
| 663 |
+
|
| 664 |
+
self.build_pos_embed(device=device, dtype=dtype)
|
| 665 |
+
self.use_adaln_lora = use_adaln_lora
|
| 666 |
+
self.adaln_lora_dim = adaln_lora_dim
|
| 667 |
+
self.t_embedder = nn.Sequential(
|
| 668 |
+
Timesteps(model_channels),
|
| 669 |
+
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,),
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
| 673 |
+
self.x_embedder = PatchEmbed(
|
| 674 |
+
spatial_patch_size=patch_spatial,
|
| 675 |
+
temporal_patch_size=patch_temporal,
|
| 676 |
+
in_channels=in_channels,
|
| 677 |
+
out_channels=model_channels,
|
| 678 |
+
device=device, dtype=dtype, operations=operations,
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
self.blocks = nn.ModuleList(
|
| 682 |
+
[
|
| 683 |
+
Block(
|
| 684 |
+
x_dim=model_channels,
|
| 685 |
+
context_dim=crossattn_emb_channels,
|
| 686 |
+
num_heads=num_heads,
|
| 687 |
+
mlp_ratio=mlp_ratio,
|
| 688 |
+
use_adaln_lora=use_adaln_lora,
|
| 689 |
+
adaln_lora_dim=adaln_lora_dim,
|
| 690 |
+
device=device, dtype=dtype, operations=operations,
|
| 691 |
+
)
|
| 692 |
+
for _ in range(num_blocks)
|
| 693 |
+
]
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
self.final_layer = FinalLayer(
|
| 697 |
+
hidden_size=self.model_channels,
|
| 698 |
+
spatial_patch_size=self.patch_spatial,
|
| 699 |
+
temporal_patch_size=self.patch_temporal,
|
| 700 |
+
out_channels=self.out_channels,
|
| 701 |
+
use_adaln_lora=self.use_adaln_lora,
|
| 702 |
+
adaln_lora_dim=self.adaln_lora_dim,
|
| 703 |
+
device=device, dtype=dtype, operations=operations,
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype)
|
| 707 |
+
|
| 708 |
+
def build_pos_embed(self, device=None, dtype=None) -> None:
|
| 709 |
+
if self.pos_emb_cls == "rope3d":
|
| 710 |
+
cls_type = VideoRopePosition3DEmb
|
| 711 |
+
else:
|
| 712 |
+
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
| 713 |
+
|
| 714 |
+
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
| 715 |
+
kwargs = dict(
|
| 716 |
+
model_channels=self.model_channels,
|
| 717 |
+
len_h=self.max_img_h // self.patch_spatial,
|
| 718 |
+
len_w=self.max_img_w // self.patch_spatial,
|
| 719 |
+
len_t=self.max_frames // self.patch_temporal,
|
| 720 |
+
max_fps=self.max_fps,
|
| 721 |
+
min_fps=self.min_fps,
|
| 722 |
+
is_learnable=self.pos_emb_learnable,
|
| 723 |
+
interpolation=self.pos_emb_interpolation,
|
| 724 |
+
head_dim=self.model_channels // self.num_heads,
|
| 725 |
+
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
|
| 726 |
+
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
|
| 727 |
+
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
|
| 728 |
+
enable_fps_modulation=self.rope_enable_fps_modulation,
|
| 729 |
+
device=device,
|
| 730 |
+
)
|
| 731 |
+
self.pos_embedder = cls_type(
|
| 732 |
+
**kwargs, # type: ignore
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
if self.extra_per_block_abs_pos_emb:
|
| 736 |
+
kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio
|
| 737 |
+
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
|
| 738 |
+
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
|
| 739 |
+
kwargs["device"] = device
|
| 740 |
+
kwargs["dtype"] = dtype
|
| 741 |
+
self.extra_pos_embedder = LearnablePosEmbAxis(
|
| 742 |
+
**kwargs, # type: ignore
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
def prepare_embedded_sequence(
|
| 746 |
+
self,
|
| 747 |
+
x_B_C_T_H_W: torch.Tensor,
|
| 748 |
+
fps: Optional[torch.Tensor] = None,
|
| 749 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 750 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 751 |
+
"""
|
| 752 |
+
Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
|
| 753 |
+
|
| 754 |
+
Args:
|
| 755 |
+
x_B_C_T_H_W (torch.Tensor): video
|
| 756 |
+
fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
|
| 757 |
+
If None, a default value (`self.base_fps`) will be used.
|
| 758 |
+
padding_mask (Optional[torch.Tensor]): current it is not used
|
| 759 |
+
|
| 760 |
+
Returns:
|
| 761 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 762 |
+
- A tensor of shape (B, T, H, W, D) with the embedded sequence.
|
| 763 |
+
- An optional positional embedding tensor, returned only if the positional embedding class
|
| 764 |
+
(`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
|
| 765 |
+
|
| 766 |
+
Notes:
|
| 767 |
+
- If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
|
| 768 |
+
- The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
|
| 769 |
+
- If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
|
| 770 |
+
the `self.pos_embedder` with the shape [T, H, W].
|
| 771 |
+
- If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the
|
| 772 |
+
`self.pos_embedder` with the fps tensor.
|
| 773 |
+
- Otherwise, the positional embeddings are generated without considering fps.
|
| 774 |
+
"""
|
| 775 |
+
if self.concat_padding_mask:
|
| 776 |
+
if padding_mask is None:
|
| 777 |
+
padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device)
|
| 778 |
+
else:
|
| 779 |
+
padding_mask = transforms.functional.resize(
|
| 780 |
+
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
| 781 |
+
)
|
| 782 |
+
x_B_C_T_H_W = torch.cat(
|
| 783 |
+
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
| 784 |
+
)
|
| 785 |
+
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
| 786 |
+
|
| 787 |
+
if self.extra_per_block_abs_pos_emb:
|
| 788 |
+
extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
|
| 789 |
+
else:
|
| 790 |
+
extra_pos_emb = None
|
| 791 |
+
|
| 792 |
+
if "rope" in self.pos_emb_cls.lower():
|
| 793 |
+
return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb
|
| 794 |
+
x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D]
|
| 795 |
+
|
| 796 |
+
return x_B_T_H_W_D, None, extra_pos_emb
|
| 797 |
+
|
| 798 |
+
def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor:
|
| 799 |
+
x_B_C_Tt_Hp_Wp = rearrange(
|
| 800 |
+
x_B_T_H_W_M,
|
| 801 |
+
"B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)",
|
| 802 |
+
p1=self.patch_spatial,
|
| 803 |
+
p2=self.patch_spatial,
|
| 804 |
+
t=self.patch_temporal,
|
| 805 |
+
)
|
| 806 |
+
return x_B_C_Tt_Hp_Wp
|
| 807 |
+
|
| 808 |
+
def forward(
|
| 809 |
+
self,
|
| 810 |
+
x: torch.Tensor,
|
| 811 |
+
timesteps: torch.Tensor,
|
| 812 |
+
context: torch.Tensor,
|
| 813 |
+
fps: Optional[torch.Tensor] = None,
|
| 814 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 815 |
+
**kwargs,
|
| 816 |
+
):
|
| 817 |
+
x_B_C_T_H_W = x
|
| 818 |
+
timesteps_B_T = timesteps
|
| 819 |
+
crossattn_emb = context
|
| 820 |
+
"""
|
| 821 |
+
Args:
|
| 822 |
+
x: (B, C, T, H, W) tensor of spatial-temp inputs
|
| 823 |
+
timesteps: (B, ) tensor of timesteps
|
| 824 |
+
crossattn_emb: (B, N, D) tensor of cross-attention embeddings
|
| 825 |
+
"""
|
| 826 |
+
x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence(
|
| 827 |
+
x_B_C_T_H_W,
|
| 828 |
+
fps=fps,
|
| 829 |
+
padding_mask=padding_mask,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
if timesteps_B_T.ndim == 1:
|
| 833 |
+
timesteps_B_T = timesteps_B_T.unsqueeze(1)
|
| 834 |
+
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype))
|
| 835 |
+
t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D)
|
| 836 |
+
|
| 837 |
+
# for logging purpose
|
| 838 |
+
affline_scale_log_info = {}
|
| 839 |
+
affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach()
|
| 840 |
+
self.affline_scale_log_info = affline_scale_log_info
|
| 841 |
+
self.affline_emb = t_embedding_B_T_D
|
| 842 |
+
self.crossattn_emb = crossattn_emb
|
| 843 |
+
|
| 844 |
+
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
|
| 845 |
+
assert (
|
| 846 |
+
x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
|
| 847 |
+
), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}"
|
| 848 |
+
|
| 849 |
+
block_kwargs = {
|
| 850 |
+
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
|
| 851 |
+
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
|
| 852 |
+
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
| 853 |
+
}
|
| 854 |
+
for block in self.blocks:
|
| 855 |
+
x_B_T_H_W_D = block(
|
| 856 |
+
x_B_T_H_W_D,
|
| 857 |
+
t_embedding_B_T_D,
|
| 858 |
+
crossattn_emb,
|
| 859 |
+
**block_kwargs,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
| 863 |
+
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
| 864 |
+
return x_B_C_Tt_Hp_Wp
|
ComfyUI/comfy/ldm/cosmos/vae.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
from enum import Enum
|
| 21 |
+
import math
|
| 22 |
+
|
| 23 |
+
from .cosmos_tokenizer.layers3d import (
|
| 24 |
+
EncoderFactorized,
|
| 25 |
+
DecoderFactorized,
|
| 26 |
+
CausalConv3d,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class IdentityDistribution(torch.nn.Module):
|
| 31 |
+
def __init__(self):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
def forward(self, parameters):
|
| 35 |
+
return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class GaussianDistribution(torch.nn.Module):
|
| 39 |
+
def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.min_logvar = min_logvar
|
| 42 |
+
self.max_logvar = max_logvar
|
| 43 |
+
|
| 44 |
+
def sample(self, mean, logvar):
|
| 45 |
+
std = torch.exp(0.5 * logvar)
|
| 46 |
+
return mean + std * torch.randn_like(mean)
|
| 47 |
+
|
| 48 |
+
def forward(self, parameters):
|
| 49 |
+
mean, logvar = torch.chunk(parameters, 2, dim=1)
|
| 50 |
+
logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
|
| 51 |
+
return self.sample(mean, logvar), (mean, logvar)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ContinuousFormulation(Enum):
|
| 55 |
+
VAE = GaussianDistribution
|
| 56 |
+
AE = IdentityDistribution
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class CausalContinuousVideoTokenizer(nn.Module):
|
| 60 |
+
def __init__(
|
| 61 |
+
self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
|
| 62 |
+
) -> None:
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
|
| 65 |
+
self.latent_channels = latent_channels
|
| 66 |
+
self.sigma_data = 0.5
|
| 67 |
+
|
| 68 |
+
# encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
|
| 69 |
+
self.encoder = EncoderFactorized(
|
| 70 |
+
z_channels=z_factor * z_channels, **kwargs
|
| 71 |
+
)
|
| 72 |
+
if kwargs.get("temporal_compression", 4) == 4:
|
| 73 |
+
kwargs["channels_mult"] = [2, 4]
|
| 74 |
+
# decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
|
| 75 |
+
self.decoder = DecoderFactorized(
|
| 76 |
+
z_channels=z_channels, **kwargs
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.quant_conv = CausalConv3d(
|
| 80 |
+
z_factor * z_channels,
|
| 81 |
+
z_factor * latent_channels,
|
| 82 |
+
kernel_size=1,
|
| 83 |
+
padding=0,
|
| 84 |
+
)
|
| 85 |
+
self.post_quant_conv = CausalConv3d(
|
| 86 |
+
latent_channels, z_channels, kernel_size=1, padding=0
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
|
| 90 |
+
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
| 91 |
+
|
| 92 |
+
num_parameters = sum(param.numel() for param in self.parameters())
|
| 93 |
+
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
| 94 |
+
logging.debug(
|
| 95 |
+
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
latent_temporal_chunk = 16
|
| 99 |
+
self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
| 100 |
+
self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def encode(self, x):
|
| 104 |
+
h = self.encoder(x)
|
| 105 |
+
moments = self.quant_conv(h)
|
| 106 |
+
z, posteriors = self.distribution(moments)
|
| 107 |
+
latent_ch = z.shape[1]
|
| 108 |
+
latent_t = z.shape[2]
|
| 109 |
+
in_dtype = z.dtype
|
| 110 |
+
mean = self.latent_mean.view(latent_ch, -1)
|
| 111 |
+
std = self.latent_std.view(latent_ch, -1)
|
| 112 |
+
|
| 113 |
+
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
| 114 |
+
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
| 115 |
+
return ((z - mean) / std) * self.sigma_data
|
| 116 |
+
|
| 117 |
+
def decode(self, z):
|
| 118 |
+
in_dtype = z.dtype
|
| 119 |
+
latent_ch = z.shape[1]
|
| 120 |
+
latent_t = z.shape[2]
|
| 121 |
+
mean = self.latent_mean.view(latent_ch, -1)
|
| 122 |
+
std = self.latent_std.view(latent_ch, -1)
|
| 123 |
+
|
| 124 |
+
mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
| 125 |
+
std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
|
| 126 |
+
|
| 127 |
+
z = z / self.sigma_data
|
| 128 |
+
z = z * std + mean
|
| 129 |
+
z = self.post_quant_conv(z)
|
| 130 |
+
return self.decoder(z)
|
| 131 |
+
|
ComfyUI/comfy/ldm/flux/controlnet.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
| 2 |
+
#modified to support different types of flux controlnets
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import math
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
|
| 9 |
+
from .layers import (timestep_embedding)
|
| 10 |
+
|
| 11 |
+
from .model import Flux
|
| 12 |
+
import comfy.ldm.common_dit
|
| 13 |
+
|
| 14 |
+
class MistolineCondDownsamplBlock(nn.Module):
|
| 15 |
+
def __init__(self, dtype=None, device=None, operations=None):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.encoder = nn.Sequential(
|
| 18 |
+
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
| 19 |
+
nn.SiLU(),
|
| 20 |
+
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
| 21 |
+
nn.SiLU(),
|
| 22 |
+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
| 23 |
+
nn.SiLU(),
|
| 24 |
+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
| 25 |
+
nn.SiLU(),
|
| 26 |
+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
| 27 |
+
nn.SiLU(),
|
| 28 |
+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
| 31 |
+
nn.SiLU(),
|
| 32 |
+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
| 33 |
+
nn.SiLU(),
|
| 34 |
+
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
| 35 |
+
nn.SiLU(),
|
| 36 |
+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return self.encoder(x)
|
| 41 |
+
|
| 42 |
+
class MistolineControlnetBlock(nn.Module):
|
| 43 |
+
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
| 46 |
+
self.act = nn.SiLU()
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return self.act(self.linear(x))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ControlNetFlux(Flux):
|
| 53 |
+
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
| 54 |
+
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
| 55 |
+
|
| 56 |
+
self.main_model_double = 19
|
| 57 |
+
self.main_model_single = 38
|
| 58 |
+
|
| 59 |
+
self.mistoline = mistoline
|
| 60 |
+
# add ControlNet blocks
|
| 61 |
+
if self.mistoline:
|
| 62 |
+
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
| 63 |
+
else:
|
| 64 |
+
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
| 65 |
+
|
| 66 |
+
self.controlnet_blocks = nn.ModuleList([])
|
| 67 |
+
for _ in range(self.params.depth):
|
| 68 |
+
self.controlnet_blocks.append(control_block())
|
| 69 |
+
|
| 70 |
+
self.controlnet_single_blocks = nn.ModuleList([])
|
| 71 |
+
for _ in range(self.params.depth_single_blocks):
|
| 72 |
+
self.controlnet_single_blocks.append(control_block())
|
| 73 |
+
|
| 74 |
+
self.num_union_modes = num_union_modes
|
| 75 |
+
self.controlnet_mode_embedder = None
|
| 76 |
+
if self.num_union_modes > 0:
|
| 77 |
+
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
| 78 |
+
|
| 79 |
+
self.gradient_checkpointing = False
|
| 80 |
+
self.latent_input = latent_input
|
| 81 |
+
if control_latent_channels is None:
|
| 82 |
+
control_latent_channels = self.in_channels
|
| 83 |
+
else:
|
| 84 |
+
control_latent_channels *= 2 * 2 #patch size
|
| 85 |
+
|
| 86 |
+
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
| 87 |
+
if not self.latent_input:
|
| 88 |
+
if self.mistoline:
|
| 89 |
+
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
| 90 |
+
else:
|
| 91 |
+
self.input_hint_block = nn.Sequential(
|
| 92 |
+
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
| 93 |
+
nn.SiLU(),
|
| 94 |
+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
| 95 |
+
nn.SiLU(),
|
| 96 |
+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
| 97 |
+
nn.SiLU(),
|
| 98 |
+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
| 99 |
+
nn.SiLU(),
|
| 100 |
+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
| 101 |
+
nn.SiLU(),
|
| 102 |
+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
| 103 |
+
nn.SiLU(),
|
| 104 |
+
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
| 105 |
+
nn.SiLU(),
|
| 106 |
+
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def forward_orig(
|
| 110 |
+
self,
|
| 111 |
+
img: Tensor,
|
| 112 |
+
img_ids: Tensor,
|
| 113 |
+
controlnet_cond: Tensor,
|
| 114 |
+
txt: Tensor,
|
| 115 |
+
txt_ids: Tensor,
|
| 116 |
+
timesteps: Tensor,
|
| 117 |
+
y: Tensor,
|
| 118 |
+
guidance: Tensor = None,
|
| 119 |
+
control_type: Tensor = None,
|
| 120 |
+
) -> Tensor:
|
| 121 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 122 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 123 |
+
|
| 124 |
+
if y is None:
|
| 125 |
+
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
| 126 |
+
else:
|
| 127 |
+
y = y[:, :self.params.vec_in_dim]
|
| 128 |
+
|
| 129 |
+
# running on sequences img
|
| 130 |
+
img = self.img_in(img)
|
| 131 |
+
|
| 132 |
+
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
| 133 |
+
img = img + controlnet_cond
|
| 134 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 135 |
+
if self.params.guidance_embed:
|
| 136 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 137 |
+
vec = vec + self.vector_in(y)
|
| 138 |
+
txt = self.txt_in(txt)
|
| 139 |
+
|
| 140 |
+
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
| 141 |
+
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
| 142 |
+
txt = torch.cat([control_cond, txt], dim=1)
|
| 143 |
+
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
| 144 |
+
|
| 145 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 146 |
+
pe = self.pe_embedder(ids)
|
| 147 |
+
|
| 148 |
+
controlnet_double = ()
|
| 149 |
+
|
| 150 |
+
for i in range(len(self.double_blocks)):
|
| 151 |
+
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
| 152 |
+
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
| 153 |
+
|
| 154 |
+
img = torch.cat((txt, img), 1)
|
| 155 |
+
|
| 156 |
+
controlnet_single = ()
|
| 157 |
+
|
| 158 |
+
for i in range(len(self.single_blocks)):
|
| 159 |
+
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
| 160 |
+
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
| 161 |
+
|
| 162 |
+
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
| 163 |
+
if self.latent_input:
|
| 164 |
+
out_input = ()
|
| 165 |
+
for x in controlnet_double:
|
| 166 |
+
out_input += (x,) * repeat
|
| 167 |
+
else:
|
| 168 |
+
out_input = (controlnet_double * repeat)
|
| 169 |
+
|
| 170 |
+
out = {"input": out_input[:self.main_model_double]}
|
| 171 |
+
if len(controlnet_single) > 0:
|
| 172 |
+
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
| 173 |
+
out_output = ()
|
| 174 |
+
if self.latent_input:
|
| 175 |
+
for x in controlnet_single:
|
| 176 |
+
out_output += (x,) * repeat
|
| 177 |
+
else:
|
| 178 |
+
out_output = (controlnet_single * repeat)
|
| 179 |
+
out["output"] = out_output[:self.main_model_single]
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
|
| 183 |
+
patch_size = 2
|
| 184 |
+
if self.latent_input:
|
| 185 |
+
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
| 186 |
+
elif self.mistoline:
|
| 187 |
+
hint = hint * 2.0 - 1.0
|
| 188 |
+
hint = self.input_cond_block(hint)
|
| 189 |
+
else:
|
| 190 |
+
hint = hint * 2.0 - 1.0
|
| 191 |
+
hint = self.input_hint_block(hint)
|
| 192 |
+
|
| 193 |
+
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
| 194 |
+
|
| 195 |
+
bs, c, h, w = x.shape
|
| 196 |
+
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
| 197 |
+
|
| 198 |
+
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
| 199 |
+
|
| 200 |
+
h_len = ((h + (patch_size // 2)) // patch_size)
|
| 201 |
+
w_len = ((w + (patch_size // 2)) // patch_size)
|
| 202 |
+
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
| 203 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
| 204 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
| 205 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
| 206 |
+
|
| 207 |
+
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
| 208 |
+
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
ComfyUI/comfy/ldm/flux/layers.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
from .math import attention, rope
|
| 8 |
+
import comfy.ops
|
| 9 |
+
import comfy.ldm.common_dit
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EmbedND(nn.Module):
|
| 13 |
+
def __init__(self, dim: int, theta: int, axes_dim: list):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.dim = dim
|
| 16 |
+
self.theta = theta
|
| 17 |
+
self.axes_dim = axes_dim
|
| 18 |
+
|
| 19 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 20 |
+
n_axes = ids.shape[-1]
|
| 21 |
+
emb = torch.cat(
|
| 22 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 23 |
+
dim=-3,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
return emb.unsqueeze(1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 30 |
+
"""
|
| 31 |
+
Create sinusoidal timestep embeddings.
|
| 32 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 33 |
+
These may be fractional.
|
| 34 |
+
:param dim: the dimension of the output.
|
| 35 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 36 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 37 |
+
"""
|
| 38 |
+
t = time_factor * t
|
| 39 |
+
half = dim // 2
|
| 40 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
| 41 |
+
|
| 42 |
+
args = t[:, None].float() * freqs[None]
|
| 43 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 44 |
+
if dim % 2:
|
| 45 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 46 |
+
if torch.is_floating_point(t):
|
| 47 |
+
embedding = embedding.to(t)
|
| 48 |
+
return embedding
|
| 49 |
+
|
| 50 |
+
class MLPEmbedder(nn.Module):
|
| 51 |
+
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
| 54 |
+
self.silu = nn.SiLU()
|
| 55 |
+
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
| 56 |
+
|
| 57 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 58 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class RMSNorm(torch.nn.Module):
|
| 62 |
+
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
| 65 |
+
|
| 66 |
+
def forward(self, x: Tensor):
|
| 67 |
+
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class QKNorm(torch.nn.Module):
|
| 71 |
+
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
| 74 |
+
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
| 75 |
+
|
| 76 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
| 77 |
+
q = self.query_norm(q)
|
| 78 |
+
k = self.key_norm(k)
|
| 79 |
+
return q.to(v), k.to(v)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class SelfAttention(nn.Module):
|
| 83 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.num_heads = num_heads
|
| 86 |
+
head_dim = dim // num_heads
|
| 87 |
+
|
| 88 |
+
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
| 89 |
+
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
| 90 |
+
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class ModulationOut:
|
| 95 |
+
shift: Tensor
|
| 96 |
+
scale: Tensor
|
| 97 |
+
gate: Tensor
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Modulation(nn.Module):
|
| 101 |
+
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.is_double = double
|
| 104 |
+
self.multiplier = 6 if double else 3
|
| 105 |
+
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
| 106 |
+
|
| 107 |
+
def forward(self, vec: Tensor) -> tuple:
|
| 108 |
+
if vec.ndim == 2:
|
| 109 |
+
vec = vec[:, None, :]
|
| 110 |
+
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
| 111 |
+
|
| 112 |
+
return (
|
| 113 |
+
ModulationOut(*out[:3]),
|
| 114 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
| 119 |
+
if modulation_dims is None:
|
| 120 |
+
if m_add is not None:
|
| 121 |
+
return torch.addcmul(m_add, tensor, m_mult)
|
| 122 |
+
else:
|
| 123 |
+
return tensor * m_mult
|
| 124 |
+
else:
|
| 125 |
+
for d in modulation_dims:
|
| 126 |
+
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
| 127 |
+
if m_add is not None:
|
| 128 |
+
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
| 129 |
+
return tensor
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class DoubleStreamBlock(nn.Module):
|
| 133 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
| 134 |
+
super().__init__()
|
| 135 |
+
|
| 136 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 137 |
+
self.num_heads = num_heads
|
| 138 |
+
self.hidden_size = hidden_size
|
| 139 |
+
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
| 140 |
+
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 141 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
| 142 |
+
|
| 143 |
+
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 144 |
+
self.img_mlp = nn.Sequential(
|
| 145 |
+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
| 146 |
+
nn.GELU(approximate="tanh"),
|
| 147 |
+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
| 151 |
+
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 152 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
| 153 |
+
|
| 154 |
+
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 155 |
+
self.txt_mlp = nn.Sequential(
|
| 156 |
+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
| 157 |
+
nn.GELU(approximate="tanh"),
|
| 158 |
+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
| 159 |
+
)
|
| 160 |
+
self.flipped_img_txt = flipped_img_txt
|
| 161 |
+
|
| 162 |
+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
| 163 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
| 164 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
| 165 |
+
|
| 166 |
+
# prepare image for attention
|
| 167 |
+
img_modulated = self.img_norm1(img)
|
| 168 |
+
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
| 169 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 170 |
+
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 171 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 172 |
+
|
| 173 |
+
# prepare txt for attention
|
| 174 |
+
txt_modulated = self.txt_norm1(txt)
|
| 175 |
+
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
| 176 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 177 |
+
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 178 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 179 |
+
|
| 180 |
+
if self.flipped_img_txt:
|
| 181 |
+
# run actual attention
|
| 182 |
+
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
| 183 |
+
torch.cat((img_k, txt_k), dim=2),
|
| 184 |
+
torch.cat((img_v, txt_v), dim=2),
|
| 185 |
+
pe=pe, mask=attn_mask)
|
| 186 |
+
|
| 187 |
+
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
| 188 |
+
else:
|
| 189 |
+
# run actual attention
|
| 190 |
+
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
| 191 |
+
torch.cat((txt_k, img_k), dim=2),
|
| 192 |
+
torch.cat((txt_v, img_v), dim=2),
|
| 193 |
+
pe=pe, mask=attn_mask)
|
| 194 |
+
|
| 195 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
| 196 |
+
|
| 197 |
+
# calculate the img bloks
|
| 198 |
+
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
| 199 |
+
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
| 200 |
+
|
| 201 |
+
# calculate the txt bloks
|
| 202 |
+
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
| 203 |
+
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
| 204 |
+
|
| 205 |
+
if txt.dtype == torch.float16:
|
| 206 |
+
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
| 207 |
+
|
| 208 |
+
return img, txt
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class SingleStreamBlock(nn.Module):
|
| 212 |
+
"""
|
| 213 |
+
A DiT block with parallel linear layers as described in
|
| 214 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
hidden_size: int,
|
| 220 |
+
num_heads: int,
|
| 221 |
+
mlp_ratio: float = 4.0,
|
| 222 |
+
qk_scale: float = None,
|
| 223 |
+
dtype=None,
|
| 224 |
+
device=None,
|
| 225 |
+
operations=None
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.hidden_dim = hidden_size
|
| 229 |
+
self.num_heads = num_heads
|
| 230 |
+
head_dim = hidden_size // num_heads
|
| 231 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 232 |
+
|
| 233 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 234 |
+
# qkv and mlp_in
|
| 235 |
+
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
| 236 |
+
# proj and mlp_out
|
| 237 |
+
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
| 238 |
+
|
| 239 |
+
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
| 240 |
+
|
| 241 |
+
self.hidden_size = hidden_size
|
| 242 |
+
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 243 |
+
|
| 244 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 245 |
+
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
| 246 |
+
|
| 247 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
| 248 |
+
mod, _ = self.modulation(vec)
|
| 249 |
+
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
| 250 |
+
|
| 251 |
+
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 252 |
+
q, k = self.norm(q, k, v)
|
| 253 |
+
|
| 254 |
+
# compute attention
|
| 255 |
+
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
| 256 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 257 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 258 |
+
x += apply_mod(output, mod.gate, None, modulation_dims)
|
| 259 |
+
if x.dtype == torch.float16:
|
| 260 |
+
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
| 261 |
+
return x
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class LastLayer(nn.Module):
|
| 265 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
| 266 |
+
super().__init__()
|
| 267 |
+
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 268 |
+
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
| 269 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
| 270 |
+
|
| 271 |
+
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
| 272 |
+
if vec.ndim == 2:
|
| 273 |
+
vec = vec[:, None, :]
|
| 274 |
+
|
| 275 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
|
| 276 |
+
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
|
| 277 |
+
x = self.linear(x)
|
| 278 |
+
return x
|
ComfyUI/comfy/ldm/flux/math.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 6 |
+
import comfy.model_management
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
| 10 |
+
q_shape = q.shape
|
| 11 |
+
k_shape = k.shape
|
| 12 |
+
|
| 13 |
+
if pe is not None:
|
| 14 |
+
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
| 15 |
+
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
| 16 |
+
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
| 17 |
+
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
| 18 |
+
|
| 19 |
+
heads = q.shape[1]
|
| 20 |
+
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 25 |
+
assert dim % 2 == 0
|
| 26 |
+
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
| 27 |
+
device = torch.device("cpu")
|
| 28 |
+
else:
|
| 29 |
+
device = pos.device
|
| 30 |
+
|
| 31 |
+
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
| 32 |
+
omega = 1.0 / (theta**scale)
|
| 33 |
+
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
| 34 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
| 35 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 36 |
+
return out.to(dtype=torch.float32, device=pos.device)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
| 40 |
+
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
| 41 |
+
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
| 42 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 43 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 44 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
| 45 |
+
|
ComfyUI/comfy/ldm/flux/model.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Original code can be found on: https://github.com/black-forest-labs/flux
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
import comfy.ldm.common_dit
|
| 9 |
+
|
| 10 |
+
from .layers import (
|
| 11 |
+
DoubleStreamBlock,
|
| 12 |
+
EmbedND,
|
| 13 |
+
LastLayer,
|
| 14 |
+
MLPEmbedder,
|
| 15 |
+
SingleStreamBlock,
|
| 16 |
+
timestep_embedding,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class FluxParams:
|
| 21 |
+
in_channels: int
|
| 22 |
+
out_channels: int
|
| 23 |
+
vec_in_dim: int
|
| 24 |
+
context_in_dim: int
|
| 25 |
+
hidden_size: int
|
| 26 |
+
mlp_ratio: float
|
| 27 |
+
num_heads: int
|
| 28 |
+
depth: int
|
| 29 |
+
depth_single_blocks: int
|
| 30 |
+
axes_dim: list
|
| 31 |
+
theta: int
|
| 32 |
+
patch_size: int
|
| 33 |
+
qkv_bias: bool
|
| 34 |
+
guidance_embed: bool
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Flux(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Transformer model for flow matching on sequences.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.dtype = dtype
|
| 45 |
+
params = FluxParams(**kwargs)
|
| 46 |
+
self.params = params
|
| 47 |
+
self.patch_size = params.patch_size
|
| 48 |
+
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
| 49 |
+
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
| 50 |
+
if params.hidden_size % params.num_heads != 0:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
| 53 |
+
)
|
| 54 |
+
pe_dim = params.hidden_size // params.num_heads
|
| 55 |
+
if sum(params.axes_dim) != pe_dim:
|
| 56 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
| 57 |
+
self.hidden_size = params.hidden_size
|
| 58 |
+
self.num_heads = params.num_heads
|
| 59 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
| 60 |
+
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
| 61 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
| 62 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
| 63 |
+
self.guidance_in = (
|
| 64 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
| 65 |
+
)
|
| 66 |
+
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
| 67 |
+
|
| 68 |
+
self.double_blocks = nn.ModuleList(
|
| 69 |
+
[
|
| 70 |
+
DoubleStreamBlock(
|
| 71 |
+
self.hidden_size,
|
| 72 |
+
self.num_heads,
|
| 73 |
+
mlp_ratio=params.mlp_ratio,
|
| 74 |
+
qkv_bias=params.qkv_bias,
|
| 75 |
+
dtype=dtype, device=device, operations=operations
|
| 76 |
+
)
|
| 77 |
+
for _ in range(params.depth)
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.single_blocks = nn.ModuleList(
|
| 82 |
+
[
|
| 83 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
| 84 |
+
for _ in range(params.depth_single_blocks)
|
| 85 |
+
]
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if final_layer:
|
| 89 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
| 90 |
+
|
| 91 |
+
def forward_orig(
|
| 92 |
+
self,
|
| 93 |
+
img: Tensor,
|
| 94 |
+
img_ids: Tensor,
|
| 95 |
+
txt: Tensor,
|
| 96 |
+
txt_ids: Tensor,
|
| 97 |
+
timesteps: Tensor,
|
| 98 |
+
y: Tensor,
|
| 99 |
+
guidance: Tensor = None,
|
| 100 |
+
control = None,
|
| 101 |
+
transformer_options={},
|
| 102 |
+
attn_mask: Tensor = None,
|
| 103 |
+
) -> Tensor:
|
| 104 |
+
|
| 105 |
+
if y is None:
|
| 106 |
+
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
| 107 |
+
|
| 108 |
+
patches_replace = transformer_options.get("patches_replace", {})
|
| 109 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 110 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 111 |
+
|
| 112 |
+
# running on sequences img
|
| 113 |
+
img = self.img_in(img)
|
| 114 |
+
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
| 115 |
+
if self.params.guidance_embed:
|
| 116 |
+
if guidance is not None:
|
| 117 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
| 118 |
+
|
| 119 |
+
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
| 120 |
+
txt = self.txt_in(txt)
|
| 121 |
+
|
| 122 |
+
if img_ids is not None:
|
| 123 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 124 |
+
pe = self.pe_embedder(ids)
|
| 125 |
+
else:
|
| 126 |
+
pe = None
|
| 127 |
+
|
| 128 |
+
blocks_replace = patches_replace.get("dit", {})
|
| 129 |
+
for i, block in enumerate(self.double_blocks):
|
| 130 |
+
if ("double_block", i) in blocks_replace:
|
| 131 |
+
def block_wrap(args):
|
| 132 |
+
out = {}
|
| 133 |
+
out["img"], out["txt"] = block(img=args["img"],
|
| 134 |
+
txt=args["txt"],
|
| 135 |
+
vec=args["vec"],
|
| 136 |
+
pe=args["pe"],
|
| 137 |
+
attn_mask=args.get("attn_mask"))
|
| 138 |
+
return out
|
| 139 |
+
|
| 140 |
+
out = blocks_replace[("double_block", i)]({"img": img,
|
| 141 |
+
"txt": txt,
|
| 142 |
+
"vec": vec,
|
| 143 |
+
"pe": pe,
|
| 144 |
+
"attn_mask": attn_mask},
|
| 145 |
+
{"original_block": block_wrap})
|
| 146 |
+
txt = out["txt"]
|
| 147 |
+
img = out["img"]
|
| 148 |
+
else:
|
| 149 |
+
img, txt = block(img=img,
|
| 150 |
+
txt=txt,
|
| 151 |
+
vec=vec,
|
| 152 |
+
pe=pe,
|
| 153 |
+
attn_mask=attn_mask)
|
| 154 |
+
|
| 155 |
+
if control is not None: # Controlnet
|
| 156 |
+
control_i = control.get("input")
|
| 157 |
+
if i < len(control_i):
|
| 158 |
+
add = control_i[i]
|
| 159 |
+
if add is not None:
|
| 160 |
+
img += add
|
| 161 |
+
|
| 162 |
+
if img.dtype == torch.float16:
|
| 163 |
+
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
| 164 |
+
|
| 165 |
+
img = torch.cat((txt, img), 1)
|
| 166 |
+
|
| 167 |
+
for i, block in enumerate(self.single_blocks):
|
| 168 |
+
if ("single_block", i) in blocks_replace:
|
| 169 |
+
def block_wrap(args):
|
| 170 |
+
out = {}
|
| 171 |
+
out["img"] = block(args["img"],
|
| 172 |
+
vec=args["vec"],
|
| 173 |
+
pe=args["pe"],
|
| 174 |
+
attn_mask=args.get("attn_mask"))
|
| 175 |
+
return out
|
| 176 |
+
|
| 177 |
+
out = blocks_replace[("single_block", i)]({"img": img,
|
| 178 |
+
"vec": vec,
|
| 179 |
+
"pe": pe,
|
| 180 |
+
"attn_mask": attn_mask},
|
| 181 |
+
{"original_block": block_wrap})
|
| 182 |
+
img = out["img"]
|
| 183 |
+
else:
|
| 184 |
+
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
| 185 |
+
|
| 186 |
+
if control is not None: # Controlnet
|
| 187 |
+
control_o = control.get("output")
|
| 188 |
+
if i < len(control_o):
|
| 189 |
+
add = control_o[i]
|
| 190 |
+
if add is not None:
|
| 191 |
+
img[:, txt.shape[1] :, ...] += add
|
| 192 |
+
|
| 193 |
+
img = img[:, txt.shape[1] :, ...]
|
| 194 |
+
|
| 195 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 196 |
+
return img
|
| 197 |
+
|
| 198 |
+
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
| 199 |
+
bs, c, h, w = x.shape
|
| 200 |
+
patch_size = self.patch_size
|
| 201 |
+
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
| 202 |
+
|
| 203 |
+
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
| 204 |
+
h_len = ((h + (patch_size // 2)) // patch_size)
|
| 205 |
+
w_len = ((w + (patch_size // 2)) // patch_size)
|
| 206 |
+
|
| 207 |
+
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
| 208 |
+
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
| 209 |
+
|
| 210 |
+
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
| 211 |
+
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
| 212 |
+
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
| 213 |
+
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
| 214 |
+
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
| 215 |
+
|
| 216 |
+
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
| 217 |
+
bs, c, h_orig, w_orig = x.shape
|
| 218 |
+
patch_size = self.patch_size
|
| 219 |
+
|
| 220 |
+
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
| 221 |
+
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
| 222 |
+
img, img_ids = self.process_img(x)
|
| 223 |
+
img_tokens = img.shape[1]
|
| 224 |
+
if ref_latents is not None:
|
| 225 |
+
h = 0
|
| 226 |
+
w = 0
|
| 227 |
+
for ref in ref_latents:
|
| 228 |
+
h_offset = 0
|
| 229 |
+
w_offset = 0
|
| 230 |
+
if ref.shape[-2] + h > ref.shape[-1] + w:
|
| 231 |
+
w_offset = w
|
| 232 |
+
else:
|
| 233 |
+
h_offset = h
|
| 234 |
+
|
| 235 |
+
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
| 236 |
+
img = torch.cat([img, kontext], dim=1)
|
| 237 |
+
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
| 238 |
+
h = max(h, ref.shape[-2] + h_offset)
|
| 239 |
+
w = max(w, ref.shape[-1] + w_offset)
|
| 240 |
+
|
| 241 |
+
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
| 242 |
+
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
| 243 |
+
out = out[:, :img_tokens]
|
| 244 |
+
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
ComfyUI/comfy/ldm/flux/redux.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import comfy.ops
|
| 3 |
+
|
| 4 |
+
ops = comfy.ops.manual_cast
|
| 5 |
+
|
| 6 |
+
class ReduxImageEncoder(torch.nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
redux_dim: int = 1152,
|
| 10 |
+
txt_in_features: int = 4096,
|
| 11 |
+
device=None,
|
| 12 |
+
dtype=None,
|
| 13 |
+
) -> None:
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.redux_dim = redux_dim
|
| 17 |
+
self.device = device
|
| 18 |
+
self.dtype = dtype
|
| 19 |
+
|
| 20 |
+
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
| 21 |
+
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
| 22 |
+
|
| 23 |
+
def forward(self, sigclip_embeds) -> torch.Tensor:
|
| 24 |
+
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
| 25 |
+
return projected_x
|
ComfyUI/comfy/ldm/hidream/model.py
ADDED
|
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import einops
|
| 6 |
+
from einops import repeat
|
| 7 |
+
|
| 8 |
+
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from comfy.ldm.flux.math import apply_rope, rope
|
| 12 |
+
from comfy.ldm.flux.layers import LastLayer
|
| 13 |
+
|
| 14 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 15 |
+
import comfy.model_management
|
| 16 |
+
import comfy.ldm.common_dit
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
| 20 |
+
class EmbedND(nn.Module):
|
| 21 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.theta = theta
|
| 24 |
+
self.axes_dim = axes_dim
|
| 25 |
+
|
| 26 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
n_axes = ids.shape[-1]
|
| 28 |
+
emb = torch.cat(
|
| 29 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 30 |
+
dim=-3,
|
| 31 |
+
)
|
| 32 |
+
return emb.unsqueeze(2)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PatchEmbed(nn.Module):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
patch_size=2,
|
| 39 |
+
in_channels=4,
|
| 40 |
+
out_channels=1024,
|
| 41 |
+
dtype=None, device=None, operations=None
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.patch_size = patch_size
|
| 45 |
+
self.out_channels = out_channels
|
| 46 |
+
self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
|
| 47 |
+
|
| 48 |
+
def forward(self, latent):
|
| 49 |
+
latent = self.proj(latent)
|
| 50 |
+
return latent
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class PooledEmbed(nn.Module):
|
| 54 |
+
def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
| 57 |
+
|
| 58 |
+
def forward(self, pooled_embed):
|
| 59 |
+
return self.pooled_embedder(pooled_embed)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TimestepEmbed(nn.Module):
|
| 63 |
+
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 66 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
| 67 |
+
|
| 68 |
+
def forward(self, timesteps, wdtype):
|
| 69 |
+
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
|
| 70 |
+
t_emb = self.timestep_embedder(t_emb)
|
| 71 |
+
return t_emb
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
| 75 |
+
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class HiDreamAttnProcessor_flashattn:
|
| 79 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
| 80 |
+
|
| 81 |
+
def __call__(
|
| 82 |
+
self,
|
| 83 |
+
attn,
|
| 84 |
+
image_tokens: torch.FloatTensor,
|
| 85 |
+
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
| 86 |
+
text_tokens: Optional[torch.FloatTensor] = None,
|
| 87 |
+
rope: torch.FloatTensor = None,
|
| 88 |
+
*args,
|
| 89 |
+
**kwargs,
|
| 90 |
+
) -> torch.FloatTensor:
|
| 91 |
+
dtype = image_tokens.dtype
|
| 92 |
+
batch_size = image_tokens.shape[0]
|
| 93 |
+
|
| 94 |
+
query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
|
| 95 |
+
key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
|
| 96 |
+
value_i = attn.to_v(image_tokens)
|
| 97 |
+
|
| 98 |
+
inner_dim = key_i.shape[-1]
|
| 99 |
+
head_dim = inner_dim // attn.heads
|
| 100 |
+
|
| 101 |
+
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
|
| 102 |
+
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
|
| 103 |
+
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
|
| 104 |
+
if image_tokens_masks is not None:
|
| 105 |
+
key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
|
| 106 |
+
|
| 107 |
+
if not attn.single:
|
| 108 |
+
query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
|
| 109 |
+
key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
|
| 110 |
+
value_t = attn.to_v_t(text_tokens)
|
| 111 |
+
|
| 112 |
+
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
|
| 113 |
+
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
|
| 114 |
+
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
|
| 115 |
+
|
| 116 |
+
num_image_tokens = query_i.shape[1]
|
| 117 |
+
num_text_tokens = query_t.shape[1]
|
| 118 |
+
query = torch.cat([query_i, query_t], dim=1)
|
| 119 |
+
key = torch.cat([key_i, key_t], dim=1)
|
| 120 |
+
value = torch.cat([value_i, value_t], dim=1)
|
| 121 |
+
else:
|
| 122 |
+
query = query_i
|
| 123 |
+
key = key_i
|
| 124 |
+
value = value_i
|
| 125 |
+
|
| 126 |
+
if query.shape[-1] == rope.shape[-3] * 2:
|
| 127 |
+
query, key = apply_rope(query, key, rope)
|
| 128 |
+
else:
|
| 129 |
+
query_1, query_2 = query.chunk(2, dim=-1)
|
| 130 |
+
key_1, key_2 = key.chunk(2, dim=-1)
|
| 131 |
+
query_1, key_1 = apply_rope(query_1, key_1, rope)
|
| 132 |
+
query = torch.cat([query_1, query_2], dim=-1)
|
| 133 |
+
key = torch.cat([key_1, key_2], dim=-1)
|
| 134 |
+
|
| 135 |
+
hidden_states = attention(query, key, value)
|
| 136 |
+
|
| 137 |
+
if not attn.single:
|
| 138 |
+
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
|
| 139 |
+
hidden_states_i = attn.to_out(hidden_states_i)
|
| 140 |
+
hidden_states_t = attn.to_out_t(hidden_states_t)
|
| 141 |
+
return hidden_states_i, hidden_states_t
|
| 142 |
+
else:
|
| 143 |
+
hidden_states = attn.to_out(hidden_states)
|
| 144 |
+
return hidden_states
|
| 145 |
+
|
| 146 |
+
class HiDreamAttention(nn.Module):
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
query_dim: int,
|
| 150 |
+
heads: int = 8,
|
| 151 |
+
dim_head: int = 64,
|
| 152 |
+
upcast_attention: bool = False,
|
| 153 |
+
upcast_softmax: bool = False,
|
| 154 |
+
scale_qk: bool = True,
|
| 155 |
+
eps: float = 1e-5,
|
| 156 |
+
processor = None,
|
| 157 |
+
out_dim: int = None,
|
| 158 |
+
single: bool = False,
|
| 159 |
+
dtype=None, device=None, operations=None
|
| 160 |
+
):
|
| 161 |
+
# super(Attention, self).__init__()
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 164 |
+
self.query_dim = query_dim
|
| 165 |
+
self.upcast_attention = upcast_attention
|
| 166 |
+
self.upcast_softmax = upcast_softmax
|
| 167 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 168 |
+
|
| 169 |
+
self.scale_qk = scale_qk
|
| 170 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
| 171 |
+
|
| 172 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 173 |
+
self.sliceable_head_dim = heads
|
| 174 |
+
self.single = single
|
| 175 |
+
|
| 176 |
+
linear_cls = operations.Linear
|
| 177 |
+
self.linear_cls = linear_cls
|
| 178 |
+
self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
| 179 |
+
self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
| 180 |
+
self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
| 181 |
+
self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
| 182 |
+
self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
| 183 |
+
self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
| 184 |
+
|
| 185 |
+
if not single:
|
| 186 |
+
self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
|
| 187 |
+
self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
| 188 |
+
self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
|
| 189 |
+
self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
|
| 190 |
+
self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
| 191 |
+
self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
|
| 192 |
+
|
| 193 |
+
self.processor = processor
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
norm_image_tokens: torch.FloatTensor,
|
| 198 |
+
image_tokens_masks: torch.FloatTensor = None,
|
| 199 |
+
norm_text_tokens: torch.FloatTensor = None,
|
| 200 |
+
rope: torch.FloatTensor = None,
|
| 201 |
+
) -> torch.Tensor:
|
| 202 |
+
return self.processor(
|
| 203 |
+
self,
|
| 204 |
+
image_tokens = norm_image_tokens,
|
| 205 |
+
image_tokens_masks = image_tokens_masks,
|
| 206 |
+
text_tokens = norm_text_tokens,
|
| 207 |
+
rope = rope,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class FeedForwardSwiGLU(nn.Module):
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
dim: int,
|
| 215 |
+
hidden_dim: int,
|
| 216 |
+
multiple_of: int = 256,
|
| 217 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 218 |
+
dtype=None, device=None, operations=None
|
| 219 |
+
):
|
| 220 |
+
super().__init__()
|
| 221 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 222 |
+
# custom dim factor multiplier
|
| 223 |
+
if ffn_dim_multiplier is not None:
|
| 224 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 225 |
+
hidden_dim = multiple_of * (
|
| 226 |
+
(hidden_dim + multiple_of - 1) // multiple_of
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
| 230 |
+
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
| 231 |
+
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
| 238 |
+
class MoEGate(nn.Module):
|
| 239 |
+
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.top_k = num_activated_experts
|
| 242 |
+
self.n_routed_experts = num_routed_experts
|
| 243 |
+
|
| 244 |
+
self.scoring_func = 'softmax'
|
| 245 |
+
self.alpha = aux_loss_alpha
|
| 246 |
+
self.seq_aux = False
|
| 247 |
+
|
| 248 |
+
# topk selection algorithm
|
| 249 |
+
self.norm_topk_prob = False
|
| 250 |
+
self.gating_dim = embed_dim
|
| 251 |
+
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
|
| 252 |
+
self.reset_parameters()
|
| 253 |
+
|
| 254 |
+
def reset_parameters(self) -> None:
|
| 255 |
+
pass
|
| 256 |
+
# import torch.nn.init as init
|
| 257 |
+
# init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 258 |
+
|
| 259 |
+
def forward(self, hidden_states):
|
| 260 |
+
bsz, seq_len, h = hidden_states.shape
|
| 261 |
+
|
| 262 |
+
### compute gating score
|
| 263 |
+
hidden_states = hidden_states.view(-1, h)
|
| 264 |
+
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
|
| 265 |
+
if self.scoring_func == 'softmax':
|
| 266 |
+
scores = logits.softmax(dim=-1)
|
| 267 |
+
else:
|
| 268 |
+
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
| 269 |
+
|
| 270 |
+
### select top-k experts
|
| 271 |
+
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
| 272 |
+
|
| 273 |
+
### norm gate to sum 1
|
| 274 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 275 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 276 |
+
topk_weight = topk_weight / denominator
|
| 277 |
+
|
| 278 |
+
aux_loss = None
|
| 279 |
+
return topk_idx, topk_weight, aux_loss
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
| 283 |
+
class MOEFeedForwardSwiGLU(nn.Module):
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
dim: int,
|
| 287 |
+
hidden_dim: int,
|
| 288 |
+
num_routed_experts: int,
|
| 289 |
+
num_activated_experts: int,
|
| 290 |
+
dtype=None, device=None, operations=None
|
| 291 |
+
):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
|
| 294 |
+
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
|
| 295 |
+
self.gate = MoEGate(
|
| 296 |
+
embed_dim = dim,
|
| 297 |
+
num_routed_experts = num_routed_experts,
|
| 298 |
+
num_activated_experts = num_activated_experts,
|
| 299 |
+
dtype=dtype, device=device, operations=operations
|
| 300 |
+
)
|
| 301 |
+
self.num_activated_experts = num_activated_experts
|
| 302 |
+
|
| 303 |
+
def forward(self, x):
|
| 304 |
+
wtype = x.dtype
|
| 305 |
+
identity = x
|
| 306 |
+
orig_shape = x.shape
|
| 307 |
+
topk_idx, topk_weight, aux_loss = self.gate(x)
|
| 308 |
+
x = x.view(-1, x.shape[-1])
|
| 309 |
+
flat_topk_idx = topk_idx.view(-1)
|
| 310 |
+
if True: # self.training: # TODO: check which branch performs faster
|
| 311 |
+
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
| 312 |
+
y = torch.empty_like(x, dtype=wtype)
|
| 313 |
+
for i, expert in enumerate(self.experts):
|
| 314 |
+
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
| 315 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 316 |
+
y = y.view(*orig_shape).to(dtype=wtype)
|
| 317 |
+
#y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 318 |
+
else:
|
| 319 |
+
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
| 320 |
+
y = y + self.shared_experts(identity)
|
| 321 |
+
return y
|
| 322 |
+
|
| 323 |
+
@torch.no_grad()
|
| 324 |
+
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
| 325 |
+
expert_cache = torch.zeros_like(x)
|
| 326 |
+
idxs = flat_expert_indices.argsort()
|
| 327 |
+
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
| 328 |
+
token_idxs = idxs // self.num_activated_experts
|
| 329 |
+
for i, end_idx in enumerate(tokens_per_expert):
|
| 330 |
+
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
| 331 |
+
if start_idx == end_idx:
|
| 332 |
+
continue
|
| 333 |
+
expert = self.experts[i]
|
| 334 |
+
exp_token_idx = token_idxs[start_idx:end_idx]
|
| 335 |
+
expert_tokens = x[exp_token_idx]
|
| 336 |
+
expert_out = expert(expert_tokens)
|
| 337 |
+
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
| 338 |
+
|
| 339 |
+
# for fp16 and other dtype
|
| 340 |
+
expert_cache = expert_cache.to(expert_out.dtype)
|
| 341 |
+
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
|
| 342 |
+
return expert_cache
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class TextProjection(nn.Module):
|
| 346 |
+
def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
|
| 347 |
+
super().__init__()
|
| 348 |
+
self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
|
| 349 |
+
|
| 350 |
+
def forward(self, caption):
|
| 351 |
+
hidden_states = self.linear(caption)
|
| 352 |
+
return hidden_states
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class BlockType:
|
| 356 |
+
TransformerBlock = 1
|
| 357 |
+
SingleTransformerBlock = 2
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class HiDreamImageSingleTransformerBlock(nn.Module):
|
| 361 |
+
def __init__(
|
| 362 |
+
self,
|
| 363 |
+
dim: int,
|
| 364 |
+
num_attention_heads: int,
|
| 365 |
+
attention_head_dim: int,
|
| 366 |
+
num_routed_experts: int = 4,
|
| 367 |
+
num_activated_experts: int = 2,
|
| 368 |
+
dtype=None, device=None, operations=None
|
| 369 |
+
):
|
| 370 |
+
super().__init__()
|
| 371 |
+
self.num_attention_heads = num_attention_heads
|
| 372 |
+
self.adaLN_modulation = nn.Sequential(
|
| 373 |
+
nn.SiLU(),
|
| 374 |
+
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# 1. Attention
|
| 378 |
+
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
| 379 |
+
self.attn1 = HiDreamAttention(
|
| 380 |
+
query_dim=dim,
|
| 381 |
+
heads=num_attention_heads,
|
| 382 |
+
dim_head=attention_head_dim,
|
| 383 |
+
processor = HiDreamAttnProcessor_flashattn(),
|
| 384 |
+
single = True,
|
| 385 |
+
dtype=dtype, device=device, operations=operations
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# 3. Feed-forward
|
| 389 |
+
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
| 390 |
+
if num_routed_experts > 0:
|
| 391 |
+
self.ff_i = MOEFeedForwardSwiGLU(
|
| 392 |
+
dim = dim,
|
| 393 |
+
hidden_dim = 4 * dim,
|
| 394 |
+
num_routed_experts = num_routed_experts,
|
| 395 |
+
num_activated_experts = num_activated_experts,
|
| 396 |
+
dtype=dtype, device=device, operations=operations
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
| 400 |
+
|
| 401 |
+
def forward(
|
| 402 |
+
self,
|
| 403 |
+
image_tokens: torch.FloatTensor,
|
| 404 |
+
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
| 405 |
+
text_tokens: Optional[torch.FloatTensor] = None,
|
| 406 |
+
adaln_input: Optional[torch.FloatTensor] = None,
|
| 407 |
+
rope: torch.FloatTensor = None,
|
| 408 |
+
|
| 409 |
+
) -> torch.FloatTensor:
|
| 410 |
+
wtype = image_tokens.dtype
|
| 411 |
+
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
|
| 412 |
+
self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
|
| 413 |
+
|
| 414 |
+
# 1. MM-Attention
|
| 415 |
+
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
| 416 |
+
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
| 417 |
+
attn_output_i = self.attn1(
|
| 418 |
+
norm_image_tokens,
|
| 419 |
+
image_tokens_masks,
|
| 420 |
+
rope = rope,
|
| 421 |
+
)
|
| 422 |
+
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
| 423 |
+
|
| 424 |
+
# 2. Feed-forward
|
| 425 |
+
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
| 426 |
+
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
| 427 |
+
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
|
| 428 |
+
image_tokens = ff_output_i + image_tokens
|
| 429 |
+
return image_tokens
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class HiDreamImageTransformerBlock(nn.Module):
|
| 433 |
+
def __init__(
|
| 434 |
+
self,
|
| 435 |
+
dim: int,
|
| 436 |
+
num_attention_heads: int,
|
| 437 |
+
attention_head_dim: int,
|
| 438 |
+
num_routed_experts: int = 4,
|
| 439 |
+
num_activated_experts: int = 2,
|
| 440 |
+
dtype=None, device=None, operations=None
|
| 441 |
+
):
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.num_attention_heads = num_attention_heads
|
| 444 |
+
self.adaLN_modulation = nn.Sequential(
|
| 445 |
+
nn.SiLU(),
|
| 446 |
+
operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
|
| 447 |
+
)
|
| 448 |
+
# nn.init.zeros_(self.adaLN_modulation[1].weight)
|
| 449 |
+
# nn.init.zeros_(self.adaLN_modulation[1].bias)
|
| 450 |
+
|
| 451 |
+
# 1. Attention
|
| 452 |
+
self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
| 453 |
+
self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
| 454 |
+
self.attn1 = HiDreamAttention(
|
| 455 |
+
query_dim=dim,
|
| 456 |
+
heads=num_attention_heads,
|
| 457 |
+
dim_head=attention_head_dim,
|
| 458 |
+
processor = HiDreamAttnProcessor_flashattn(),
|
| 459 |
+
single = False,
|
| 460 |
+
dtype=dtype, device=device, operations=operations
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# 3. Feed-forward
|
| 464 |
+
self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
|
| 465 |
+
if num_routed_experts > 0:
|
| 466 |
+
self.ff_i = MOEFeedForwardSwiGLU(
|
| 467 |
+
dim = dim,
|
| 468 |
+
hidden_dim = 4 * dim,
|
| 469 |
+
num_routed_experts = num_routed_experts,
|
| 470 |
+
num_activated_experts = num_activated_experts,
|
| 471 |
+
dtype=dtype, device=device, operations=operations
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
+
self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
| 475 |
+
self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
|
| 476 |
+
self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
|
| 477 |
+
|
| 478 |
+
def forward(
|
| 479 |
+
self,
|
| 480 |
+
image_tokens: torch.FloatTensor,
|
| 481 |
+
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
| 482 |
+
text_tokens: Optional[torch.FloatTensor] = None,
|
| 483 |
+
adaln_input: Optional[torch.FloatTensor] = None,
|
| 484 |
+
rope: torch.FloatTensor = None,
|
| 485 |
+
) -> torch.FloatTensor:
|
| 486 |
+
wtype = image_tokens.dtype
|
| 487 |
+
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
|
| 488 |
+
shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
|
| 489 |
+
self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
|
| 490 |
+
|
| 491 |
+
# 1. MM-Attention
|
| 492 |
+
norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
|
| 493 |
+
norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
|
| 494 |
+
norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
|
| 495 |
+
norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
|
| 496 |
+
|
| 497 |
+
attn_output_i, attn_output_t = self.attn1(
|
| 498 |
+
norm_image_tokens,
|
| 499 |
+
image_tokens_masks,
|
| 500 |
+
norm_text_tokens,
|
| 501 |
+
rope = rope,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
image_tokens = gate_msa_i * attn_output_i + image_tokens
|
| 505 |
+
text_tokens = gate_msa_t * attn_output_t + text_tokens
|
| 506 |
+
|
| 507 |
+
# 2. Feed-forward
|
| 508 |
+
norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
|
| 509 |
+
norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
|
| 510 |
+
norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
|
| 511 |
+
norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
|
| 512 |
+
|
| 513 |
+
ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
|
| 514 |
+
ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
|
| 515 |
+
image_tokens = ff_output_i + image_tokens
|
| 516 |
+
text_tokens = ff_output_t + text_tokens
|
| 517 |
+
return image_tokens, text_tokens
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class HiDreamImageBlock(nn.Module):
|
| 521 |
+
def __init__(
|
| 522 |
+
self,
|
| 523 |
+
dim: int,
|
| 524 |
+
num_attention_heads: int,
|
| 525 |
+
attention_head_dim: int,
|
| 526 |
+
num_routed_experts: int = 4,
|
| 527 |
+
num_activated_experts: int = 2,
|
| 528 |
+
block_type: BlockType = BlockType.TransformerBlock,
|
| 529 |
+
dtype=None, device=None, operations=None
|
| 530 |
+
):
|
| 531 |
+
super().__init__()
|
| 532 |
+
block_classes = {
|
| 533 |
+
BlockType.TransformerBlock: HiDreamImageTransformerBlock,
|
| 534 |
+
BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
|
| 535 |
+
}
|
| 536 |
+
self.block = block_classes[block_type](
|
| 537 |
+
dim,
|
| 538 |
+
num_attention_heads,
|
| 539 |
+
attention_head_dim,
|
| 540 |
+
num_routed_experts,
|
| 541 |
+
num_activated_experts,
|
| 542 |
+
dtype=dtype, device=device, operations=operations
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
def forward(
|
| 546 |
+
self,
|
| 547 |
+
image_tokens: torch.FloatTensor,
|
| 548 |
+
image_tokens_masks: Optional[torch.FloatTensor] = None,
|
| 549 |
+
text_tokens: Optional[torch.FloatTensor] = None,
|
| 550 |
+
adaln_input: torch.FloatTensor = None,
|
| 551 |
+
rope: torch.FloatTensor = None,
|
| 552 |
+
) -> torch.FloatTensor:
|
| 553 |
+
return self.block(
|
| 554 |
+
image_tokens,
|
| 555 |
+
image_tokens_masks,
|
| 556 |
+
text_tokens,
|
| 557 |
+
adaln_input,
|
| 558 |
+
rope,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class HiDreamImageTransformer2DModel(nn.Module):
|
| 563 |
+
def __init__(
|
| 564 |
+
self,
|
| 565 |
+
patch_size: Optional[int] = None,
|
| 566 |
+
in_channels: int = 64,
|
| 567 |
+
out_channels: Optional[int] = None,
|
| 568 |
+
num_layers: int = 16,
|
| 569 |
+
num_single_layers: int = 32,
|
| 570 |
+
attention_head_dim: int = 128,
|
| 571 |
+
num_attention_heads: int = 20,
|
| 572 |
+
caption_channels: List[int] = None,
|
| 573 |
+
text_emb_dim: int = 2048,
|
| 574 |
+
num_routed_experts: int = 4,
|
| 575 |
+
num_activated_experts: int = 2,
|
| 576 |
+
axes_dims_rope: Tuple[int, int] = (32, 32),
|
| 577 |
+
max_resolution: Tuple[int, int] = (128, 128),
|
| 578 |
+
llama_layers: List[int] = None,
|
| 579 |
+
image_model=None,
|
| 580 |
+
dtype=None, device=None, operations=None
|
| 581 |
+
):
|
| 582 |
+
self.patch_size = patch_size
|
| 583 |
+
self.num_attention_heads = num_attention_heads
|
| 584 |
+
self.attention_head_dim = attention_head_dim
|
| 585 |
+
self.num_layers = num_layers
|
| 586 |
+
self.num_single_layers = num_single_layers
|
| 587 |
+
|
| 588 |
+
self.gradient_checkpointing = False
|
| 589 |
+
|
| 590 |
+
super().__init__()
|
| 591 |
+
self.dtype = dtype
|
| 592 |
+
self.out_channels = out_channels or in_channels
|
| 593 |
+
self.inner_dim = self.num_attention_heads * self.attention_head_dim
|
| 594 |
+
self.llama_layers = llama_layers
|
| 595 |
+
|
| 596 |
+
self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
|
| 597 |
+
self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
| 598 |
+
self.x_embedder = PatchEmbed(
|
| 599 |
+
patch_size = patch_size,
|
| 600 |
+
in_channels = in_channels,
|
| 601 |
+
out_channels = self.inner_dim,
|
| 602 |
+
dtype=dtype, device=device, operations=operations
|
| 603 |
+
)
|
| 604 |
+
self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
|
| 605 |
+
|
| 606 |
+
self.double_stream_blocks = nn.ModuleList(
|
| 607 |
+
[
|
| 608 |
+
HiDreamImageBlock(
|
| 609 |
+
dim = self.inner_dim,
|
| 610 |
+
num_attention_heads = self.num_attention_heads,
|
| 611 |
+
attention_head_dim = self.attention_head_dim,
|
| 612 |
+
num_routed_experts = num_routed_experts,
|
| 613 |
+
num_activated_experts = num_activated_experts,
|
| 614 |
+
block_type = BlockType.TransformerBlock,
|
| 615 |
+
dtype=dtype, device=device, operations=operations
|
| 616 |
+
)
|
| 617 |
+
for i in range(self.num_layers)
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
self.single_stream_blocks = nn.ModuleList(
|
| 622 |
+
[
|
| 623 |
+
HiDreamImageBlock(
|
| 624 |
+
dim = self.inner_dim,
|
| 625 |
+
num_attention_heads = self.num_attention_heads,
|
| 626 |
+
attention_head_dim = self.attention_head_dim,
|
| 627 |
+
num_routed_experts = num_routed_experts,
|
| 628 |
+
num_activated_experts = num_activated_experts,
|
| 629 |
+
block_type = BlockType.SingleTransformerBlock,
|
| 630 |
+
dtype=dtype, device=device, operations=operations
|
| 631 |
+
)
|
| 632 |
+
for i in range(self.num_single_layers)
|
| 633 |
+
]
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
| 637 |
+
|
| 638 |
+
caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
|
| 639 |
+
caption_projection = []
|
| 640 |
+
for caption_channel in caption_channels:
|
| 641 |
+
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
|
| 642 |
+
self.caption_projection = nn.ModuleList(caption_projection)
|
| 643 |
+
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
|
| 644 |
+
|
| 645 |
+
def expand_timesteps(self, timesteps, batch_size, device):
|
| 646 |
+
if not torch.is_tensor(timesteps):
|
| 647 |
+
is_mps = device.type == "mps"
|
| 648 |
+
if isinstance(timesteps, float):
|
| 649 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 650 |
+
else:
|
| 651 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 652 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
|
| 653 |
+
elif len(timesteps.shape) == 0:
|
| 654 |
+
timesteps = timesteps[None].to(device)
|
| 655 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 656 |
+
timesteps = timesteps.expand(batch_size)
|
| 657 |
+
return timesteps
|
| 658 |
+
|
| 659 |
+
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
|
| 660 |
+
x_arr = []
|
| 661 |
+
for i, img_size in enumerate(img_sizes):
|
| 662 |
+
pH, pW = img_size
|
| 663 |
+
x_arr.append(
|
| 664 |
+
einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
|
| 665 |
+
p1=self.patch_size, p2=self.patch_size)
|
| 666 |
+
)
|
| 667 |
+
x = torch.cat(x_arr, dim=0)
|
| 668 |
+
return x
|
| 669 |
+
|
| 670 |
+
def patchify(self, x, max_seq, img_sizes=None):
|
| 671 |
+
pz2 = self.patch_size * self.patch_size
|
| 672 |
+
if isinstance(x, torch.Tensor):
|
| 673 |
+
B = x.shape[0]
|
| 674 |
+
device = x.device
|
| 675 |
+
dtype = x.dtype
|
| 676 |
+
else:
|
| 677 |
+
B = len(x)
|
| 678 |
+
device = x[0].device
|
| 679 |
+
dtype = x[0].dtype
|
| 680 |
+
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
|
| 681 |
+
|
| 682 |
+
if img_sizes is not None:
|
| 683 |
+
for i, img_size in enumerate(img_sizes):
|
| 684 |
+
x_masks[i, 0:img_size[0] * img_size[1]] = 1
|
| 685 |
+
x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
|
| 686 |
+
elif isinstance(x, torch.Tensor):
|
| 687 |
+
pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
|
| 688 |
+
x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
|
| 689 |
+
img_sizes = [[pH, pW]] * B
|
| 690 |
+
x_masks = None
|
| 691 |
+
else:
|
| 692 |
+
raise NotImplementedError
|
| 693 |
+
return x, x_masks, img_sizes
|
| 694 |
+
|
| 695 |
+
def forward(
|
| 696 |
+
self,
|
| 697 |
+
x: torch.Tensor,
|
| 698 |
+
t: torch.Tensor,
|
| 699 |
+
y: Optional[torch.Tensor] = None,
|
| 700 |
+
context: Optional[torch.Tensor] = None,
|
| 701 |
+
encoder_hidden_states_llama3=None,
|
| 702 |
+
image_cond=None,
|
| 703 |
+
control = None,
|
| 704 |
+
transformer_options = {},
|
| 705 |
+
) -> torch.Tensor:
|
| 706 |
+
bs, c, h, w = x.shape
|
| 707 |
+
if image_cond is not None:
|
| 708 |
+
x = torch.cat([x, image_cond], dim=-1)
|
| 709 |
+
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
| 710 |
+
timesteps = t
|
| 711 |
+
pooled_embeds = y
|
| 712 |
+
T5_encoder_hidden_states = context
|
| 713 |
+
|
| 714 |
+
img_sizes = None
|
| 715 |
+
|
| 716 |
+
# spatial forward
|
| 717 |
+
batch_size = hidden_states.shape[0]
|
| 718 |
+
hidden_states_type = hidden_states.dtype
|
| 719 |
+
|
| 720 |
+
# 0. time
|
| 721 |
+
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
|
| 722 |
+
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
| 723 |
+
p_embedder = self.p_embedder(pooled_embeds)
|
| 724 |
+
adaln_input = timesteps + p_embedder
|
| 725 |
+
|
| 726 |
+
hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
|
| 727 |
+
if image_tokens_masks is None:
|
| 728 |
+
pH, pW = img_sizes[0]
|
| 729 |
+
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
|
| 730 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
|
| 731 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
|
| 732 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
| 733 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 734 |
+
|
| 735 |
+
# T5_encoder_hidden_states = encoder_hidden_states[0]
|
| 736 |
+
encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
|
| 737 |
+
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
|
| 738 |
+
|
| 739 |
+
if self.caption_projection is not None:
|
| 740 |
+
new_encoder_hidden_states = []
|
| 741 |
+
for i, enc_hidden_state in enumerate(encoder_hidden_states):
|
| 742 |
+
enc_hidden_state = self.caption_projection[i](enc_hidden_state)
|
| 743 |
+
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
| 744 |
+
new_encoder_hidden_states.append(enc_hidden_state)
|
| 745 |
+
encoder_hidden_states = new_encoder_hidden_states
|
| 746 |
+
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
|
| 747 |
+
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
| 748 |
+
encoder_hidden_states.append(T5_encoder_hidden_states)
|
| 749 |
+
|
| 750 |
+
txt_ids = torch.zeros(
|
| 751 |
+
batch_size,
|
| 752 |
+
encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
|
| 753 |
+
3,
|
| 754 |
+
device=img_ids.device, dtype=img_ids.dtype
|
| 755 |
+
)
|
| 756 |
+
ids = torch.cat((img_ids, txt_ids), dim=1)
|
| 757 |
+
rope = self.pe_embedder(ids)
|
| 758 |
+
|
| 759 |
+
# 2. Blocks
|
| 760 |
+
block_id = 0
|
| 761 |
+
initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
|
| 762 |
+
initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
|
| 763 |
+
for bid, block in enumerate(self.double_stream_blocks):
|
| 764 |
+
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
| 765 |
+
cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
| 766 |
+
hidden_states, initial_encoder_hidden_states = block(
|
| 767 |
+
image_tokens = hidden_states,
|
| 768 |
+
image_tokens_masks = image_tokens_masks,
|
| 769 |
+
text_tokens = cur_encoder_hidden_states,
|
| 770 |
+
adaln_input = adaln_input,
|
| 771 |
+
rope = rope,
|
| 772 |
+
)
|
| 773 |
+
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
|
| 774 |
+
block_id += 1
|
| 775 |
+
|
| 776 |
+
image_tokens_seq_len = hidden_states.shape[1]
|
| 777 |
+
hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
|
| 778 |
+
hidden_states_seq_len = hidden_states.shape[1]
|
| 779 |
+
if image_tokens_masks is not None:
|
| 780 |
+
encoder_attention_mask_ones = torch.ones(
|
| 781 |
+
(batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
|
| 782 |
+
device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
|
| 783 |
+
)
|
| 784 |
+
image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
|
| 785 |
+
|
| 786 |
+
for bid, block in enumerate(self.single_stream_blocks):
|
| 787 |
+
cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
|
| 788 |
+
hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
|
| 789 |
+
hidden_states = block(
|
| 790 |
+
image_tokens=hidden_states,
|
| 791 |
+
image_tokens_masks=image_tokens_masks,
|
| 792 |
+
text_tokens=None,
|
| 793 |
+
adaln_input=adaln_input,
|
| 794 |
+
rope=rope,
|
| 795 |
+
)
|
| 796 |
+
hidden_states = hidden_states[:, :hidden_states_seq_len]
|
| 797 |
+
block_id += 1
|
| 798 |
+
|
| 799 |
+
hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
|
| 800 |
+
output = self.final_layer(hidden_states, adaln_input)
|
| 801 |
+
output = self.unpatchify(output, img_sizes)
|
| 802 |
+
return -output[:, :, :h, :w]
|
ComfyUI/comfy/ldm/hunyuan3d/model.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from comfy.ldm.flux.layers import (
|
| 4 |
+
DoubleStreamBlock,
|
| 5 |
+
LastLayer,
|
| 6 |
+
MLPEmbedder,
|
| 7 |
+
SingleStreamBlock,
|
| 8 |
+
timestep_embedding,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Hunyuan3Dv2(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_channels=64,
|
| 16 |
+
context_in_dim=1536,
|
| 17 |
+
hidden_size=1024,
|
| 18 |
+
mlp_ratio=4.0,
|
| 19 |
+
num_heads=16,
|
| 20 |
+
depth=16,
|
| 21 |
+
depth_single_blocks=32,
|
| 22 |
+
qkv_bias=True,
|
| 23 |
+
guidance_embed=False,
|
| 24 |
+
image_model=None,
|
| 25 |
+
dtype=None,
|
| 26 |
+
device=None,
|
| 27 |
+
operations=None
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.dtype = dtype
|
| 31 |
+
|
| 32 |
+
if hidden_size % num_heads != 0:
|
| 33 |
+
raise ValueError(
|
| 34 |
+
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
|
| 38 |
+
self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
|
| 39 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
|
| 40 |
+
self.guidance_in = (
|
| 41 |
+
MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
|
| 42 |
+
)
|
| 43 |
+
self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
|
| 44 |
+
self.double_blocks = nn.ModuleList(
|
| 45 |
+
[
|
| 46 |
+
DoubleStreamBlock(
|
| 47 |
+
hidden_size,
|
| 48 |
+
num_heads,
|
| 49 |
+
mlp_ratio=mlp_ratio,
|
| 50 |
+
qkv_bias=qkv_bias,
|
| 51 |
+
dtype=dtype, device=device, operations=operations
|
| 52 |
+
)
|
| 53 |
+
for _ in range(depth)
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
self.single_blocks = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
SingleStreamBlock(
|
| 59 |
+
hidden_size,
|
| 60 |
+
num_heads,
|
| 61 |
+
mlp_ratio=mlp_ratio,
|
| 62 |
+
dtype=dtype, device=device, operations=operations
|
| 63 |
+
)
|
| 64 |
+
for _ in range(depth_single_blocks)
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
|
| 68 |
+
|
| 69 |
+
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
|
| 70 |
+
x = x.movedim(-1, -2)
|
| 71 |
+
timestep = 1.0 - timestep
|
| 72 |
+
txt = context
|
| 73 |
+
img = self.latent_in(x)
|
| 74 |
+
|
| 75 |
+
vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
|
| 76 |
+
if self.guidance_in is not None:
|
| 77 |
+
if guidance is not None:
|
| 78 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
|
| 79 |
+
|
| 80 |
+
txt = self.cond_in(txt)
|
| 81 |
+
pe = None
|
| 82 |
+
attn_mask = None
|
| 83 |
+
|
| 84 |
+
patches_replace = transformer_options.get("patches_replace", {})
|
| 85 |
+
blocks_replace = patches_replace.get("dit", {})
|
| 86 |
+
for i, block in enumerate(self.double_blocks):
|
| 87 |
+
if ("double_block", i) in blocks_replace:
|
| 88 |
+
def block_wrap(args):
|
| 89 |
+
out = {}
|
| 90 |
+
out["img"], out["txt"] = block(img=args["img"],
|
| 91 |
+
txt=args["txt"],
|
| 92 |
+
vec=args["vec"],
|
| 93 |
+
pe=args["pe"],
|
| 94 |
+
attn_mask=args.get("attn_mask"))
|
| 95 |
+
return out
|
| 96 |
+
|
| 97 |
+
out = blocks_replace[("double_block", i)]({"img": img,
|
| 98 |
+
"txt": txt,
|
| 99 |
+
"vec": vec,
|
| 100 |
+
"pe": pe,
|
| 101 |
+
"attn_mask": attn_mask},
|
| 102 |
+
{"original_block": block_wrap})
|
| 103 |
+
txt = out["txt"]
|
| 104 |
+
img = out["img"]
|
| 105 |
+
else:
|
| 106 |
+
img, txt = block(img=img,
|
| 107 |
+
txt=txt,
|
| 108 |
+
vec=vec,
|
| 109 |
+
pe=pe,
|
| 110 |
+
attn_mask=attn_mask)
|
| 111 |
+
|
| 112 |
+
img = torch.cat((txt, img), 1)
|
| 113 |
+
|
| 114 |
+
for i, block in enumerate(self.single_blocks):
|
| 115 |
+
if ("single_block", i) in blocks_replace:
|
| 116 |
+
def block_wrap(args):
|
| 117 |
+
out = {}
|
| 118 |
+
out["img"] = block(args["img"],
|
| 119 |
+
vec=args["vec"],
|
| 120 |
+
pe=args["pe"],
|
| 121 |
+
attn_mask=args.get("attn_mask"))
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
out = blocks_replace[("single_block", i)]({"img": img,
|
| 125 |
+
"vec": vec,
|
| 126 |
+
"pe": pe,
|
| 127 |
+
"attn_mask": attn_mask},
|
| 128 |
+
{"original_block": block_wrap})
|
| 129 |
+
img = out["img"]
|
| 130 |
+
else:
|
| 131 |
+
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
| 132 |
+
|
| 133 |
+
img = img[:, txt.shape[1]:, ...]
|
| 134 |
+
img = self.final_layer(img, vec)
|
| 135 |
+
return img.movedim(-2, -1) * (-1.0)
|
ComfyUI/comfy/ldm/hunyuan3d/vae.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
|
| 2 |
+
# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from typing import Union, Tuple, List, Callable, Optional
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from einops import repeat, rearrange
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
import comfy.ops
|
| 17 |
+
ops = comfy.ops.disable_weight_init
|
| 18 |
+
|
| 19 |
+
def generate_dense_grid_points(
|
| 20 |
+
bbox_min: np.ndarray,
|
| 21 |
+
bbox_max: np.ndarray,
|
| 22 |
+
octree_resolution: int,
|
| 23 |
+
indexing: str = "ij",
|
| 24 |
+
):
|
| 25 |
+
length = bbox_max - bbox_min
|
| 26 |
+
num_cells = octree_resolution
|
| 27 |
+
|
| 28 |
+
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
| 29 |
+
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
| 30 |
+
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
| 31 |
+
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
| 32 |
+
xyz = np.stack((xs, ys, zs), axis=-1)
|
| 33 |
+
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
| 34 |
+
|
| 35 |
+
return xyz, grid_size, length
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class VanillaVolumeDecoder:
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def __call__(
|
| 41 |
+
self,
|
| 42 |
+
latents: torch.FloatTensor,
|
| 43 |
+
geo_decoder: Callable,
|
| 44 |
+
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
| 45 |
+
num_chunks: int = 10000,
|
| 46 |
+
octree_resolution: int = None,
|
| 47 |
+
enable_pbar: bool = True,
|
| 48 |
+
**kwargs,
|
| 49 |
+
):
|
| 50 |
+
device = latents.device
|
| 51 |
+
dtype = latents.dtype
|
| 52 |
+
batch_size = latents.shape[0]
|
| 53 |
+
|
| 54 |
+
# 1. generate query points
|
| 55 |
+
if isinstance(bounds, float):
|
| 56 |
+
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 57 |
+
|
| 58 |
+
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
| 59 |
+
xyz_samples, grid_size, length = generate_dense_grid_points(
|
| 60 |
+
bbox_min=bbox_min,
|
| 61 |
+
bbox_max=bbox_max,
|
| 62 |
+
octree_resolution=octree_resolution,
|
| 63 |
+
indexing="ij"
|
| 64 |
+
)
|
| 65 |
+
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
| 66 |
+
|
| 67 |
+
# 2. latents to 3d volume
|
| 68 |
+
batch_logits = []
|
| 69 |
+
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
|
| 70 |
+
disable=not enable_pbar):
|
| 71 |
+
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
| 72 |
+
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
| 73 |
+
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
| 74 |
+
batch_logits.append(logits)
|
| 75 |
+
|
| 76 |
+
grid_logits = torch.cat(batch_logits, dim=1)
|
| 77 |
+
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
| 78 |
+
|
| 79 |
+
return grid_logits
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class FourierEmbedder(nn.Module):
|
| 83 |
+
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
| 84 |
+
each feature dimension of `x[..., i]` into:
|
| 85 |
+
[
|
| 86 |
+
sin(x[..., i]),
|
| 87 |
+
sin(f_1*x[..., i]),
|
| 88 |
+
sin(f_2*x[..., i]),
|
| 89 |
+
...
|
| 90 |
+
sin(f_N * x[..., i]),
|
| 91 |
+
cos(x[..., i]),
|
| 92 |
+
cos(f_1*x[..., i]),
|
| 93 |
+
cos(f_2*x[..., i]),
|
| 94 |
+
...
|
| 95 |
+
cos(f_N * x[..., i]),
|
| 96 |
+
x[..., i] # only present if include_input is True.
|
| 97 |
+
], here f_i is the frequency.
|
| 98 |
+
|
| 99 |
+
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
| 100 |
+
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
| 101 |
+
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
num_freqs (int): the number of frequencies, default is 6;
|
| 105 |
+
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
| 106 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
| 107 |
+
input_dim (int): the input dimension, default is 3;
|
| 108 |
+
include_input (bool): include the input tensor or not, default is True.
|
| 109 |
+
|
| 110 |
+
Attributes:
|
| 111 |
+
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
| 112 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
| 113 |
+
|
| 114 |
+
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
| 115 |
+
otherwise, it is input_dim * num_freqs * 2.
|
| 116 |
+
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self,
|
| 120 |
+
num_freqs: int = 6,
|
| 121 |
+
logspace: bool = True,
|
| 122 |
+
input_dim: int = 3,
|
| 123 |
+
include_input: bool = True,
|
| 124 |
+
include_pi: bool = True) -> None:
|
| 125 |
+
|
| 126 |
+
"""The initialization"""
|
| 127 |
+
|
| 128 |
+
super().__init__()
|
| 129 |
+
|
| 130 |
+
if logspace:
|
| 131 |
+
frequencies = 2.0 ** torch.arange(
|
| 132 |
+
num_freqs,
|
| 133 |
+
dtype=torch.float32
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
frequencies = torch.linspace(
|
| 137 |
+
1.0,
|
| 138 |
+
2.0 ** (num_freqs - 1),
|
| 139 |
+
num_freqs,
|
| 140 |
+
dtype=torch.float32
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if include_pi:
|
| 144 |
+
frequencies *= torch.pi
|
| 145 |
+
|
| 146 |
+
self.register_buffer("frequencies", frequencies, persistent=False)
|
| 147 |
+
self.include_input = include_input
|
| 148 |
+
self.num_freqs = num_freqs
|
| 149 |
+
|
| 150 |
+
self.out_dim = self.get_dims(input_dim)
|
| 151 |
+
|
| 152 |
+
def get_dims(self, input_dim):
|
| 153 |
+
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
| 154 |
+
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
| 155 |
+
|
| 156 |
+
return out_dim
|
| 157 |
+
|
| 158 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
""" Forward process.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
x: tensor of shape [..., dim]
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
| 166 |
+
where temp is 1 if include_input is True and 0 otherwise.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
if self.num_freqs > 0:
|
| 170 |
+
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
|
| 171 |
+
if self.include_input:
|
| 172 |
+
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
| 173 |
+
else:
|
| 174 |
+
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
| 175 |
+
else:
|
| 176 |
+
return x
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class CrossAttentionProcessor:
|
| 180 |
+
def __call__(self, attn, q, k, v):
|
| 181 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
| 182 |
+
return out
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class DropPath(nn.Module):
|
| 186 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 190 |
+
super(DropPath, self).__init__()
|
| 191 |
+
self.drop_prob = drop_prob
|
| 192 |
+
self.scale_by_keep = scale_by_keep
|
| 193 |
+
|
| 194 |
+
def forward(self, x):
|
| 195 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 196 |
+
|
| 197 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 198 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 199 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 200 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 201 |
+
'survival rate' as the argument.
|
| 202 |
+
|
| 203 |
+
"""
|
| 204 |
+
if self.drop_prob == 0. or not self.training:
|
| 205 |
+
return x
|
| 206 |
+
keep_prob = 1 - self.drop_prob
|
| 207 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 208 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 209 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 210 |
+
random_tensor.div_(keep_prob)
|
| 211 |
+
return x * random_tensor
|
| 212 |
+
|
| 213 |
+
def extra_repr(self):
|
| 214 |
+
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class MLP(nn.Module):
|
| 218 |
+
def __init__(
|
| 219 |
+
self, *,
|
| 220 |
+
width: int,
|
| 221 |
+
expand_ratio: int = 4,
|
| 222 |
+
output_width: int = None,
|
| 223 |
+
drop_path_rate: float = 0.0
|
| 224 |
+
):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.width = width
|
| 227 |
+
self.c_fc = ops.Linear(width, width * expand_ratio)
|
| 228 |
+
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
| 229 |
+
self.gelu = nn.GELU()
|
| 230 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class QKVMultiheadCrossAttention(nn.Module):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
*,
|
| 240 |
+
heads: int,
|
| 241 |
+
width=None,
|
| 242 |
+
qk_norm=False,
|
| 243 |
+
norm_layer=ops.LayerNorm
|
| 244 |
+
):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.heads = heads
|
| 247 |
+
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 248 |
+
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 249 |
+
|
| 250 |
+
self.attn_processor = CrossAttentionProcessor()
|
| 251 |
+
|
| 252 |
+
def forward(self, q, kv):
|
| 253 |
+
_, n_ctx, _ = q.shape
|
| 254 |
+
bs, n_data, width = kv.shape
|
| 255 |
+
attn_ch = width // self.heads // 2
|
| 256 |
+
q = q.view(bs, n_ctx, self.heads, -1)
|
| 257 |
+
kv = kv.view(bs, n_data, self.heads, -1)
|
| 258 |
+
k, v = torch.split(kv, attn_ch, dim=-1)
|
| 259 |
+
|
| 260 |
+
q = self.q_norm(q)
|
| 261 |
+
k = self.k_norm(k)
|
| 262 |
+
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
| 263 |
+
out = self.attn_processor(self, q, k, v)
|
| 264 |
+
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
| 265 |
+
return out
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class MultiheadCrossAttention(nn.Module):
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
*,
|
| 272 |
+
width: int,
|
| 273 |
+
heads: int,
|
| 274 |
+
qkv_bias: bool = True,
|
| 275 |
+
data_width: Optional[int] = None,
|
| 276 |
+
norm_layer=ops.LayerNorm,
|
| 277 |
+
qk_norm: bool = False,
|
| 278 |
+
kv_cache: bool = False,
|
| 279 |
+
):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.width = width
|
| 282 |
+
self.heads = heads
|
| 283 |
+
self.data_width = width if data_width is None else data_width
|
| 284 |
+
self.c_q = ops.Linear(width, width, bias=qkv_bias)
|
| 285 |
+
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
|
| 286 |
+
self.c_proj = ops.Linear(width, width)
|
| 287 |
+
self.attention = QKVMultiheadCrossAttention(
|
| 288 |
+
heads=heads,
|
| 289 |
+
width=width,
|
| 290 |
+
norm_layer=norm_layer,
|
| 291 |
+
qk_norm=qk_norm
|
| 292 |
+
)
|
| 293 |
+
self.kv_cache = kv_cache
|
| 294 |
+
self.data = None
|
| 295 |
+
|
| 296 |
+
def forward(self, x, data):
|
| 297 |
+
x = self.c_q(x)
|
| 298 |
+
if self.kv_cache:
|
| 299 |
+
if self.data is None:
|
| 300 |
+
self.data = self.c_kv(data)
|
| 301 |
+
logging.info('Save kv cache,this should be called only once for one mesh')
|
| 302 |
+
data = self.data
|
| 303 |
+
else:
|
| 304 |
+
data = self.c_kv(data)
|
| 305 |
+
x = self.attention(x, data)
|
| 306 |
+
x = self.c_proj(x)
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class ResidualCrossAttentionBlock(nn.Module):
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
*,
|
| 314 |
+
width: int,
|
| 315 |
+
heads: int,
|
| 316 |
+
mlp_expand_ratio: int = 4,
|
| 317 |
+
data_width: Optional[int] = None,
|
| 318 |
+
qkv_bias: bool = True,
|
| 319 |
+
norm_layer=ops.LayerNorm,
|
| 320 |
+
qk_norm: bool = False
|
| 321 |
+
):
|
| 322 |
+
super().__init__()
|
| 323 |
+
|
| 324 |
+
if data_width is None:
|
| 325 |
+
data_width = width
|
| 326 |
+
|
| 327 |
+
self.attn = MultiheadCrossAttention(
|
| 328 |
+
width=width,
|
| 329 |
+
heads=heads,
|
| 330 |
+
data_width=data_width,
|
| 331 |
+
qkv_bias=qkv_bias,
|
| 332 |
+
norm_layer=norm_layer,
|
| 333 |
+
qk_norm=qk_norm
|
| 334 |
+
)
|
| 335 |
+
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 336 |
+
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
| 337 |
+
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 338 |
+
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
| 339 |
+
|
| 340 |
+
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
| 341 |
+
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
| 342 |
+
x = x + self.mlp(self.ln_3(x))
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class QKVMultiheadAttention(nn.Module):
|
| 347 |
+
def __init__(
|
| 348 |
+
self,
|
| 349 |
+
*,
|
| 350 |
+
heads: int,
|
| 351 |
+
width=None,
|
| 352 |
+
qk_norm=False,
|
| 353 |
+
norm_layer=ops.LayerNorm
|
| 354 |
+
):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.heads = heads
|
| 357 |
+
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 358 |
+
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 359 |
+
|
| 360 |
+
def forward(self, qkv):
|
| 361 |
+
bs, n_ctx, width = qkv.shape
|
| 362 |
+
attn_ch = width // self.heads // 3
|
| 363 |
+
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
| 364 |
+
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
| 365 |
+
|
| 366 |
+
q = self.q_norm(q)
|
| 367 |
+
k = self.k_norm(k)
|
| 368 |
+
|
| 369 |
+
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
| 370 |
+
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
| 371 |
+
return out
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class MultiheadAttention(nn.Module):
|
| 375 |
+
def __init__(
|
| 376 |
+
self,
|
| 377 |
+
*,
|
| 378 |
+
width: int,
|
| 379 |
+
heads: int,
|
| 380 |
+
qkv_bias: bool,
|
| 381 |
+
norm_layer=ops.LayerNorm,
|
| 382 |
+
qk_norm: bool = False,
|
| 383 |
+
drop_path_rate: float = 0.0
|
| 384 |
+
):
|
| 385 |
+
super().__init__()
|
| 386 |
+
self.width = width
|
| 387 |
+
self.heads = heads
|
| 388 |
+
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
|
| 389 |
+
self.c_proj = ops.Linear(width, width)
|
| 390 |
+
self.attention = QKVMultiheadAttention(
|
| 391 |
+
heads=heads,
|
| 392 |
+
width=width,
|
| 393 |
+
norm_layer=norm_layer,
|
| 394 |
+
qk_norm=qk_norm
|
| 395 |
+
)
|
| 396 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 397 |
+
|
| 398 |
+
def forward(self, x):
|
| 399 |
+
x = self.c_qkv(x)
|
| 400 |
+
x = self.attention(x)
|
| 401 |
+
x = self.drop_path(self.c_proj(x))
|
| 402 |
+
return x
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class ResidualAttentionBlock(nn.Module):
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
*,
|
| 409 |
+
width: int,
|
| 410 |
+
heads: int,
|
| 411 |
+
qkv_bias: bool = True,
|
| 412 |
+
norm_layer=ops.LayerNorm,
|
| 413 |
+
qk_norm: bool = False,
|
| 414 |
+
drop_path_rate: float = 0.0,
|
| 415 |
+
):
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.attn = MultiheadAttention(
|
| 418 |
+
width=width,
|
| 419 |
+
heads=heads,
|
| 420 |
+
qkv_bias=qkv_bias,
|
| 421 |
+
norm_layer=norm_layer,
|
| 422 |
+
qk_norm=qk_norm,
|
| 423 |
+
drop_path_rate=drop_path_rate
|
| 424 |
+
)
|
| 425 |
+
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 426 |
+
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
| 427 |
+
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 428 |
+
|
| 429 |
+
def forward(self, x: torch.Tensor):
|
| 430 |
+
x = x + self.attn(self.ln_1(x))
|
| 431 |
+
x = x + self.mlp(self.ln_2(x))
|
| 432 |
+
return x
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class Transformer(nn.Module):
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
*,
|
| 439 |
+
width: int,
|
| 440 |
+
layers: int,
|
| 441 |
+
heads: int,
|
| 442 |
+
qkv_bias: bool = True,
|
| 443 |
+
norm_layer=ops.LayerNorm,
|
| 444 |
+
qk_norm: bool = False,
|
| 445 |
+
drop_path_rate: float = 0.0
|
| 446 |
+
):
|
| 447 |
+
super().__init__()
|
| 448 |
+
self.width = width
|
| 449 |
+
self.layers = layers
|
| 450 |
+
self.resblocks = nn.ModuleList(
|
| 451 |
+
[
|
| 452 |
+
ResidualAttentionBlock(
|
| 453 |
+
width=width,
|
| 454 |
+
heads=heads,
|
| 455 |
+
qkv_bias=qkv_bias,
|
| 456 |
+
norm_layer=norm_layer,
|
| 457 |
+
qk_norm=qk_norm,
|
| 458 |
+
drop_path_rate=drop_path_rate
|
| 459 |
+
)
|
| 460 |
+
for _ in range(layers)
|
| 461 |
+
]
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
def forward(self, x: torch.Tensor):
|
| 465 |
+
for block in self.resblocks:
|
| 466 |
+
x = block(x)
|
| 467 |
+
return x
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class CrossAttentionDecoder(nn.Module):
|
| 471 |
+
|
| 472 |
+
def __init__(
|
| 473 |
+
self,
|
| 474 |
+
*,
|
| 475 |
+
out_channels: int,
|
| 476 |
+
fourier_embedder: FourierEmbedder,
|
| 477 |
+
width: int,
|
| 478 |
+
heads: int,
|
| 479 |
+
mlp_expand_ratio: int = 4,
|
| 480 |
+
downsample_ratio: int = 1,
|
| 481 |
+
enable_ln_post: bool = True,
|
| 482 |
+
qkv_bias: bool = True,
|
| 483 |
+
qk_norm: bool = False,
|
| 484 |
+
label_type: str = "binary"
|
| 485 |
+
):
|
| 486 |
+
super().__init__()
|
| 487 |
+
|
| 488 |
+
self.enable_ln_post = enable_ln_post
|
| 489 |
+
self.fourier_embedder = fourier_embedder
|
| 490 |
+
self.downsample_ratio = downsample_ratio
|
| 491 |
+
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
|
| 492 |
+
if self.downsample_ratio != 1:
|
| 493 |
+
self.latents_proj = ops.Linear(width * downsample_ratio, width)
|
| 494 |
+
if self.enable_ln_post == False:
|
| 495 |
+
qk_norm = False
|
| 496 |
+
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
| 497 |
+
width=width,
|
| 498 |
+
mlp_expand_ratio=mlp_expand_ratio,
|
| 499 |
+
heads=heads,
|
| 500 |
+
qkv_bias=qkv_bias,
|
| 501 |
+
qk_norm=qk_norm
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if self.enable_ln_post:
|
| 505 |
+
self.ln_post = ops.LayerNorm(width)
|
| 506 |
+
self.output_proj = ops.Linear(width, out_channels)
|
| 507 |
+
self.label_type = label_type
|
| 508 |
+
self.count = 0
|
| 509 |
+
|
| 510 |
+
def forward(self, queries=None, query_embeddings=None, latents=None):
|
| 511 |
+
if query_embeddings is None:
|
| 512 |
+
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
| 513 |
+
self.count += query_embeddings.shape[1]
|
| 514 |
+
if self.downsample_ratio != 1:
|
| 515 |
+
latents = self.latents_proj(latents)
|
| 516 |
+
x = self.cross_attn_decoder(query_embeddings, latents)
|
| 517 |
+
if self.enable_ln_post:
|
| 518 |
+
x = self.ln_post(x)
|
| 519 |
+
occ = self.output_proj(x)
|
| 520 |
+
return occ
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class ShapeVAE(nn.Module):
|
| 524 |
+
def __init__(
|
| 525 |
+
self,
|
| 526 |
+
*,
|
| 527 |
+
embed_dim: int,
|
| 528 |
+
width: int,
|
| 529 |
+
heads: int,
|
| 530 |
+
num_decoder_layers: int,
|
| 531 |
+
geo_decoder_downsample_ratio: int = 1,
|
| 532 |
+
geo_decoder_mlp_expand_ratio: int = 4,
|
| 533 |
+
geo_decoder_ln_post: bool = True,
|
| 534 |
+
num_freqs: int = 8,
|
| 535 |
+
include_pi: bool = True,
|
| 536 |
+
qkv_bias: bool = True,
|
| 537 |
+
qk_norm: bool = False,
|
| 538 |
+
label_type: str = "binary",
|
| 539 |
+
drop_path_rate: float = 0.0,
|
| 540 |
+
scale_factor: float = 1.0,
|
| 541 |
+
):
|
| 542 |
+
super().__init__()
|
| 543 |
+
self.geo_decoder_ln_post = geo_decoder_ln_post
|
| 544 |
+
|
| 545 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
| 546 |
+
|
| 547 |
+
self.post_kl = ops.Linear(embed_dim, width)
|
| 548 |
+
|
| 549 |
+
self.transformer = Transformer(
|
| 550 |
+
width=width,
|
| 551 |
+
layers=num_decoder_layers,
|
| 552 |
+
heads=heads,
|
| 553 |
+
qkv_bias=qkv_bias,
|
| 554 |
+
qk_norm=qk_norm,
|
| 555 |
+
drop_path_rate=drop_path_rate
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
self.geo_decoder = CrossAttentionDecoder(
|
| 559 |
+
fourier_embedder=self.fourier_embedder,
|
| 560 |
+
out_channels=1,
|
| 561 |
+
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
| 562 |
+
downsample_ratio=geo_decoder_downsample_ratio,
|
| 563 |
+
enable_ln_post=self.geo_decoder_ln_post,
|
| 564 |
+
width=width // geo_decoder_downsample_ratio,
|
| 565 |
+
heads=heads // geo_decoder_downsample_ratio,
|
| 566 |
+
qkv_bias=qkv_bias,
|
| 567 |
+
qk_norm=qk_norm,
|
| 568 |
+
label_type=label_type,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
self.volume_decoder = VanillaVolumeDecoder()
|
| 572 |
+
self.scale_factor = scale_factor
|
| 573 |
+
|
| 574 |
+
def decode(self, latents, **kwargs):
|
| 575 |
+
latents = self.post_kl(latents.movedim(-2, -1))
|
| 576 |
+
latents = self.transformer(latents)
|
| 577 |
+
|
| 578 |
+
bounds = kwargs.get("bounds", 1.01)
|
| 579 |
+
num_chunks = kwargs.get("num_chunks", 8000)
|
| 580 |
+
octree_resolution = kwargs.get("octree_resolution", 256)
|
| 581 |
+
enable_pbar = kwargs.get("enable_pbar", True)
|
| 582 |
+
|
| 583 |
+
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
|
| 584 |
+
return grid_logits.movedim(-2, -1)
|
| 585 |
+
|
| 586 |
+
def encode(self, x):
|
| 587 |
+
return None
|
ComfyUI/comfy/ldm/hunyuan_video/model.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Based on Flux code because of weird hunyuan video code license.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import comfy.ldm.flux.layers
|
| 5 |
+
import comfy.ldm.modules.diffusionmodules.mmdit
|
| 6 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from einops import repeat
|
| 11 |
+
|
| 12 |
+
from torch import Tensor, nn
|
| 13 |
+
|
| 14 |
+
from comfy.ldm.flux.layers import (
|
| 15 |
+
DoubleStreamBlock,
|
| 16 |
+
EmbedND,
|
| 17 |
+
LastLayer,
|
| 18 |
+
MLPEmbedder,
|
| 19 |
+
SingleStreamBlock,
|
| 20 |
+
timestep_embedding
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import comfy.ldm.common_dit
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class HunyuanVideoParams:
|
| 28 |
+
in_channels: int
|
| 29 |
+
out_channels: int
|
| 30 |
+
vec_in_dim: int
|
| 31 |
+
context_in_dim: int
|
| 32 |
+
hidden_size: int
|
| 33 |
+
mlp_ratio: float
|
| 34 |
+
num_heads: int
|
| 35 |
+
depth: int
|
| 36 |
+
depth_single_blocks: int
|
| 37 |
+
axes_dim: list
|
| 38 |
+
theta: int
|
| 39 |
+
patch_size: list
|
| 40 |
+
qkv_bias: bool
|
| 41 |
+
guidance_embed: bool
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SelfAttentionRef(nn.Module):
|
| 45 |
+
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
| 48 |
+
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TokenRefinerBlock(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
hidden_size,
|
| 55 |
+
heads,
|
| 56 |
+
dtype=None,
|
| 57 |
+
device=None,
|
| 58 |
+
operations=None
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.heads = heads
|
| 62 |
+
mlp_hidden_dim = hidden_size * 4
|
| 63 |
+
|
| 64 |
+
self.adaLN_modulation = nn.Sequential(
|
| 65 |
+
nn.SiLU(),
|
| 66 |
+
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
| 70 |
+
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
|
| 71 |
+
|
| 72 |
+
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
| 73 |
+
|
| 74 |
+
self.mlp = nn.Sequential(
|
| 75 |
+
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
| 76 |
+
nn.SiLU(),
|
| 77 |
+
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, x, c, mask):
|
| 81 |
+
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 82 |
+
|
| 83 |
+
norm_x = self.norm1(x)
|
| 84 |
+
qkv = self.self_attn.qkv(norm_x)
|
| 85 |
+
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
| 86 |
+
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
| 87 |
+
|
| 88 |
+
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
| 89 |
+
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class IndividualTokenRefiner(nn.Module):
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
hidden_size,
|
| 97 |
+
heads,
|
| 98 |
+
num_blocks,
|
| 99 |
+
dtype=None,
|
| 100 |
+
device=None,
|
| 101 |
+
operations=None
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.blocks = nn.ModuleList(
|
| 105 |
+
[
|
| 106 |
+
TokenRefinerBlock(
|
| 107 |
+
hidden_size=hidden_size,
|
| 108 |
+
heads=heads,
|
| 109 |
+
dtype=dtype,
|
| 110 |
+
device=device,
|
| 111 |
+
operations=operations
|
| 112 |
+
)
|
| 113 |
+
for _ in range(num_blocks)
|
| 114 |
+
]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(self, x, c, mask):
|
| 118 |
+
m = None
|
| 119 |
+
if mask is not None:
|
| 120 |
+
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
| 121 |
+
m = m + m.transpose(2, 3)
|
| 122 |
+
|
| 123 |
+
for block in self.blocks:
|
| 124 |
+
x = block(x, c, m)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class TokenRefiner(nn.Module):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
text_dim,
|
| 133 |
+
hidden_size,
|
| 134 |
+
heads,
|
| 135 |
+
num_blocks,
|
| 136 |
+
dtype=None,
|
| 137 |
+
device=None,
|
| 138 |
+
operations=None
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
| 143 |
+
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
| 144 |
+
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
|
| 145 |
+
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
|
| 146 |
+
|
| 147 |
+
def forward(
|
| 148 |
+
self,
|
| 149 |
+
x,
|
| 150 |
+
timesteps,
|
| 151 |
+
mask,
|
| 152 |
+
):
|
| 153 |
+
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
| 154 |
+
# m = mask.float().unsqueeze(-1)
|
| 155 |
+
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
| 156 |
+
c = x.sum(dim=1) / x.shape[1]
|
| 157 |
+
|
| 158 |
+
c = t + self.c_embedder(c.to(x.dtype))
|
| 159 |
+
x = self.input_embedder(x)
|
| 160 |
+
x = self.individual_token_refiner(x, c, mask)
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
class HunyuanVideo(nn.Module):
|
| 164 |
+
"""
|
| 165 |
+
Transformer model for flow matching on sequences.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.dtype = dtype
|
| 171 |
+
params = HunyuanVideoParams(**kwargs)
|
| 172 |
+
self.params = params
|
| 173 |
+
self.patch_size = params.patch_size
|
| 174 |
+
self.in_channels = params.in_channels
|
| 175 |
+
self.out_channels = params.out_channels
|
| 176 |
+
if params.hidden_size % params.num_heads != 0:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
| 179 |
+
)
|
| 180 |
+
pe_dim = params.hidden_size // params.num_heads
|
| 181 |
+
if sum(params.axes_dim) != pe_dim:
|
| 182 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
| 183 |
+
self.hidden_size = params.hidden_size
|
| 184 |
+
self.num_heads = params.num_heads
|
| 185 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
| 186 |
+
|
| 187 |
+
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
| 188 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
| 189 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
| 190 |
+
self.guidance_in = (
|
| 191 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
|
| 195 |
+
|
| 196 |
+
self.double_blocks = nn.ModuleList(
|
| 197 |
+
[
|
| 198 |
+
DoubleStreamBlock(
|
| 199 |
+
self.hidden_size,
|
| 200 |
+
self.num_heads,
|
| 201 |
+
mlp_ratio=params.mlp_ratio,
|
| 202 |
+
qkv_bias=params.qkv_bias,
|
| 203 |
+
flipped_img_txt=True,
|
| 204 |
+
dtype=dtype, device=device, operations=operations
|
| 205 |
+
)
|
| 206 |
+
for _ in range(params.depth)
|
| 207 |
+
]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
self.single_blocks = nn.ModuleList(
|
| 211 |
+
[
|
| 212 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
| 213 |
+
for _ in range(params.depth_single_blocks)
|
| 214 |
+
]
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if final_layer:
|
| 218 |
+
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
| 219 |
+
|
| 220 |
+
def forward_orig(
|
| 221 |
+
self,
|
| 222 |
+
img: Tensor,
|
| 223 |
+
img_ids: Tensor,
|
| 224 |
+
txt: Tensor,
|
| 225 |
+
txt_ids: Tensor,
|
| 226 |
+
txt_mask: Tensor,
|
| 227 |
+
timesteps: Tensor,
|
| 228 |
+
y: Tensor,
|
| 229 |
+
guidance: Tensor = None,
|
| 230 |
+
guiding_frame_index=None,
|
| 231 |
+
ref_latent=None,
|
| 232 |
+
control=None,
|
| 233 |
+
transformer_options={},
|
| 234 |
+
) -> Tensor:
|
| 235 |
+
patches_replace = transformer_options.get("patches_replace", {})
|
| 236 |
+
|
| 237 |
+
initial_shape = list(img.shape)
|
| 238 |
+
# running on sequences img
|
| 239 |
+
img = self.img_in(img)
|
| 240 |
+
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
| 241 |
+
|
| 242 |
+
if ref_latent is not None:
|
| 243 |
+
ref_latent_ids = self.img_ids(ref_latent)
|
| 244 |
+
ref_latent = self.img_in(ref_latent)
|
| 245 |
+
img = torch.cat([ref_latent, img], dim=-2)
|
| 246 |
+
ref_latent_ids[..., 0] = -1
|
| 247 |
+
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
|
| 248 |
+
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
|
| 249 |
+
|
| 250 |
+
if guiding_frame_index is not None:
|
| 251 |
+
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
|
| 252 |
+
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
|
| 253 |
+
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
|
| 254 |
+
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
|
| 255 |
+
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
|
| 256 |
+
modulation_dims_txt = [(0, None, 1)]
|
| 257 |
+
else:
|
| 258 |
+
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
| 259 |
+
modulation_dims = None
|
| 260 |
+
modulation_dims_txt = None
|
| 261 |
+
|
| 262 |
+
if self.params.guidance_embed:
|
| 263 |
+
if guidance is not None:
|
| 264 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
| 265 |
+
|
| 266 |
+
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
| 267 |
+
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
| 268 |
+
|
| 269 |
+
txt = self.txt_in(txt, timesteps, txt_mask)
|
| 270 |
+
|
| 271 |
+
ids = torch.cat((img_ids, txt_ids), dim=1)
|
| 272 |
+
pe = self.pe_embedder(ids)
|
| 273 |
+
|
| 274 |
+
img_len = img.shape[1]
|
| 275 |
+
if txt_mask is not None:
|
| 276 |
+
attn_mask_len = img_len + txt.shape[1]
|
| 277 |
+
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
| 278 |
+
attn_mask[:, 0, img_len:] = txt_mask
|
| 279 |
+
else:
|
| 280 |
+
attn_mask = None
|
| 281 |
+
|
| 282 |
+
blocks_replace = patches_replace.get("dit", {})
|
| 283 |
+
for i, block in enumerate(self.double_blocks):
|
| 284 |
+
if ("double_block", i) in blocks_replace:
|
| 285 |
+
def block_wrap(args):
|
| 286 |
+
out = {}
|
| 287 |
+
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
|
| 288 |
+
return out
|
| 289 |
+
|
| 290 |
+
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
|
| 291 |
+
txt = out["txt"]
|
| 292 |
+
img = out["img"]
|
| 293 |
+
else:
|
| 294 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
|
| 295 |
+
|
| 296 |
+
if control is not None: # Controlnet
|
| 297 |
+
control_i = control.get("input")
|
| 298 |
+
if i < len(control_i):
|
| 299 |
+
add = control_i[i]
|
| 300 |
+
if add is not None:
|
| 301 |
+
img += add
|
| 302 |
+
|
| 303 |
+
img = torch.cat((img, txt), 1)
|
| 304 |
+
|
| 305 |
+
for i, block in enumerate(self.single_blocks):
|
| 306 |
+
if ("single_block", i) in blocks_replace:
|
| 307 |
+
def block_wrap(args):
|
| 308 |
+
out = {}
|
| 309 |
+
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
|
| 310 |
+
return out
|
| 311 |
+
|
| 312 |
+
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
|
| 313 |
+
img = out["img"]
|
| 314 |
+
else:
|
| 315 |
+
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
|
| 316 |
+
|
| 317 |
+
if control is not None: # Controlnet
|
| 318 |
+
control_o = control.get("output")
|
| 319 |
+
if i < len(control_o):
|
| 320 |
+
add = control_o[i]
|
| 321 |
+
if add is not None:
|
| 322 |
+
img[:, : img_len] += add
|
| 323 |
+
|
| 324 |
+
img = img[:, : img_len]
|
| 325 |
+
if ref_latent is not None:
|
| 326 |
+
img = img[:, ref_latent.shape[1]:]
|
| 327 |
+
|
| 328 |
+
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
|
| 329 |
+
|
| 330 |
+
shape = initial_shape[-3:]
|
| 331 |
+
for i in range(len(shape)):
|
| 332 |
+
shape[i] = shape[i] // self.patch_size[i]
|
| 333 |
+
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
| 334 |
+
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
| 335 |
+
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
| 336 |
+
return img
|
| 337 |
+
|
| 338 |
+
def img_ids(self, x):
|
| 339 |
+
bs, c, t, h, w = x.shape
|
| 340 |
+
patch_size = self.patch_size
|
| 341 |
+
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
| 342 |
+
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
| 343 |
+
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
| 344 |
+
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
| 345 |
+
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
| 346 |
+
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
| 347 |
+
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
| 348 |
+
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
| 349 |
+
|
| 350 |
+
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
| 351 |
+
bs, c, t, h, w = x.shape
|
| 352 |
+
img_ids = self.img_ids(x)
|
| 353 |
+
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
| 354 |
+
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
| 355 |
+
return out
|
ComfyUI/comfy/ldm/hydit/attn_layers.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Tuple, Union, Optional
|
| 4 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
| 8 |
+
"""
|
| 9 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
| 10 |
+
|
| 11 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| 12 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
| 16 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| 17 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
torch.Tensor: Reshaped frequency tensor.
|
| 21 |
+
|
| 22 |
+
Raises:
|
| 23 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
| 24 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
| 25 |
+
"""
|
| 26 |
+
ndim = x.ndim
|
| 27 |
+
assert 0 <= 1 < ndim
|
| 28 |
+
|
| 29 |
+
if isinstance(freqs_cis, tuple):
|
| 30 |
+
# freqs_cis: (cos, sin) in real space
|
| 31 |
+
if head_first:
|
| 32 |
+
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
| 33 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 34 |
+
else:
|
| 35 |
+
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
| 36 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 37 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
| 38 |
+
else:
|
| 39 |
+
# freqs_cis: values in complex space
|
| 40 |
+
if head_first:
|
| 41 |
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
| 42 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 43 |
+
else:
|
| 44 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
| 45 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 46 |
+
return freqs_cis.view(*shape)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def rotate_half(x):
|
| 50 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 51 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def apply_rotary_emb(
|
| 55 |
+
xq: torch.Tensor,
|
| 56 |
+
xk: Optional[torch.Tensor],
|
| 57 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 58 |
+
head_first: bool = False,
|
| 59 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
+
"""
|
| 61 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
| 62 |
+
|
| 63 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
| 64 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
| 65 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
| 66 |
+
returned as real tensors.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
| 70 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
| 71 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
|
| 72 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
xk_out = None
|
| 79 |
+
if isinstance(freqs_cis, tuple):
|
| 80 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
| 81 |
+
xq_out = (xq * cos + rotate_half(xq) * sin)
|
| 82 |
+
if xk is not None:
|
| 83 |
+
xk_out = (xk * cos + rotate_half(xk) * sin)
|
| 84 |
+
else:
|
| 85 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
| 86 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
| 87 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
| 88 |
+
if xk is not None:
|
| 89 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
| 90 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
| 91 |
+
|
| 92 |
+
return xq_out, xk_out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class CrossAttention(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
Use QK Normalization.
|
| 99 |
+
"""
|
| 100 |
+
def __init__(self,
|
| 101 |
+
qdim,
|
| 102 |
+
kdim,
|
| 103 |
+
num_heads,
|
| 104 |
+
qkv_bias=True,
|
| 105 |
+
qk_norm=False,
|
| 106 |
+
attn_drop=0.0,
|
| 107 |
+
proj_drop=0.0,
|
| 108 |
+
attn_precision=None,
|
| 109 |
+
device=None,
|
| 110 |
+
dtype=None,
|
| 111 |
+
operations=None,
|
| 112 |
+
):
|
| 113 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.attn_precision = attn_precision
|
| 116 |
+
self.qdim = qdim
|
| 117 |
+
self.kdim = kdim
|
| 118 |
+
self.num_heads = num_heads
|
| 119 |
+
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
| 120 |
+
self.head_dim = self.qdim // num_heads
|
| 121 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
| 122 |
+
self.scale = self.head_dim ** -0.5
|
| 123 |
+
|
| 124 |
+
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
| 125 |
+
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
| 126 |
+
|
| 127 |
+
# TODO: eps should be 1 / 65530 if using fp16
|
| 128 |
+
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
| 129 |
+
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
| 130 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 131 |
+
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
| 132 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 133 |
+
|
| 134 |
+
def forward(self, x, y, freqs_cis_img=None):
|
| 135 |
+
"""
|
| 136 |
+
Parameters
|
| 137 |
+
----------
|
| 138 |
+
x: torch.Tensor
|
| 139 |
+
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
| 140 |
+
y: torch.Tensor
|
| 141 |
+
(batch, seqlen2, hidden_dim2)
|
| 142 |
+
freqs_cis_img: torch.Tensor
|
| 143 |
+
(batch, hidden_dim // 2), RoPE for image
|
| 144 |
+
"""
|
| 145 |
+
b, s1, c = x.shape # [b, s1, D]
|
| 146 |
+
_, s2, c = y.shape # [b, s2, 1024]
|
| 147 |
+
|
| 148 |
+
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
| 149 |
+
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
|
| 150 |
+
k, v = kv.unbind(dim=2) # [b, s, h, d]
|
| 151 |
+
q = self.q_norm(q)
|
| 152 |
+
k = self.k_norm(k)
|
| 153 |
+
|
| 154 |
+
# Apply RoPE if needed
|
| 155 |
+
if freqs_cis_img is not None:
|
| 156 |
+
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
|
| 157 |
+
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
|
| 158 |
+
q = qq
|
| 159 |
+
|
| 160 |
+
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
| 161 |
+
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
| 162 |
+
v = v.transpose(-2, -3).contiguous()
|
| 163 |
+
|
| 164 |
+
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
| 165 |
+
|
| 166 |
+
out = self.out_proj(context) # context.reshape - B, L1, -1
|
| 167 |
+
out = self.proj_drop(out)
|
| 168 |
+
|
| 169 |
+
out_tuple = (out,)
|
| 170 |
+
|
| 171 |
+
return out_tuple
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class Attention(nn.Module):
|
| 175 |
+
"""
|
| 176 |
+
We rename some layer names to align with flash attention
|
| 177 |
+
"""
|
| 178 |
+
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.attn_precision = attn_precision
|
| 181 |
+
self.dim = dim
|
| 182 |
+
self.num_heads = num_heads
|
| 183 |
+
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 184 |
+
self.head_dim = self.dim // num_heads
|
| 185 |
+
# This assertion is aligned with flash attention
|
| 186 |
+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
| 187 |
+
self.scale = self.head_dim ** -0.5
|
| 188 |
+
|
| 189 |
+
# qkv --> Wqkv
|
| 190 |
+
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
| 191 |
+
# TODO: eps should be 1 / 65530 if using fp16
|
| 192 |
+
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
| 193 |
+
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
| 194 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 195 |
+
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
| 196 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, freqs_cis_img=None):
|
| 199 |
+
B, N, C = x.shape
|
| 200 |
+
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
|
| 201 |
+
q, k, v = qkv.unbind(0) # [b, h, s, d]
|
| 202 |
+
q = self.q_norm(q) # [b, h, s, d]
|
| 203 |
+
k = self.k_norm(k) # [b, h, s, d]
|
| 204 |
+
|
| 205 |
+
# Apply RoPE if needed
|
| 206 |
+
if freqs_cis_img is not None:
|
| 207 |
+
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
|
| 208 |
+
assert qq.shape == q.shape and kk.shape == k.shape, \
|
| 209 |
+
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
|
| 210 |
+
q, k = qq, kk
|
| 211 |
+
|
| 212 |
+
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
| 213 |
+
x = self.out_proj(x)
|
| 214 |
+
x = self.proj_drop(x)
|
| 215 |
+
|
| 216 |
+
out_tuple = (x,)
|
| 217 |
+
|
| 218 |
+
return out_tuple
|
ComfyUI/comfy/ldm/hydit/controlnet.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
| 7 |
+
TimestepEmbedder,
|
| 8 |
+
PatchEmbed,
|
| 9 |
+
)
|
| 10 |
+
from .poolers import AttentionPool
|
| 11 |
+
|
| 12 |
+
import comfy.latent_formats
|
| 13 |
+
from .models import HunYuanDiTBlock, calc_rope
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class HunYuanControlNet(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
HunYuanDiT: Diffusion model with a Transformer backbone.
|
| 20 |
+
|
| 21 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
| 22 |
+
|
| 23 |
+
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
args: argparse.Namespace
|
| 28 |
+
The arguments parsed by argparse.
|
| 29 |
+
input_size: tuple
|
| 30 |
+
The size of the input image.
|
| 31 |
+
patch_size: int
|
| 32 |
+
The size of the patch.
|
| 33 |
+
in_channels: int
|
| 34 |
+
The number of input channels.
|
| 35 |
+
hidden_size: int
|
| 36 |
+
The hidden size of the transformer backbone.
|
| 37 |
+
depth: int
|
| 38 |
+
The number of transformer blocks.
|
| 39 |
+
num_heads: int
|
| 40 |
+
The number of attention heads.
|
| 41 |
+
mlp_ratio: float
|
| 42 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
| 43 |
+
log_fn: callable
|
| 44 |
+
The logging function.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
input_size: tuple = 128,
|
| 50 |
+
patch_size: int = 2,
|
| 51 |
+
in_channels: int = 4,
|
| 52 |
+
hidden_size: int = 1408,
|
| 53 |
+
depth: int = 40,
|
| 54 |
+
num_heads: int = 16,
|
| 55 |
+
mlp_ratio: float = 4.3637,
|
| 56 |
+
text_states_dim=1024,
|
| 57 |
+
text_states_dim_t5=2048,
|
| 58 |
+
text_len=77,
|
| 59 |
+
text_len_t5=256,
|
| 60 |
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
| 61 |
+
size_cond=False,
|
| 62 |
+
use_style_cond=False,
|
| 63 |
+
learn_sigma=True,
|
| 64 |
+
norm="layer",
|
| 65 |
+
log_fn: callable = print,
|
| 66 |
+
attn_precision=None,
|
| 67 |
+
dtype=None,
|
| 68 |
+
device=None,
|
| 69 |
+
operations=None,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.log_fn = log_fn
|
| 74 |
+
self.depth = depth
|
| 75 |
+
self.learn_sigma = learn_sigma
|
| 76 |
+
self.in_channels = in_channels
|
| 77 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 78 |
+
self.patch_size = patch_size
|
| 79 |
+
self.num_heads = num_heads
|
| 80 |
+
self.hidden_size = hidden_size
|
| 81 |
+
self.text_states_dim = text_states_dim
|
| 82 |
+
self.text_states_dim_t5 = text_states_dim_t5
|
| 83 |
+
self.text_len = text_len
|
| 84 |
+
self.text_len_t5 = text_len_t5
|
| 85 |
+
self.size_cond = size_cond
|
| 86 |
+
self.use_style_cond = use_style_cond
|
| 87 |
+
self.norm = norm
|
| 88 |
+
self.dtype = dtype
|
| 89 |
+
self.latent_format = comfy.latent_formats.SDXL
|
| 90 |
+
|
| 91 |
+
self.mlp_t5 = nn.Sequential(
|
| 92 |
+
nn.Linear(
|
| 93 |
+
self.text_states_dim_t5,
|
| 94 |
+
self.text_states_dim_t5 * 4,
|
| 95 |
+
bias=True,
|
| 96 |
+
dtype=dtype,
|
| 97 |
+
device=device,
|
| 98 |
+
),
|
| 99 |
+
nn.SiLU(),
|
| 100 |
+
nn.Linear(
|
| 101 |
+
self.text_states_dim_t5 * 4,
|
| 102 |
+
self.text_states_dim,
|
| 103 |
+
bias=True,
|
| 104 |
+
dtype=dtype,
|
| 105 |
+
device=device,
|
| 106 |
+
),
|
| 107 |
+
)
|
| 108 |
+
# learnable replace
|
| 109 |
+
self.text_embedding_padding = nn.Parameter(
|
| 110 |
+
torch.randn(
|
| 111 |
+
self.text_len + self.text_len_t5,
|
| 112 |
+
self.text_states_dim,
|
| 113 |
+
dtype=dtype,
|
| 114 |
+
device=device,
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Attention pooling
|
| 119 |
+
pooler_out_dim = 1024
|
| 120 |
+
self.pooler = AttentionPool(
|
| 121 |
+
self.text_len_t5,
|
| 122 |
+
self.text_states_dim_t5,
|
| 123 |
+
num_heads=8,
|
| 124 |
+
output_dim=pooler_out_dim,
|
| 125 |
+
dtype=dtype,
|
| 126 |
+
device=device,
|
| 127 |
+
operations=operations,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Dimension of the extra input vectors
|
| 131 |
+
self.extra_in_dim = pooler_out_dim
|
| 132 |
+
|
| 133 |
+
if self.size_cond:
|
| 134 |
+
# Image size and crop size conditions
|
| 135 |
+
self.extra_in_dim += 6 * 256
|
| 136 |
+
|
| 137 |
+
if self.use_style_cond:
|
| 138 |
+
# Here we use a default learned embedder layer for future extension.
|
| 139 |
+
self.style_embedder = nn.Embedding(
|
| 140 |
+
1, hidden_size, dtype=dtype, device=device
|
| 141 |
+
)
|
| 142 |
+
self.extra_in_dim += hidden_size
|
| 143 |
+
|
| 144 |
+
# Text embedding for `add`
|
| 145 |
+
self.x_embedder = PatchEmbed(
|
| 146 |
+
input_size,
|
| 147 |
+
patch_size,
|
| 148 |
+
in_channels,
|
| 149 |
+
hidden_size,
|
| 150 |
+
dtype=dtype,
|
| 151 |
+
device=device,
|
| 152 |
+
operations=operations,
|
| 153 |
+
)
|
| 154 |
+
self.t_embedder = TimestepEmbedder(
|
| 155 |
+
hidden_size, dtype=dtype, device=device, operations=operations
|
| 156 |
+
)
|
| 157 |
+
self.extra_embedder = nn.Sequential(
|
| 158 |
+
operations.Linear(
|
| 159 |
+
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
|
| 160 |
+
),
|
| 161 |
+
nn.SiLU(),
|
| 162 |
+
operations.Linear(
|
| 163 |
+
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
|
| 164 |
+
),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# HUnYuanDiT Blocks
|
| 168 |
+
self.blocks = nn.ModuleList(
|
| 169 |
+
[
|
| 170 |
+
HunYuanDiTBlock(
|
| 171 |
+
hidden_size=hidden_size,
|
| 172 |
+
c_emb_size=hidden_size,
|
| 173 |
+
num_heads=num_heads,
|
| 174 |
+
mlp_ratio=mlp_ratio,
|
| 175 |
+
text_states_dim=self.text_states_dim,
|
| 176 |
+
qk_norm=qk_norm,
|
| 177 |
+
norm_type=self.norm,
|
| 178 |
+
skip=False,
|
| 179 |
+
attn_precision=attn_precision,
|
| 180 |
+
dtype=dtype,
|
| 181 |
+
device=device,
|
| 182 |
+
operations=operations,
|
| 183 |
+
)
|
| 184 |
+
for _ in range(19)
|
| 185 |
+
]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Input zero linear for the first block
|
| 189 |
+
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Output zero linear for the every block
|
| 193 |
+
self.after_proj_list = nn.ModuleList(
|
| 194 |
+
[
|
| 195 |
+
|
| 196 |
+
operations.Linear(
|
| 197 |
+
self.hidden_size, self.hidden_size, dtype=dtype, device=device
|
| 198 |
+
)
|
| 199 |
+
for _ in range(len(self.blocks))
|
| 200 |
+
]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def forward(
|
| 204 |
+
self,
|
| 205 |
+
x,
|
| 206 |
+
hint,
|
| 207 |
+
timesteps,
|
| 208 |
+
context,#encoder_hidden_states=None,
|
| 209 |
+
text_embedding_mask=None,
|
| 210 |
+
encoder_hidden_states_t5=None,
|
| 211 |
+
text_embedding_mask_t5=None,
|
| 212 |
+
image_meta_size=None,
|
| 213 |
+
style=None,
|
| 214 |
+
return_dict=False,
|
| 215 |
+
**kwarg,
|
| 216 |
+
):
|
| 217 |
+
"""
|
| 218 |
+
Forward pass of the encoder.
|
| 219 |
+
|
| 220 |
+
Parameters
|
| 221 |
+
----------
|
| 222 |
+
x: torch.Tensor
|
| 223 |
+
(B, D, H, W)
|
| 224 |
+
t: torch.Tensor
|
| 225 |
+
(B)
|
| 226 |
+
encoder_hidden_states: torch.Tensor
|
| 227 |
+
CLIP text embedding, (B, L_clip, D)
|
| 228 |
+
text_embedding_mask: torch.Tensor
|
| 229 |
+
CLIP text embedding mask, (B, L_clip)
|
| 230 |
+
encoder_hidden_states_t5: torch.Tensor
|
| 231 |
+
T5 text embedding, (B, L_t5, D)
|
| 232 |
+
text_embedding_mask_t5: torch.Tensor
|
| 233 |
+
T5 text embedding mask, (B, L_t5)
|
| 234 |
+
image_meta_size: torch.Tensor
|
| 235 |
+
(B, 6)
|
| 236 |
+
style: torch.Tensor
|
| 237 |
+
(B)
|
| 238 |
+
cos_cis_img: torch.Tensor
|
| 239 |
+
sin_cis_img: torch.Tensor
|
| 240 |
+
return_dict: bool
|
| 241 |
+
Whether to return a dictionary.
|
| 242 |
+
"""
|
| 243 |
+
condition = hint
|
| 244 |
+
if condition.shape[0] == 1:
|
| 245 |
+
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
|
| 246 |
+
|
| 247 |
+
text_states = context # 2,77,1024
|
| 248 |
+
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
| 249 |
+
text_states_mask = text_embedding_mask.bool() # 2,77
|
| 250 |
+
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
| 251 |
+
b_t5, l_t5, c_t5 = text_states_t5.shape
|
| 252 |
+
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
| 253 |
+
|
| 254 |
+
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
| 255 |
+
|
| 256 |
+
text_states[:, -self.text_len :] = torch.where(
|
| 257 |
+
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
| 258 |
+
text_states[:, -self.text_len :],
|
| 259 |
+
padding[: self.text_len],
|
| 260 |
+
)
|
| 261 |
+
text_states_t5[:, -self.text_len_t5 :] = torch.where(
|
| 262 |
+
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
|
| 263 |
+
text_states_t5[:, -self.text_len_t5 :],
|
| 264 |
+
padding[self.text_len :],
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
| 268 |
+
|
| 269 |
+
# _, _, oh, ow = x.shape
|
| 270 |
+
# th, tw = oh // self.patch_size, ow // self.patch_size
|
| 271 |
+
|
| 272 |
+
# Get image RoPE embedding according to `reso`lution.
|
| 273 |
+
freqs_cis_img = calc_rope(
|
| 274 |
+
x, self.patch_size, self.hidden_size // self.num_heads
|
| 275 |
+
) # (cos_cis_img, sin_cis_img)
|
| 276 |
+
|
| 277 |
+
# ========================= Build time and image embedding =========================
|
| 278 |
+
t = self.t_embedder(timesteps, dtype=self.dtype)
|
| 279 |
+
x = self.x_embedder(x)
|
| 280 |
+
|
| 281 |
+
# ========================= Concatenate all extra vectors =========================
|
| 282 |
+
# Build text tokens with pooling
|
| 283 |
+
extra_vec = self.pooler(encoder_hidden_states_t5)
|
| 284 |
+
|
| 285 |
+
# Build image meta size tokens if applicable
|
| 286 |
+
# if image_meta_size is not None:
|
| 287 |
+
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
|
| 288 |
+
# if image_meta_size.dtype != self.dtype:
|
| 289 |
+
# image_meta_size = image_meta_size.half()
|
| 290 |
+
# image_meta_size = image_meta_size.view(-1, 6 * 256)
|
| 291 |
+
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
| 292 |
+
|
| 293 |
+
# Build style tokens
|
| 294 |
+
if style is not None:
|
| 295 |
+
style_embedding = self.style_embedder(style)
|
| 296 |
+
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
| 297 |
+
|
| 298 |
+
# Concatenate all extra vectors
|
| 299 |
+
c = t + self.extra_embedder(extra_vec) # [B, D]
|
| 300 |
+
|
| 301 |
+
# ========================= Deal with Condition =========================
|
| 302 |
+
condition = self.x_embedder(condition)
|
| 303 |
+
|
| 304 |
+
# ========================= Forward pass through HunYuanDiT blocks =========================
|
| 305 |
+
controls = []
|
| 306 |
+
x = x + self.before_proj(condition) # add condition
|
| 307 |
+
for layer, block in enumerate(self.blocks):
|
| 308 |
+
x = block(x, c, text_states, freqs_cis_img)
|
| 309 |
+
controls.append(self.after_proj_list[layer](x)) # zero linear for output
|
| 310 |
+
|
| 311 |
+
return {"output": controls}
|
ComfyUI/comfy/ldm/hydit/models.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
import comfy.ops
|
| 6 |
+
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
|
| 7 |
+
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
| 8 |
+
from torch.utils import checkpoint
|
| 9 |
+
|
| 10 |
+
from .attn_layers import Attention, CrossAttention
|
| 11 |
+
from .poolers import AttentionPool
|
| 12 |
+
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
| 13 |
+
|
| 14 |
+
def calc_rope(x, patch_size, head_size):
|
| 15 |
+
th = (x.shape[2] + (patch_size // 2)) // patch_size
|
| 16 |
+
tw = (x.shape[3] + (patch_size // 2)) // patch_size
|
| 17 |
+
base_size = 512 // 8 // patch_size
|
| 18 |
+
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
| 19 |
+
sub_args = [start, stop, (th, tw)]
|
| 20 |
+
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
| 21 |
+
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
| 22 |
+
rope = (rope[0].to(x), rope[1].to(x))
|
| 23 |
+
return rope
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def modulate(x, shift, scale):
|
| 27 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class HunYuanDiTBlock(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
A HunYuanDiT block with `add` conditioning.
|
| 33 |
+
"""
|
| 34 |
+
def __init__(self,
|
| 35 |
+
hidden_size,
|
| 36 |
+
c_emb_size,
|
| 37 |
+
num_heads,
|
| 38 |
+
mlp_ratio=4.0,
|
| 39 |
+
text_states_dim=1024,
|
| 40 |
+
qk_norm=False,
|
| 41 |
+
norm_type="layer",
|
| 42 |
+
skip=False,
|
| 43 |
+
attn_precision=None,
|
| 44 |
+
dtype=None,
|
| 45 |
+
device=None,
|
| 46 |
+
operations=None,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
use_ele_affine = True
|
| 50 |
+
|
| 51 |
+
if norm_type == "layer":
|
| 52 |
+
norm_layer = operations.LayerNorm
|
| 53 |
+
elif norm_type == "rms":
|
| 54 |
+
norm_layer = operations.RMSNorm
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"Unknown norm_type: {norm_type}")
|
| 57 |
+
|
| 58 |
+
# ========================= Self-Attention =========================
|
| 59 |
+
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
|
| 60 |
+
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
| 61 |
+
|
| 62 |
+
# ========================= FFN =========================
|
| 63 |
+
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
|
| 64 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 65 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 66 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
|
| 67 |
+
|
| 68 |
+
# ========================= Add =========================
|
| 69 |
+
# Simply use add like SDXL.
|
| 70 |
+
self.default_modulation = nn.Sequential(
|
| 71 |
+
nn.SiLU(),
|
| 72 |
+
operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# ========================= Cross-Attention =========================
|
| 76 |
+
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
|
| 77 |
+
qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
| 78 |
+
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
| 79 |
+
|
| 80 |
+
# ========================= Skip Connection =========================
|
| 81 |
+
if skip:
|
| 82 |
+
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
| 83 |
+
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
|
| 84 |
+
else:
|
| 85 |
+
self.skip_linear = None
|
| 86 |
+
|
| 87 |
+
self.gradient_checkpointing = False
|
| 88 |
+
|
| 89 |
+
def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
| 90 |
+
# Long Skip Connection
|
| 91 |
+
if self.skip_linear is not None:
|
| 92 |
+
cat = torch.cat([x, skip], dim=-1)
|
| 93 |
+
if cat.dtype != x.dtype:
|
| 94 |
+
cat = cat.to(x.dtype)
|
| 95 |
+
cat = self.skip_norm(cat)
|
| 96 |
+
x = self.skip_linear(cat)
|
| 97 |
+
|
| 98 |
+
# Self-Attention
|
| 99 |
+
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
|
| 100 |
+
attn_inputs = (
|
| 101 |
+
self.norm1(x) + shift_msa, freq_cis_img,
|
| 102 |
+
)
|
| 103 |
+
x = x + self.attn1(*attn_inputs)[0]
|
| 104 |
+
|
| 105 |
+
# Cross-Attention
|
| 106 |
+
cross_inputs = (
|
| 107 |
+
self.norm3(x), text_states, freq_cis_img
|
| 108 |
+
)
|
| 109 |
+
x = x + self.attn2(*cross_inputs)[0]
|
| 110 |
+
|
| 111 |
+
# FFN Layer
|
| 112 |
+
mlp_inputs = self.norm2(x)
|
| 113 |
+
x = x + self.mlp(mlp_inputs)
|
| 114 |
+
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
| 118 |
+
if self.gradient_checkpointing and self.training:
|
| 119 |
+
return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
|
| 120 |
+
return self._forward(x, c, text_states, freq_cis_img, skip)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class FinalLayer(nn.Module):
|
| 124 |
+
"""
|
| 125 |
+
The final layer of HunYuanDiT.
|
| 126 |
+
"""
|
| 127 |
+
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 130 |
+
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
| 131 |
+
self.adaLN_modulation = nn.Sequential(
|
| 132 |
+
nn.SiLU(),
|
| 133 |
+
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(self, x, c):
|
| 137 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 138 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 139 |
+
x = self.linear(x)
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class HunYuanDiT(nn.Module):
|
| 144 |
+
"""
|
| 145 |
+
HunYuanDiT: Diffusion model with a Transformer backbone.
|
| 146 |
+
|
| 147 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
| 148 |
+
|
| 149 |
+
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
| 150 |
+
|
| 151 |
+
Parameters
|
| 152 |
+
----------
|
| 153 |
+
args: argparse.Namespace
|
| 154 |
+
The arguments parsed by argparse.
|
| 155 |
+
input_size: tuple
|
| 156 |
+
The size of the input image.
|
| 157 |
+
patch_size: int
|
| 158 |
+
The size of the patch.
|
| 159 |
+
in_channels: int
|
| 160 |
+
The number of input channels.
|
| 161 |
+
hidden_size: int
|
| 162 |
+
The hidden size of the transformer backbone.
|
| 163 |
+
depth: int
|
| 164 |
+
The number of transformer blocks.
|
| 165 |
+
num_heads: int
|
| 166 |
+
The number of attention heads.
|
| 167 |
+
mlp_ratio: float
|
| 168 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
| 169 |
+
log_fn: callable
|
| 170 |
+
The logging function.
|
| 171 |
+
"""
|
| 172 |
+
#@register_to_config
|
| 173 |
+
def __init__(self,
|
| 174 |
+
input_size: tuple = 32,
|
| 175 |
+
patch_size: int = 2,
|
| 176 |
+
in_channels: int = 4,
|
| 177 |
+
hidden_size: int = 1152,
|
| 178 |
+
depth: int = 28,
|
| 179 |
+
num_heads: int = 16,
|
| 180 |
+
mlp_ratio: float = 4.0,
|
| 181 |
+
text_states_dim = 1024,
|
| 182 |
+
text_states_dim_t5 = 2048,
|
| 183 |
+
text_len = 77,
|
| 184 |
+
text_len_t5 = 256,
|
| 185 |
+
qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
|
| 186 |
+
size_cond = False,
|
| 187 |
+
use_style_cond = False,
|
| 188 |
+
learn_sigma = True,
|
| 189 |
+
norm = "layer",
|
| 190 |
+
log_fn: callable = print,
|
| 191 |
+
attn_precision=None,
|
| 192 |
+
dtype=None,
|
| 193 |
+
device=None,
|
| 194 |
+
operations=None,
|
| 195 |
+
**kwargs,
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.log_fn = log_fn
|
| 199 |
+
self.depth = depth
|
| 200 |
+
self.learn_sigma = learn_sigma
|
| 201 |
+
self.in_channels = in_channels
|
| 202 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 203 |
+
self.patch_size = patch_size
|
| 204 |
+
self.num_heads = num_heads
|
| 205 |
+
self.hidden_size = hidden_size
|
| 206 |
+
self.text_states_dim = text_states_dim
|
| 207 |
+
self.text_states_dim_t5 = text_states_dim_t5
|
| 208 |
+
self.text_len = text_len
|
| 209 |
+
self.text_len_t5 = text_len_t5
|
| 210 |
+
self.size_cond = size_cond
|
| 211 |
+
self.use_style_cond = use_style_cond
|
| 212 |
+
self.norm = norm
|
| 213 |
+
self.dtype = dtype
|
| 214 |
+
#import pdb
|
| 215 |
+
#pdb.set_trace()
|
| 216 |
+
|
| 217 |
+
self.mlp_t5 = nn.Sequential(
|
| 218 |
+
operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
|
| 219 |
+
nn.SiLU(),
|
| 220 |
+
operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
|
| 221 |
+
)
|
| 222 |
+
# learnable replace
|
| 223 |
+
self.text_embedding_padding = nn.Parameter(
|
| 224 |
+
torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
|
| 225 |
+
|
| 226 |
+
# Attention pooling
|
| 227 |
+
pooler_out_dim = 1024
|
| 228 |
+
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
|
| 229 |
+
|
| 230 |
+
# Dimension of the extra input vectors
|
| 231 |
+
self.extra_in_dim = pooler_out_dim
|
| 232 |
+
|
| 233 |
+
if self.size_cond:
|
| 234 |
+
# Image size and crop size conditions
|
| 235 |
+
self.extra_in_dim += 6 * 256
|
| 236 |
+
|
| 237 |
+
if self.use_style_cond:
|
| 238 |
+
# Here we use a default learned embedder layer for future extension.
|
| 239 |
+
self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
|
| 240 |
+
self.extra_in_dim += hidden_size
|
| 241 |
+
|
| 242 |
+
# Text embedding for `add`
|
| 243 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
|
| 244 |
+
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
|
| 245 |
+
self.extra_embedder = nn.Sequential(
|
| 246 |
+
operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
|
| 247 |
+
nn.SiLU(),
|
| 248 |
+
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# HUnYuanDiT Blocks
|
| 252 |
+
self.blocks = nn.ModuleList([
|
| 253 |
+
HunYuanDiTBlock(hidden_size=hidden_size,
|
| 254 |
+
c_emb_size=hidden_size,
|
| 255 |
+
num_heads=num_heads,
|
| 256 |
+
mlp_ratio=mlp_ratio,
|
| 257 |
+
text_states_dim=self.text_states_dim,
|
| 258 |
+
qk_norm=qk_norm,
|
| 259 |
+
norm_type=self.norm,
|
| 260 |
+
skip=layer > depth // 2,
|
| 261 |
+
attn_precision=attn_precision,
|
| 262 |
+
dtype=dtype,
|
| 263 |
+
device=device,
|
| 264 |
+
operations=operations,
|
| 265 |
+
)
|
| 266 |
+
for layer in range(depth)
|
| 267 |
+
])
|
| 268 |
+
|
| 269 |
+
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
| 270 |
+
self.unpatchify_channels = self.out_channels
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def forward(self,
|
| 275 |
+
x,
|
| 276 |
+
t,
|
| 277 |
+
context,#encoder_hidden_states=None,
|
| 278 |
+
text_embedding_mask=None,
|
| 279 |
+
encoder_hidden_states_t5=None,
|
| 280 |
+
text_embedding_mask_t5=None,
|
| 281 |
+
image_meta_size=None,
|
| 282 |
+
style=None,
|
| 283 |
+
return_dict=False,
|
| 284 |
+
control=None,
|
| 285 |
+
transformer_options={},
|
| 286 |
+
):
|
| 287 |
+
"""
|
| 288 |
+
Forward pass of the encoder.
|
| 289 |
+
|
| 290 |
+
Parameters
|
| 291 |
+
----------
|
| 292 |
+
x: torch.Tensor
|
| 293 |
+
(B, D, H, W)
|
| 294 |
+
t: torch.Tensor
|
| 295 |
+
(B)
|
| 296 |
+
encoder_hidden_states: torch.Tensor
|
| 297 |
+
CLIP text embedding, (B, L_clip, D)
|
| 298 |
+
text_embedding_mask: torch.Tensor
|
| 299 |
+
CLIP text embedding mask, (B, L_clip)
|
| 300 |
+
encoder_hidden_states_t5: torch.Tensor
|
| 301 |
+
T5 text embedding, (B, L_t5, D)
|
| 302 |
+
text_embedding_mask_t5: torch.Tensor
|
| 303 |
+
T5 text embedding mask, (B, L_t5)
|
| 304 |
+
image_meta_size: torch.Tensor
|
| 305 |
+
(B, 6)
|
| 306 |
+
style: torch.Tensor
|
| 307 |
+
(B)
|
| 308 |
+
cos_cis_img: torch.Tensor
|
| 309 |
+
sin_cis_img: torch.Tensor
|
| 310 |
+
return_dict: bool
|
| 311 |
+
Whether to return a dictionary.
|
| 312 |
+
"""
|
| 313 |
+
patches_replace = transformer_options.get("patches_replace", {})
|
| 314 |
+
encoder_hidden_states = context
|
| 315 |
+
text_states = encoder_hidden_states # 2,77,1024
|
| 316 |
+
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
| 317 |
+
text_states_mask = text_embedding_mask.bool() # 2,77
|
| 318 |
+
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
| 319 |
+
b_t5, l_t5, c_t5 = text_states_t5.shape
|
| 320 |
+
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
| 321 |
+
|
| 322 |
+
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
| 323 |
+
|
| 324 |
+
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
|
| 325 |
+
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
|
| 326 |
+
|
| 327 |
+
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
| 328 |
+
# clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
|
| 329 |
+
|
| 330 |
+
_, _, oh, ow = x.shape
|
| 331 |
+
th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# Get image RoPE embedding according to `reso`lution.
|
| 335 |
+
freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
|
| 336 |
+
|
| 337 |
+
# ========================= Build time and image embedding =========================
|
| 338 |
+
t = self.t_embedder(t, dtype=x.dtype)
|
| 339 |
+
x = self.x_embedder(x)
|
| 340 |
+
|
| 341 |
+
# ========================= Concatenate all extra vectors =========================
|
| 342 |
+
# Build text tokens with pooling
|
| 343 |
+
extra_vec = self.pooler(encoder_hidden_states_t5)
|
| 344 |
+
|
| 345 |
+
# Build image meta size tokens if applicable
|
| 346 |
+
if self.size_cond:
|
| 347 |
+
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
|
| 348 |
+
image_meta_size = image_meta_size.view(-1, 6 * 256)
|
| 349 |
+
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
| 350 |
+
|
| 351 |
+
# Build style tokens
|
| 352 |
+
if self.use_style_cond:
|
| 353 |
+
if style is None:
|
| 354 |
+
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
|
| 355 |
+
style_embedding = self.style_embedder(style, out_dtype=x.dtype)
|
| 356 |
+
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
| 357 |
+
|
| 358 |
+
# Concatenate all extra vectors
|
| 359 |
+
c = t + self.extra_embedder(extra_vec) # [B, D]
|
| 360 |
+
|
| 361 |
+
blocks_replace = patches_replace.get("dit", {})
|
| 362 |
+
|
| 363 |
+
controls = None
|
| 364 |
+
if control:
|
| 365 |
+
controls = control.get("output", None)
|
| 366 |
+
# ========================= Forward pass through HunYuanDiT blocks =========================
|
| 367 |
+
skips = []
|
| 368 |
+
for layer, block in enumerate(self.blocks):
|
| 369 |
+
if layer > self.depth // 2:
|
| 370 |
+
if controls is not None:
|
| 371 |
+
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
| 372 |
+
else:
|
| 373 |
+
skip = skips.pop()
|
| 374 |
+
else:
|
| 375 |
+
skip = None
|
| 376 |
+
|
| 377 |
+
if ("double_block", layer) in blocks_replace:
|
| 378 |
+
def block_wrap(args):
|
| 379 |
+
out = {}
|
| 380 |
+
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
| 381 |
+
return out
|
| 382 |
+
|
| 383 |
+
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
| 384 |
+
x = out["img"]
|
| 385 |
+
else:
|
| 386 |
+
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
if layer < (self.depth // 2 - 1):
|
| 390 |
+
skips.append(x)
|
| 391 |
+
if controls is not None and len(controls) != 0:
|
| 392 |
+
raise ValueError("The number of controls is not equal to the number of skip connections.")
|
| 393 |
+
|
| 394 |
+
# ========================= Final layer =========================
|
| 395 |
+
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
|
| 396 |
+
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
|
| 397 |
+
|
| 398 |
+
if return_dict:
|
| 399 |
+
return {'x': x}
|
| 400 |
+
if self.learn_sigma:
|
| 401 |
+
return x[:,:self.out_channels // 2,:oh,:ow]
|
| 402 |
+
return x[:,:,:oh,:ow]
|
| 403 |
+
|
| 404 |
+
def unpatchify(self, x, h, w):
|
| 405 |
+
"""
|
| 406 |
+
x: (N, T, patch_size**2 * C)
|
| 407 |
+
imgs: (N, H, W, C)
|
| 408 |
+
"""
|
| 409 |
+
c = self.unpatchify_channels
|
| 410 |
+
p = self.x_embedder.patch_size[0]
|
| 411 |
+
# h = w = int(x.shape[1] ** 0.5)
|
| 412 |
+
assert h * w == x.shape[1]
|
| 413 |
+
|
| 414 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 415 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 416 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 417 |
+
return imgs
|
ComfyUI/comfy/ldm/hydit/poolers.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from comfy.ldm.modules.attention import optimized_attention
|
| 4 |
+
import comfy.ops
|
| 5 |
+
|
| 6 |
+
class AttentionPool(nn.Module):
|
| 7 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
|
| 10 |
+
self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
| 11 |
+
self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
| 12 |
+
self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
| 13 |
+
self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
|
| 14 |
+
self.num_heads = num_heads
|
| 15 |
+
self.embed_dim = embed_dim
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = x[:,:self.positional_embedding.shape[0] - 1]
|
| 19 |
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
| 20 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
| 21 |
+
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
|
| 22 |
+
|
| 23 |
+
q = self.q_proj(x[:1])
|
| 24 |
+
k = self.k_proj(x)
|
| 25 |
+
v = self.v_proj(x)
|
| 26 |
+
|
| 27 |
+
batch_size = q.shape[1]
|
| 28 |
+
head_dim = self.embed_dim // self.num_heads
|
| 29 |
+
q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
| 30 |
+
k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
| 31 |
+
v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
| 32 |
+
|
| 33 |
+
attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
|
| 34 |
+
|
| 35 |
+
attn_output = self.c_proj(attn_output)
|
| 36 |
+
return attn_output.squeeze(0)
|
ComfyUI/comfy/ldm/hydit/posemb_layers.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _to_tuple(x):
|
| 7 |
+
if isinstance(x, int):
|
| 8 |
+
return x, x
|
| 9 |
+
else:
|
| 10 |
+
return x
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_fill_resize_and_crop(src, tgt):
|
| 14 |
+
th, tw = _to_tuple(tgt)
|
| 15 |
+
h, w = _to_tuple(src)
|
| 16 |
+
|
| 17 |
+
tr = th / tw # base resolution
|
| 18 |
+
r = h / w # target resolution
|
| 19 |
+
|
| 20 |
+
# resize
|
| 21 |
+
if r > tr:
|
| 22 |
+
resize_height = th
|
| 23 |
+
resize_width = int(round(th / h * w))
|
| 24 |
+
else:
|
| 25 |
+
resize_width = tw
|
| 26 |
+
resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
|
| 27 |
+
|
| 28 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 29 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 30 |
+
|
| 31 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_meshgrid(start, *args):
|
| 35 |
+
if len(args) == 0:
|
| 36 |
+
# start is grid_size
|
| 37 |
+
num = _to_tuple(start)
|
| 38 |
+
start = (0, 0)
|
| 39 |
+
stop = num
|
| 40 |
+
elif len(args) == 1:
|
| 41 |
+
# start is start, args[0] is stop, step is 1
|
| 42 |
+
start = _to_tuple(start)
|
| 43 |
+
stop = _to_tuple(args[0])
|
| 44 |
+
num = (stop[0] - start[0], stop[1] - start[1])
|
| 45 |
+
elif len(args) == 2:
|
| 46 |
+
# start is start, args[0] is stop, args[1] is num
|
| 47 |
+
start = _to_tuple(start)
|
| 48 |
+
stop = _to_tuple(args[0])
|
| 49 |
+
num = _to_tuple(args[1])
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
| 52 |
+
|
| 53 |
+
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
|
| 54 |
+
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
| 55 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 56 |
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
| 57 |
+
return grid
|
| 58 |
+
|
| 59 |
+
#################################################################################
|
| 60 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 61 |
+
#################################################################################
|
| 62 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 63 |
+
|
| 64 |
+
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
|
| 65 |
+
"""
|
| 66 |
+
grid_size: int of the grid height and width
|
| 67 |
+
return:
|
| 68 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 69 |
+
"""
|
| 70 |
+
grid = get_meshgrid(start, *args) # [2, H, w]
|
| 71 |
+
# grid_h = np.arange(grid_size, dtype=np.float32)
|
| 72 |
+
# grid_w = np.arange(grid_size, dtype=np.float32)
|
| 73 |
+
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 74 |
+
# grid = np.stack(grid, axis=0) # [2, W, H]
|
| 75 |
+
|
| 76 |
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
| 77 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 78 |
+
if cls_token and extra_tokens > 0:
|
| 79 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 80 |
+
return pos_embed
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 84 |
+
assert embed_dim % 2 == 0
|
| 85 |
+
|
| 86 |
+
# use half of dimensions to encode grid_h
|
| 87 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 88 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 89 |
+
|
| 90 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 91 |
+
return emb
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 95 |
+
"""
|
| 96 |
+
embed_dim: output dimension for each position
|
| 97 |
+
pos: a list of positions to be encoded: size (W,H)
|
| 98 |
+
out: (M, D)
|
| 99 |
+
"""
|
| 100 |
+
assert embed_dim % 2 == 0
|
| 101 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 102 |
+
omega /= embed_dim / 2.
|
| 103 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 104 |
+
|
| 105 |
+
pos = pos.reshape(-1) # (M,)
|
| 106 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 107 |
+
|
| 108 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 109 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 110 |
+
|
| 111 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 112 |
+
return emb
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
#################################################################################
|
| 116 |
+
# Rotary Positional Embedding Functions #
|
| 117 |
+
#################################################################################
|
| 118 |
+
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
|
| 119 |
+
|
| 120 |
+
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
|
| 121 |
+
"""
|
| 122 |
+
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
|
| 123 |
+
|
| 124 |
+
Parameters
|
| 125 |
+
----------
|
| 126 |
+
embed_dim: int
|
| 127 |
+
embedding dimension size
|
| 128 |
+
start: int or tuple of int
|
| 129 |
+
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
|
| 130 |
+
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
| 131 |
+
use_real: bool
|
| 132 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 133 |
+
|
| 134 |
+
Returns
|
| 135 |
+
-------
|
| 136 |
+
pos_embed: torch.Tensor
|
| 137 |
+
[HW, D/2]
|
| 138 |
+
"""
|
| 139 |
+
grid = get_meshgrid(start, *args) # [2, H, w]
|
| 140 |
+
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
|
| 141 |
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
| 142 |
+
return pos_embed
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
| 146 |
+
assert embed_dim % 4 == 0
|
| 147 |
+
|
| 148 |
+
# use half of dimensions to encode grid_h
|
| 149 |
+
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
| 150 |
+
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
| 151 |
+
|
| 152 |
+
if use_real:
|
| 153 |
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
| 154 |
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
| 155 |
+
return cos, sin
|
| 156 |
+
else:
|
| 157 |
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
| 158 |
+
return emb
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
| 162 |
+
"""
|
| 163 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 164 |
+
|
| 165 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
| 166 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 167 |
+
The returned tensor contains complex values in complex64 data type.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
dim (int): Dimension of the frequency tensor.
|
| 171 |
+
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
|
| 172 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 173 |
+
use_real (bool, optional): If True, return real part and imaginary part separately.
|
| 174 |
+
Otherwise, return complex numbers.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| 178 |
+
|
| 179 |
+
"""
|
| 180 |
+
if isinstance(pos, int):
|
| 181 |
+
pos = np.arange(pos)
|
| 182 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
| 183 |
+
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
| 184 |
+
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
| 185 |
+
if use_real:
|
| 186 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
| 187 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
| 188 |
+
return freqs_cos, freqs_sin
|
| 189 |
+
else:
|
| 190 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
| 191 |
+
return freqs_cis
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def calc_sizes(rope_img, patch_size, th, tw):
|
| 196 |
+
if rope_img == 'extend':
|
| 197 |
+
# Expansion mode
|
| 198 |
+
sub_args = [(th, tw)]
|
| 199 |
+
elif rope_img.startswith('base'):
|
| 200 |
+
# Based on the specified dimensions, other dimensions are obtained through interpolation.
|
| 201 |
+
base_size = int(rope_img[4:]) // 8 // patch_size
|
| 202 |
+
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
| 203 |
+
sub_args = [start, stop, (th, tw)]
|
| 204 |
+
else:
|
| 205 |
+
raise ValueError(f"Unknown rope_img: {rope_img}")
|
| 206 |
+
return sub_args
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def init_image_posemb(rope_img,
|
| 210 |
+
resolutions,
|
| 211 |
+
patch_size,
|
| 212 |
+
hidden_size,
|
| 213 |
+
num_heads,
|
| 214 |
+
log_fn,
|
| 215 |
+
rope_real=True,
|
| 216 |
+
):
|
| 217 |
+
freqs_cis_img = {}
|
| 218 |
+
for reso in resolutions:
|
| 219 |
+
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
|
| 220 |
+
sub_args = calc_sizes(rope_img, patch_size, th, tw)
|
| 221 |
+
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
|
| 222 |
+
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
|
| 223 |
+
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
|
| 224 |
+
return freqs_cis_img
|
ComfyUI/comfy/ldm/lightricks/model.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import comfy.ldm.modules.attention
|
| 4 |
+
import comfy.ldm.common_dit
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import math
|
| 7 |
+
from typing import Dict, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_timestep_embedding(
|
| 13 |
+
timesteps: torch.Tensor,
|
| 14 |
+
embedding_dim: int,
|
| 15 |
+
flip_sin_to_cos: bool = False,
|
| 16 |
+
downscale_freq_shift: float = 1,
|
| 17 |
+
scale: float = 1,
|
| 18 |
+
max_period: int = 10000,
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 22 |
+
|
| 23 |
+
Args
|
| 24 |
+
timesteps (torch.Tensor):
|
| 25 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 26 |
+
embedding_dim (int):
|
| 27 |
+
the dimension of the output.
|
| 28 |
+
flip_sin_to_cos (bool):
|
| 29 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
| 30 |
+
downscale_freq_shift (float):
|
| 31 |
+
Controls the delta between frequencies between dimensions
|
| 32 |
+
scale (float):
|
| 33 |
+
Scaling factor applied to the embeddings.
|
| 34 |
+
max_period (int):
|
| 35 |
+
Controls the maximum frequency of the embeddings
|
| 36 |
+
Returns
|
| 37 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
| 38 |
+
"""
|
| 39 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 40 |
+
|
| 41 |
+
half_dim = embedding_dim // 2
|
| 42 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 43 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 44 |
+
)
|
| 45 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 46 |
+
|
| 47 |
+
emb = torch.exp(exponent)
|
| 48 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 49 |
+
|
| 50 |
+
# scale embeddings
|
| 51 |
+
emb = scale * emb
|
| 52 |
+
|
| 53 |
+
# concat sine and cosine embeddings
|
| 54 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 55 |
+
|
| 56 |
+
# flip sine and cosine embeddings
|
| 57 |
+
if flip_sin_to_cos:
|
| 58 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 59 |
+
|
| 60 |
+
# zero pad
|
| 61 |
+
if embedding_dim % 2 == 1:
|
| 62 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 63 |
+
return emb
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TimestepEmbedding(nn.Module):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
in_channels: int,
|
| 70 |
+
time_embed_dim: int,
|
| 71 |
+
act_fn: str = "silu",
|
| 72 |
+
out_dim: int = None,
|
| 73 |
+
post_act_fn: Optional[str] = None,
|
| 74 |
+
cond_proj_dim=None,
|
| 75 |
+
sample_proj_bias=True,
|
| 76 |
+
dtype=None, device=None, operations=None,
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
| 81 |
+
|
| 82 |
+
if cond_proj_dim is not None:
|
| 83 |
+
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
| 84 |
+
else:
|
| 85 |
+
self.cond_proj = None
|
| 86 |
+
|
| 87 |
+
self.act = nn.SiLU()
|
| 88 |
+
|
| 89 |
+
if out_dim is not None:
|
| 90 |
+
time_embed_dim_out = out_dim
|
| 91 |
+
else:
|
| 92 |
+
time_embed_dim_out = time_embed_dim
|
| 93 |
+
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
|
| 94 |
+
|
| 95 |
+
if post_act_fn is None:
|
| 96 |
+
self.post_act = None
|
| 97 |
+
# else:
|
| 98 |
+
# self.post_act = get_activation(post_act_fn)
|
| 99 |
+
|
| 100 |
+
def forward(self, sample, condition=None):
|
| 101 |
+
if condition is not None:
|
| 102 |
+
sample = sample + self.cond_proj(condition)
|
| 103 |
+
sample = self.linear_1(sample)
|
| 104 |
+
|
| 105 |
+
if self.act is not None:
|
| 106 |
+
sample = self.act(sample)
|
| 107 |
+
|
| 108 |
+
sample = self.linear_2(sample)
|
| 109 |
+
|
| 110 |
+
if self.post_act is not None:
|
| 111 |
+
sample = self.post_act(sample)
|
| 112 |
+
return sample
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Timesteps(nn.Module):
|
| 116 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.num_channels = num_channels
|
| 119 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 120 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 121 |
+
self.scale = scale
|
| 122 |
+
|
| 123 |
+
def forward(self, timesteps):
|
| 124 |
+
t_emb = get_timestep_embedding(
|
| 125 |
+
timesteps,
|
| 126 |
+
self.num_channels,
|
| 127 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 128 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 129 |
+
scale=self.scale,
|
| 130 |
+
)
|
| 131 |
+
return t_emb
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
| 135 |
+
"""
|
| 136 |
+
For PixArt-Alpha.
|
| 137 |
+
|
| 138 |
+
Reference:
|
| 139 |
+
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
| 143 |
+
super().__init__()
|
| 144 |
+
|
| 145 |
+
self.outdim = size_emb_dim
|
| 146 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 147 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
|
| 148 |
+
|
| 149 |
+
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
| 150 |
+
timesteps_proj = self.time_proj(timestep)
|
| 151 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
| 152 |
+
return timesteps_emb
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class AdaLayerNormSingle(nn.Module):
|
| 156 |
+
r"""
|
| 157 |
+
Norm layer adaptive layer norm single (adaLN-single).
|
| 158 |
+
|
| 159 |
+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
| 160 |
+
|
| 161 |
+
Parameters:
|
| 162 |
+
embedding_dim (`int`): The size of each embedding vector.
|
| 163 |
+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
|
| 167 |
+
super().__init__()
|
| 168 |
+
|
| 169 |
+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 170 |
+
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.silu = nn.SiLU()
|
| 174 |
+
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
|
| 175 |
+
|
| 176 |
+
def forward(
|
| 177 |
+
self,
|
| 178 |
+
timestep: torch.Tensor,
|
| 179 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 180 |
+
batch_size: Optional[int] = None,
|
| 181 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
| 182 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 183 |
+
# No modulation happening here.
|
| 184 |
+
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
| 185 |
+
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
| 186 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
| 187 |
+
|
| 188 |
+
class PixArtAlphaTextProjection(nn.Module):
|
| 189 |
+
"""
|
| 190 |
+
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
| 191 |
+
|
| 192 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
|
| 196 |
+
super().__init__()
|
| 197 |
+
if out_features is None:
|
| 198 |
+
out_features = hidden_size
|
| 199 |
+
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
|
| 200 |
+
if act_fn == "gelu_tanh":
|
| 201 |
+
self.act_1 = nn.GELU(approximate="tanh")
|
| 202 |
+
elif act_fn == "silu":
|
| 203 |
+
self.act_1 = nn.SiLU()
|
| 204 |
+
else:
|
| 205 |
+
raise ValueError(f"Unknown activation function: {act_fn}")
|
| 206 |
+
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
|
| 207 |
+
|
| 208 |
+
def forward(self, caption):
|
| 209 |
+
hidden_states = self.linear_1(caption)
|
| 210 |
+
hidden_states = self.act_1(hidden_states)
|
| 211 |
+
hidden_states = self.linear_2(hidden_states)
|
| 212 |
+
return hidden_states
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class GELU_approx(nn.Module):
|
| 216 |
+
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class FeedForward(nn.Module):
|
| 225 |
+
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
|
| 226 |
+
super().__init__()
|
| 227 |
+
inner_dim = int(dim * mult)
|
| 228 |
+
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
| 229 |
+
|
| 230 |
+
self.net = nn.Sequential(
|
| 231 |
+
project_in,
|
| 232 |
+
nn.Dropout(dropout),
|
| 233 |
+
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
return self.net(x)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
| 241 |
+
cos_freqs = freqs_cis[0]
|
| 242 |
+
sin_freqs = freqs_cis[1]
|
| 243 |
+
|
| 244 |
+
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
| 245 |
+
t1, t2 = t_dup.unbind(dim=-1)
|
| 246 |
+
t_dup = torch.stack((-t2, t1), dim=-1)
|
| 247 |
+
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
| 248 |
+
|
| 249 |
+
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
| 250 |
+
|
| 251 |
+
return out
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class CrossAttention(nn.Module):
|
| 255 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
| 256 |
+
super().__init__()
|
| 257 |
+
inner_dim = dim_head * heads
|
| 258 |
+
context_dim = query_dim if context_dim is None else context_dim
|
| 259 |
+
self.attn_precision = attn_precision
|
| 260 |
+
|
| 261 |
+
self.heads = heads
|
| 262 |
+
self.dim_head = dim_head
|
| 263 |
+
|
| 264 |
+
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
| 265 |
+
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
| 266 |
+
|
| 267 |
+
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
| 268 |
+
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
| 269 |
+
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
| 270 |
+
|
| 271 |
+
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
| 272 |
+
|
| 273 |
+
def forward(self, x, context=None, mask=None, pe=None):
|
| 274 |
+
q = self.to_q(x)
|
| 275 |
+
context = x if context is None else context
|
| 276 |
+
k = self.to_k(context)
|
| 277 |
+
v = self.to_v(context)
|
| 278 |
+
|
| 279 |
+
q = self.q_norm(q)
|
| 280 |
+
k = self.k_norm(k)
|
| 281 |
+
|
| 282 |
+
if pe is not None:
|
| 283 |
+
q = apply_rotary_emb(q, pe)
|
| 284 |
+
k = apply_rotary_emb(k, pe)
|
| 285 |
+
|
| 286 |
+
if mask is None:
|
| 287 |
+
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
| 288 |
+
else:
|
| 289 |
+
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
| 290 |
+
return self.to_out(out)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class BasicTransformerBlock(nn.Module):
|
| 294 |
+
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
|
| 295 |
+
super().__init__()
|
| 296 |
+
|
| 297 |
+
self.attn_precision = attn_precision
|
| 298 |
+
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
| 299 |
+
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
| 300 |
+
|
| 301 |
+
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
|
| 302 |
+
|
| 303 |
+
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
| 304 |
+
|
| 305 |
+
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
| 306 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
| 307 |
+
|
| 308 |
+
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
| 309 |
+
|
| 310 |
+
x += self.attn2(x, context=context, mask=attention_mask)
|
| 311 |
+
|
| 312 |
+
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
| 313 |
+
x += self.ff(y) * gate_mlp
|
| 314 |
+
|
| 315 |
+
return x
|
| 316 |
+
|
| 317 |
+
def get_fractional_positions(indices_grid, max_pos):
|
| 318 |
+
fractional_positions = torch.stack(
|
| 319 |
+
[
|
| 320 |
+
indices_grid[:, i] / max_pos[i]
|
| 321 |
+
for i in range(3)
|
| 322 |
+
],
|
| 323 |
+
dim=-1,
|
| 324 |
+
)
|
| 325 |
+
return fractional_positions
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
| 329 |
+
dtype = torch.float32 #self.dtype
|
| 330 |
+
|
| 331 |
+
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
| 332 |
+
|
| 333 |
+
start = 1
|
| 334 |
+
end = theta
|
| 335 |
+
device = fractional_positions.device
|
| 336 |
+
|
| 337 |
+
indices = theta ** (
|
| 338 |
+
torch.linspace(
|
| 339 |
+
math.log(start, theta),
|
| 340 |
+
math.log(end, theta),
|
| 341 |
+
dim // 6,
|
| 342 |
+
device=device,
|
| 343 |
+
dtype=dtype,
|
| 344 |
+
)
|
| 345 |
+
)
|
| 346 |
+
indices = indices.to(dtype=dtype)
|
| 347 |
+
|
| 348 |
+
indices = indices * math.pi / 2
|
| 349 |
+
|
| 350 |
+
freqs = (
|
| 351 |
+
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
| 352 |
+
.transpose(-1, -2)
|
| 353 |
+
.flatten(2)
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
| 357 |
+
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
| 358 |
+
if dim % 6 != 0:
|
| 359 |
+
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
| 360 |
+
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
| 361 |
+
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
| 362 |
+
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
| 363 |
+
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class LTXVModel(torch.nn.Module):
|
| 367 |
+
def __init__(self,
|
| 368 |
+
in_channels=128,
|
| 369 |
+
cross_attention_dim=2048,
|
| 370 |
+
attention_head_dim=64,
|
| 371 |
+
num_attention_heads=32,
|
| 372 |
+
|
| 373 |
+
caption_channels=4096,
|
| 374 |
+
num_layers=28,
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
positional_embedding_theta=10000.0,
|
| 378 |
+
positional_embedding_max_pos=[20, 2048, 2048],
|
| 379 |
+
causal_temporal_positioning=False,
|
| 380 |
+
vae_scale_factors=(8, 32, 32),
|
| 381 |
+
dtype=None, device=None, operations=None, **kwargs):
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.generator = None
|
| 384 |
+
self.vae_scale_factors = vae_scale_factors
|
| 385 |
+
self.dtype = dtype
|
| 386 |
+
self.out_channels = in_channels
|
| 387 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 388 |
+
self.causal_temporal_positioning = causal_temporal_positioning
|
| 389 |
+
|
| 390 |
+
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
| 391 |
+
|
| 392 |
+
self.adaln_single = AdaLayerNormSingle(
|
| 393 |
+
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
| 397 |
+
|
| 398 |
+
self.caption_projection = PixArtAlphaTextProjection(
|
| 399 |
+
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
self.transformer_blocks = nn.ModuleList(
|
| 403 |
+
[
|
| 404 |
+
BasicTransformerBlock(
|
| 405 |
+
self.inner_dim,
|
| 406 |
+
num_attention_heads,
|
| 407 |
+
attention_head_dim,
|
| 408 |
+
context_dim=cross_attention_dim,
|
| 409 |
+
# attn_precision=attn_precision,
|
| 410 |
+
dtype=dtype, device=device, operations=operations
|
| 411 |
+
)
|
| 412 |
+
for d in range(num_layers)
|
| 413 |
+
]
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
| 417 |
+
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 418 |
+
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
| 419 |
+
|
| 420 |
+
self.patchifier = SymmetricPatchifier(1)
|
| 421 |
+
|
| 422 |
+
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
|
| 423 |
+
patches_replace = transformer_options.get("patches_replace", {})
|
| 424 |
+
|
| 425 |
+
orig_shape = list(x.shape)
|
| 426 |
+
|
| 427 |
+
x, latent_coords = self.patchifier.patchify(x)
|
| 428 |
+
pixel_coords = latent_to_pixel_coords(
|
| 429 |
+
latent_coords=latent_coords,
|
| 430 |
+
scale_factors=self.vae_scale_factors,
|
| 431 |
+
causal_fix=self.causal_temporal_positioning,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if keyframe_idxs is not None:
|
| 435 |
+
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
|
| 436 |
+
|
| 437 |
+
fractional_coords = pixel_coords.to(torch.float32)
|
| 438 |
+
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
| 439 |
+
|
| 440 |
+
x = self.patchify_proj(x)
|
| 441 |
+
timestep = timestep * 1000.0
|
| 442 |
+
|
| 443 |
+
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
| 444 |
+
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
|
| 445 |
+
|
| 446 |
+
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
|
| 447 |
+
|
| 448 |
+
batch_size = x.shape[0]
|
| 449 |
+
timestep, embedded_timestep = self.adaln_single(
|
| 450 |
+
timestep.flatten(),
|
| 451 |
+
{"resolution": None, "aspect_ratio": None},
|
| 452 |
+
batch_size=batch_size,
|
| 453 |
+
hidden_dtype=x.dtype,
|
| 454 |
+
)
|
| 455 |
+
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
| 456 |
+
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
| 457 |
+
embedded_timestep = embedded_timestep.view(
|
| 458 |
+
batch_size, -1, embedded_timestep.shape[-1]
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# 2. Blocks
|
| 462 |
+
if self.caption_projection is not None:
|
| 463 |
+
batch_size = x.shape[0]
|
| 464 |
+
context = self.caption_projection(context)
|
| 465 |
+
context = context.view(
|
| 466 |
+
batch_size, -1, x.shape[-1]
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
blocks_replace = patches_replace.get("dit", {})
|
| 470 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 471 |
+
if ("double_block", i) in blocks_replace:
|
| 472 |
+
def block_wrap(args):
|
| 473 |
+
out = {}
|
| 474 |
+
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
| 475 |
+
return out
|
| 476 |
+
|
| 477 |
+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
| 478 |
+
x = out["img"]
|
| 479 |
+
else:
|
| 480 |
+
x = block(
|
| 481 |
+
x,
|
| 482 |
+
context=context,
|
| 483 |
+
attention_mask=attention_mask,
|
| 484 |
+
timestep=timestep,
|
| 485 |
+
pe=pe
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# 3. Output
|
| 489 |
+
scale_shift_values = (
|
| 490 |
+
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
| 491 |
+
)
|
| 492 |
+
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
| 493 |
+
x = self.norm_out(x)
|
| 494 |
+
# Modulation
|
| 495 |
+
x = x * (1 + scale) + shift
|
| 496 |
+
x = self.proj_out(x)
|
| 497 |
+
|
| 498 |
+
x = self.patchifier.unpatchify(
|
| 499 |
+
latents=x,
|
| 500 |
+
output_height=orig_shape[3],
|
| 501 |
+
output_width=orig_shape[4],
|
| 502 |
+
output_num_frames=orig_shape[2],
|
| 503 |
+
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
return x
|
ComfyUI/comfy/ldm/lightricks/symmetric_patchifier.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def latent_to_pixel_coords(
|
| 10 |
+
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
|
| 11 |
+
) -> Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
| 14 |
+
configuration.
|
| 15 |
+
Args:
|
| 16 |
+
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
| 17 |
+
containing the latent corner coordinates of each token.
|
| 18 |
+
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
|
| 19 |
+
causal_fix (bool): Whether to take into account the different temporal scale
|
| 20 |
+
of the first frame. Default = False for backwards compatibility.
|
| 21 |
+
Returns:
|
| 22 |
+
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
| 23 |
+
"""
|
| 24 |
+
pixel_coords = (
|
| 25 |
+
latent_coords
|
| 26 |
+
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
| 27 |
+
)
|
| 28 |
+
if causal_fix:
|
| 29 |
+
# Fix temporal scale for first frame to 1 due to causality
|
| 30 |
+
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
| 31 |
+
return pixel_coords
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Patchifier(ABC):
|
| 35 |
+
def __init__(self, patch_size: int):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self._patch_size = (1, patch_size, patch_size)
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def patchify(
|
| 41 |
+
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
| 42 |
+
) -> Tuple[Tensor, Tensor]:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def unpatchify(
|
| 47 |
+
self,
|
| 48 |
+
latents: Tensor,
|
| 49 |
+
output_height: int,
|
| 50 |
+
output_width: int,
|
| 51 |
+
output_num_frames: int,
|
| 52 |
+
out_channels: int,
|
| 53 |
+
) -> Tuple[Tensor, Tensor]:
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def patch_size(self):
|
| 58 |
+
return self._patch_size
|
| 59 |
+
|
| 60 |
+
def get_latent_coords(
|
| 61 |
+
self, latent_num_frames, latent_height, latent_width, batch_size, device
|
| 62 |
+
):
|
| 63 |
+
"""
|
| 64 |
+
Return a tensor of shape [batch_size, 3, num_patches] containing the
|
| 65 |
+
top-left corner latent coordinates of each latent patch.
|
| 66 |
+
The tensor is repeated for each batch element.
|
| 67 |
+
"""
|
| 68 |
+
latent_sample_coords = torch.meshgrid(
|
| 69 |
+
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
|
| 70 |
+
torch.arange(0, latent_height, self._patch_size[1], device=device),
|
| 71 |
+
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
| 72 |
+
indexing="ij",
|
| 73 |
+
)
|
| 74 |
+
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
| 75 |
+
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
| 76 |
+
latent_coords = rearrange(
|
| 77 |
+
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
| 78 |
+
)
|
| 79 |
+
return latent_coords
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class SymmetricPatchifier(Patchifier):
|
| 83 |
+
def patchify(
|
| 84 |
+
self,
|
| 85 |
+
latents: Tensor,
|
| 86 |
+
) -> Tuple[Tensor, Tensor]:
|
| 87 |
+
b, _, f, h, w = latents.shape
|
| 88 |
+
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
|
| 89 |
+
latents = rearrange(
|
| 90 |
+
latents,
|
| 91 |
+
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
| 92 |
+
p1=self._patch_size[0],
|
| 93 |
+
p2=self._patch_size[1],
|
| 94 |
+
p3=self._patch_size[2],
|
| 95 |
+
)
|
| 96 |
+
return latents, latent_coords
|
| 97 |
+
|
| 98 |
+
def unpatchify(
|
| 99 |
+
self,
|
| 100 |
+
latents: Tensor,
|
| 101 |
+
output_height: int,
|
| 102 |
+
output_width: int,
|
| 103 |
+
output_num_frames: int,
|
| 104 |
+
out_channels: int,
|
| 105 |
+
) -> Tuple[Tensor, Tensor]:
|
| 106 |
+
output_height = output_height // self._patch_size[1]
|
| 107 |
+
output_width = output_width // self._patch_size[2]
|
| 108 |
+
latents = rearrange(
|
| 109 |
+
latents,
|
| 110 |
+
"b (f h w) (c p q) -> b c f (h p) (w q) ",
|
| 111 |
+
f=output_num_frames,
|
| 112 |
+
h=output_height,
|
| 113 |
+
w=output_width,
|
| 114 |
+
p=self._patch_size[1],
|
| 115 |
+
q=self._patch_size[2],
|
| 116 |
+
)
|
| 117 |
+
return latents
|
ComfyUI/comfy/ldm/lightricks/vae/causal_conv3d.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import comfy.ops
|
| 6 |
+
ops = comfy.ops.disable_weight_init
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CausalConv3d(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
in_channels,
|
| 13 |
+
out_channels,
|
| 14 |
+
kernel_size: int = 3,
|
| 15 |
+
stride: Union[int, Tuple[int]] = 1,
|
| 16 |
+
dilation: int = 1,
|
| 17 |
+
groups: int = 1,
|
| 18 |
+
spatial_padding_mode: str = "zeros",
|
| 19 |
+
**kwargs,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
self.out_channels = out_channels
|
| 25 |
+
|
| 26 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 27 |
+
self.time_kernel_size = kernel_size[0]
|
| 28 |
+
|
| 29 |
+
dilation = (dilation, 1, 1)
|
| 30 |
+
|
| 31 |
+
height_pad = kernel_size[1] // 2
|
| 32 |
+
width_pad = kernel_size[2] // 2
|
| 33 |
+
padding = (0, height_pad, width_pad)
|
| 34 |
+
|
| 35 |
+
self.conv = ops.Conv3d(
|
| 36 |
+
in_channels,
|
| 37 |
+
out_channels,
|
| 38 |
+
kernel_size,
|
| 39 |
+
stride=stride,
|
| 40 |
+
dilation=dilation,
|
| 41 |
+
padding=padding,
|
| 42 |
+
padding_mode=spatial_padding_mode,
|
| 43 |
+
groups=groups,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x, causal: bool = True):
|
| 47 |
+
if causal:
|
| 48 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 49 |
+
(1, 1, self.time_kernel_size - 1, 1, 1)
|
| 50 |
+
)
|
| 51 |
+
x = torch.concatenate((first_frame_pad, x), dim=2)
|
| 52 |
+
else:
|
| 53 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 54 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
| 55 |
+
)
|
| 56 |
+
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
| 57 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
| 58 |
+
)
|
| 59 |
+
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
| 60 |
+
x = self.conv(x)
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def weight(self):
|
| 65 |
+
return self.conv.weight
|
ComfyUI/comfy/ldm/lumina/model.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import comfy.ldm.common_dit
|
| 10 |
+
|
| 11 |
+
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
| 12 |
+
from comfy.ldm.modules.attention import optimized_attention_masked
|
| 13 |
+
from comfy.ldm.flux.layers import EmbedND
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def modulate(x, scale):
|
| 17 |
+
return x * (1 + scale.unsqueeze(1))
|
| 18 |
+
|
| 19 |
+
#############################################################################
|
| 20 |
+
# Core NextDiT Model #
|
| 21 |
+
#############################################################################
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class JointAttention(nn.Module):
|
| 25 |
+
"""Multi-head attention module."""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
dim: int,
|
| 30 |
+
n_heads: int,
|
| 31 |
+
n_kv_heads: Optional[int],
|
| 32 |
+
qk_norm: bool,
|
| 33 |
+
operation_settings={},
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the Attention module.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
dim (int): Number of input dimensions.
|
| 40 |
+
n_heads (int): Number of heads.
|
| 41 |
+
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
|
| 42 |
+
|
| 43 |
+
"""
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
| 46 |
+
self.n_local_heads = n_heads
|
| 47 |
+
self.n_local_kv_heads = self.n_kv_heads
|
| 48 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
| 49 |
+
self.head_dim = dim // n_heads
|
| 50 |
+
|
| 51 |
+
self.qkv = operation_settings.get("operations").Linear(
|
| 52 |
+
dim,
|
| 53 |
+
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
|
| 54 |
+
bias=False,
|
| 55 |
+
device=operation_settings.get("device"),
|
| 56 |
+
dtype=operation_settings.get("dtype"),
|
| 57 |
+
)
|
| 58 |
+
self.out = operation_settings.get("operations").Linear(
|
| 59 |
+
n_heads * self.head_dim,
|
| 60 |
+
dim,
|
| 61 |
+
bias=False,
|
| 62 |
+
device=operation_settings.get("device"),
|
| 63 |
+
dtype=operation_settings.get("dtype"),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if qk_norm:
|
| 67 |
+
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
| 68 |
+
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
| 69 |
+
else:
|
| 70 |
+
self.q_norm = self.k_norm = nn.Identity()
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def apply_rotary_emb(
|
| 74 |
+
x_in: torch.Tensor,
|
| 75 |
+
freqs_cis: torch.Tensor,
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Apply rotary embeddings to input tensors using the given frequency
|
| 79 |
+
tensor.
|
| 80 |
+
|
| 81 |
+
This function applies rotary embeddings to the given query 'xq' and
|
| 82 |
+
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
| 83 |
+
input tensors are reshaped as complex numbers, and the frequency tensor
|
| 84 |
+
is reshaped for broadcasting compatibility. The resulting tensors
|
| 85 |
+
contain rotary embeddings and are returned as real tensors.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
| 89 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
| 90 |
+
exponentials.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
| 94 |
+
and key tensor with rotary embeddings.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
| 98 |
+
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
| 99 |
+
return t_out.reshape(*x_in.shape)
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
x: torch.Tensor,
|
| 104 |
+
x_mask: torch.Tensor,
|
| 105 |
+
freqs_cis: torch.Tensor,
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
x:
|
| 111 |
+
x_mask:
|
| 112 |
+
freqs_cis:
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
bsz, seqlen, _ = x.shape
|
| 118 |
+
|
| 119 |
+
xq, xk, xv = torch.split(
|
| 120 |
+
self.qkv(x),
|
| 121 |
+
[
|
| 122 |
+
self.n_local_heads * self.head_dim,
|
| 123 |
+
self.n_local_kv_heads * self.head_dim,
|
| 124 |
+
self.n_local_kv_heads * self.head_dim,
|
| 125 |
+
],
|
| 126 |
+
dim=-1,
|
| 127 |
+
)
|
| 128 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
| 129 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
| 130 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
| 131 |
+
|
| 132 |
+
xq = self.q_norm(xq)
|
| 133 |
+
xk = self.k_norm(xk)
|
| 134 |
+
|
| 135 |
+
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
| 136 |
+
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
| 137 |
+
|
| 138 |
+
n_rep = self.n_local_heads // self.n_local_kv_heads
|
| 139 |
+
if n_rep >= 1:
|
| 140 |
+
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| 141 |
+
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| 142 |
+
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
| 143 |
+
|
| 144 |
+
return self.out(output)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class FeedForward(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
dim: int,
|
| 151 |
+
hidden_dim: int,
|
| 152 |
+
multiple_of: int,
|
| 153 |
+
ffn_dim_multiplier: Optional[float],
|
| 154 |
+
operation_settings={},
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
Initialize the FeedForward module.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
dim (int): Input dimension.
|
| 161 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
| 162 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple
|
| 163 |
+
of this value.
|
| 164 |
+
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
|
| 165 |
+
dimension. Defaults to None.
|
| 166 |
+
|
| 167 |
+
"""
|
| 168 |
+
super().__init__()
|
| 169 |
+
# custom dim factor multiplier
|
| 170 |
+
if ffn_dim_multiplier is not None:
|
| 171 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 172 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 173 |
+
|
| 174 |
+
self.w1 = operation_settings.get("operations").Linear(
|
| 175 |
+
dim,
|
| 176 |
+
hidden_dim,
|
| 177 |
+
bias=False,
|
| 178 |
+
device=operation_settings.get("device"),
|
| 179 |
+
dtype=operation_settings.get("dtype"),
|
| 180 |
+
)
|
| 181 |
+
self.w2 = operation_settings.get("operations").Linear(
|
| 182 |
+
hidden_dim,
|
| 183 |
+
dim,
|
| 184 |
+
bias=False,
|
| 185 |
+
device=operation_settings.get("device"),
|
| 186 |
+
dtype=operation_settings.get("dtype"),
|
| 187 |
+
)
|
| 188 |
+
self.w3 = operation_settings.get("operations").Linear(
|
| 189 |
+
dim,
|
| 190 |
+
hidden_dim,
|
| 191 |
+
bias=False,
|
| 192 |
+
device=operation_settings.get("device"),
|
| 193 |
+
dtype=operation_settings.get("dtype"),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# @torch.compile
|
| 197 |
+
def _forward_silu_gating(self, x1, x3):
|
| 198 |
+
return F.silu(x1) * x3
|
| 199 |
+
|
| 200 |
+
def forward(self, x):
|
| 201 |
+
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class JointTransformerBlock(nn.Module):
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
layer_id: int,
|
| 208 |
+
dim: int,
|
| 209 |
+
n_heads: int,
|
| 210 |
+
n_kv_heads: int,
|
| 211 |
+
multiple_of: int,
|
| 212 |
+
ffn_dim_multiplier: float,
|
| 213 |
+
norm_eps: float,
|
| 214 |
+
qk_norm: bool,
|
| 215 |
+
modulation=True,
|
| 216 |
+
operation_settings={},
|
| 217 |
+
) -> None:
|
| 218 |
+
"""
|
| 219 |
+
Initialize a TransformerBlock.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
layer_id (int): Identifier for the layer.
|
| 223 |
+
dim (int): Embedding dimension of the input features.
|
| 224 |
+
n_heads (int): Number of attention heads.
|
| 225 |
+
n_kv_heads (Optional[int]): Number of attention heads in key and
|
| 226 |
+
value features (if using GQA), or set to None for the same as
|
| 227 |
+
query.
|
| 228 |
+
multiple_of (int):
|
| 229 |
+
ffn_dim_multiplier (float):
|
| 230 |
+
norm_eps (float):
|
| 231 |
+
|
| 232 |
+
"""
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.dim = dim
|
| 235 |
+
self.head_dim = dim // n_heads
|
| 236 |
+
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
| 237 |
+
self.feed_forward = FeedForward(
|
| 238 |
+
dim=dim,
|
| 239 |
+
hidden_dim=4 * dim,
|
| 240 |
+
multiple_of=multiple_of,
|
| 241 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 242 |
+
operation_settings=operation_settings,
|
| 243 |
+
)
|
| 244 |
+
self.layer_id = layer_id
|
| 245 |
+
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
| 246 |
+
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
| 247 |
+
|
| 248 |
+
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
| 249 |
+
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
| 250 |
+
|
| 251 |
+
self.modulation = modulation
|
| 252 |
+
if modulation:
|
| 253 |
+
self.adaLN_modulation = nn.Sequential(
|
| 254 |
+
nn.SiLU(),
|
| 255 |
+
operation_settings.get("operations").Linear(
|
| 256 |
+
min(dim, 1024),
|
| 257 |
+
4 * dim,
|
| 258 |
+
bias=True,
|
| 259 |
+
device=operation_settings.get("device"),
|
| 260 |
+
dtype=operation_settings.get("dtype"),
|
| 261 |
+
),
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
def forward(
|
| 265 |
+
self,
|
| 266 |
+
x: torch.Tensor,
|
| 267 |
+
x_mask: torch.Tensor,
|
| 268 |
+
freqs_cis: torch.Tensor,
|
| 269 |
+
adaln_input: Optional[torch.Tensor]=None,
|
| 270 |
+
):
|
| 271 |
+
"""
|
| 272 |
+
Perform a forward pass through the TransformerBlock.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
x (torch.Tensor): Input tensor.
|
| 276 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
torch.Tensor: Output tensor after applying attention and
|
| 280 |
+
feedforward layers.
|
| 281 |
+
|
| 282 |
+
"""
|
| 283 |
+
if self.modulation:
|
| 284 |
+
assert adaln_input is not None
|
| 285 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
| 286 |
+
|
| 287 |
+
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
| 288 |
+
self.attention(
|
| 289 |
+
modulate(self.attention_norm1(x), scale_msa),
|
| 290 |
+
x_mask,
|
| 291 |
+
freqs_cis,
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
| 295 |
+
self.feed_forward(
|
| 296 |
+
modulate(self.ffn_norm1(x), scale_mlp),
|
| 297 |
+
)
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
assert adaln_input is None
|
| 301 |
+
x = x + self.attention_norm2(
|
| 302 |
+
self.attention(
|
| 303 |
+
self.attention_norm1(x),
|
| 304 |
+
x_mask,
|
| 305 |
+
freqs_cis,
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
x = x + self.ffn_norm2(
|
| 309 |
+
self.feed_forward(
|
| 310 |
+
self.ffn_norm1(x),
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
return x
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class FinalLayer(nn.Module):
|
| 317 |
+
"""
|
| 318 |
+
The final layer of NextDiT.
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.norm_final = operation_settings.get("operations").LayerNorm(
|
| 324 |
+
hidden_size,
|
| 325 |
+
elementwise_affine=False,
|
| 326 |
+
eps=1e-6,
|
| 327 |
+
device=operation_settings.get("device"),
|
| 328 |
+
dtype=operation_settings.get("dtype"),
|
| 329 |
+
)
|
| 330 |
+
self.linear = operation_settings.get("operations").Linear(
|
| 331 |
+
hidden_size,
|
| 332 |
+
patch_size * patch_size * out_channels,
|
| 333 |
+
bias=True,
|
| 334 |
+
device=operation_settings.get("device"),
|
| 335 |
+
dtype=operation_settings.get("dtype"),
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
self.adaLN_modulation = nn.Sequential(
|
| 339 |
+
nn.SiLU(),
|
| 340 |
+
operation_settings.get("operations").Linear(
|
| 341 |
+
min(hidden_size, 1024),
|
| 342 |
+
hidden_size,
|
| 343 |
+
bias=True,
|
| 344 |
+
device=operation_settings.get("device"),
|
| 345 |
+
dtype=operation_settings.get("dtype"),
|
| 346 |
+
),
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def forward(self, x, c):
|
| 350 |
+
scale = self.adaLN_modulation(c)
|
| 351 |
+
x = modulate(self.norm_final(x), scale)
|
| 352 |
+
x = self.linear(x)
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class NextDiT(nn.Module):
|
| 357 |
+
"""
|
| 358 |
+
Diffusion model with a Transformer backbone.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
def __init__(
|
| 362 |
+
self,
|
| 363 |
+
patch_size: int = 2,
|
| 364 |
+
in_channels: int = 4,
|
| 365 |
+
dim: int = 4096,
|
| 366 |
+
n_layers: int = 32,
|
| 367 |
+
n_refiner_layers: int = 2,
|
| 368 |
+
n_heads: int = 32,
|
| 369 |
+
n_kv_heads: Optional[int] = None,
|
| 370 |
+
multiple_of: int = 256,
|
| 371 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 372 |
+
norm_eps: float = 1e-5,
|
| 373 |
+
qk_norm: bool = False,
|
| 374 |
+
cap_feat_dim: int = 5120,
|
| 375 |
+
axes_dims: List[int] = (16, 56, 56),
|
| 376 |
+
axes_lens: List[int] = (1, 512, 512),
|
| 377 |
+
image_model=None,
|
| 378 |
+
device=None,
|
| 379 |
+
dtype=None,
|
| 380 |
+
operations=None,
|
| 381 |
+
) -> None:
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.dtype = dtype
|
| 384 |
+
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
| 385 |
+
self.in_channels = in_channels
|
| 386 |
+
self.out_channels = in_channels
|
| 387 |
+
self.patch_size = patch_size
|
| 388 |
+
|
| 389 |
+
self.x_embedder = operation_settings.get("operations").Linear(
|
| 390 |
+
in_features=patch_size * patch_size * in_channels,
|
| 391 |
+
out_features=dim,
|
| 392 |
+
bias=True,
|
| 393 |
+
device=operation_settings.get("device"),
|
| 394 |
+
dtype=operation_settings.get("dtype"),
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
self.noise_refiner = nn.ModuleList(
|
| 398 |
+
[
|
| 399 |
+
JointTransformerBlock(
|
| 400 |
+
layer_id,
|
| 401 |
+
dim,
|
| 402 |
+
n_heads,
|
| 403 |
+
n_kv_heads,
|
| 404 |
+
multiple_of,
|
| 405 |
+
ffn_dim_multiplier,
|
| 406 |
+
norm_eps,
|
| 407 |
+
qk_norm,
|
| 408 |
+
modulation=True,
|
| 409 |
+
operation_settings=operation_settings,
|
| 410 |
+
)
|
| 411 |
+
for layer_id in range(n_refiner_layers)
|
| 412 |
+
]
|
| 413 |
+
)
|
| 414 |
+
self.context_refiner = nn.ModuleList(
|
| 415 |
+
[
|
| 416 |
+
JointTransformerBlock(
|
| 417 |
+
layer_id,
|
| 418 |
+
dim,
|
| 419 |
+
n_heads,
|
| 420 |
+
n_kv_heads,
|
| 421 |
+
multiple_of,
|
| 422 |
+
ffn_dim_multiplier,
|
| 423 |
+
norm_eps,
|
| 424 |
+
qk_norm,
|
| 425 |
+
modulation=False,
|
| 426 |
+
operation_settings=operation_settings,
|
| 427 |
+
)
|
| 428 |
+
for layer_id in range(n_refiner_layers)
|
| 429 |
+
]
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
| 433 |
+
self.cap_embedder = nn.Sequential(
|
| 434 |
+
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
| 435 |
+
operation_settings.get("operations").Linear(
|
| 436 |
+
cap_feat_dim,
|
| 437 |
+
dim,
|
| 438 |
+
bias=True,
|
| 439 |
+
device=operation_settings.get("device"),
|
| 440 |
+
dtype=operation_settings.get("dtype"),
|
| 441 |
+
),
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
self.layers = nn.ModuleList(
|
| 445 |
+
[
|
| 446 |
+
JointTransformerBlock(
|
| 447 |
+
layer_id,
|
| 448 |
+
dim,
|
| 449 |
+
n_heads,
|
| 450 |
+
n_kv_heads,
|
| 451 |
+
multiple_of,
|
| 452 |
+
ffn_dim_multiplier,
|
| 453 |
+
norm_eps,
|
| 454 |
+
qk_norm,
|
| 455 |
+
operation_settings=operation_settings,
|
| 456 |
+
)
|
| 457 |
+
for layer_id in range(n_layers)
|
| 458 |
+
]
|
| 459 |
+
)
|
| 460 |
+
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
| 461 |
+
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
| 462 |
+
|
| 463 |
+
assert (dim // n_heads) == sum(axes_dims)
|
| 464 |
+
self.axes_dims = axes_dims
|
| 465 |
+
self.axes_lens = axes_lens
|
| 466 |
+
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
| 467 |
+
self.dim = dim
|
| 468 |
+
self.n_heads = n_heads
|
| 469 |
+
|
| 470 |
+
def unpatchify(
|
| 471 |
+
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
|
| 472 |
+
) -> List[torch.Tensor]:
|
| 473 |
+
"""
|
| 474 |
+
x: (N, T, patch_size**2 * C)
|
| 475 |
+
imgs: (N, H, W, C)
|
| 476 |
+
"""
|
| 477 |
+
pH = pW = self.patch_size
|
| 478 |
+
imgs = []
|
| 479 |
+
for i in range(x.size(0)):
|
| 480 |
+
H, W = img_size[i]
|
| 481 |
+
begin = cap_size[i]
|
| 482 |
+
end = begin + (H // pH) * (W // pW)
|
| 483 |
+
imgs.append(
|
| 484 |
+
x[i][begin:end]
|
| 485 |
+
.view(H // pH, W // pW, pH, pW, self.out_channels)
|
| 486 |
+
.permute(4, 0, 2, 1, 3)
|
| 487 |
+
.flatten(3, 4)
|
| 488 |
+
.flatten(1, 2)
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if return_tensor:
|
| 492 |
+
imgs = torch.stack(imgs, dim=0)
|
| 493 |
+
return imgs
|
| 494 |
+
|
| 495 |
+
def patchify_and_embed(
|
| 496 |
+
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
| 497 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
| 498 |
+
bsz = len(x)
|
| 499 |
+
pH = pW = self.patch_size
|
| 500 |
+
device = x[0].device
|
| 501 |
+
dtype = x[0].dtype
|
| 502 |
+
|
| 503 |
+
if cap_mask is not None:
|
| 504 |
+
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
| 505 |
+
else:
|
| 506 |
+
l_effective_cap_len = [num_tokens] * bsz
|
| 507 |
+
|
| 508 |
+
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
| 509 |
+
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
| 510 |
+
|
| 511 |
+
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
| 512 |
+
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
| 513 |
+
|
| 514 |
+
max_seq_len = max(
|
| 515 |
+
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
| 516 |
+
)
|
| 517 |
+
max_cap_len = max(l_effective_cap_len)
|
| 518 |
+
max_img_len = max(l_effective_img_len)
|
| 519 |
+
|
| 520 |
+
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
| 521 |
+
|
| 522 |
+
for i in range(bsz):
|
| 523 |
+
cap_len = l_effective_cap_len[i]
|
| 524 |
+
img_len = l_effective_img_len[i]
|
| 525 |
+
H, W = img_sizes[i]
|
| 526 |
+
H_tokens, W_tokens = H // pH, W // pW
|
| 527 |
+
assert H_tokens * W_tokens == img_len
|
| 528 |
+
|
| 529 |
+
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
| 530 |
+
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
| 531 |
+
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
| 532 |
+
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
| 533 |
+
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
| 534 |
+
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
| 535 |
+
|
| 536 |
+
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
| 537 |
+
|
| 538 |
+
# build freqs_cis for cap and image individually
|
| 539 |
+
cap_freqs_cis_shape = list(freqs_cis.shape)
|
| 540 |
+
# cap_freqs_cis_shape[1] = max_cap_len
|
| 541 |
+
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
| 542 |
+
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
| 543 |
+
|
| 544 |
+
img_freqs_cis_shape = list(freqs_cis.shape)
|
| 545 |
+
img_freqs_cis_shape[1] = max_img_len
|
| 546 |
+
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
| 547 |
+
|
| 548 |
+
for i in range(bsz):
|
| 549 |
+
cap_len = l_effective_cap_len[i]
|
| 550 |
+
img_len = l_effective_img_len[i]
|
| 551 |
+
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
| 552 |
+
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
| 553 |
+
|
| 554 |
+
# refine context
|
| 555 |
+
for layer in self.context_refiner:
|
| 556 |
+
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
| 557 |
+
|
| 558 |
+
# refine image
|
| 559 |
+
flat_x = []
|
| 560 |
+
for i in range(bsz):
|
| 561 |
+
img = x[i]
|
| 562 |
+
C, H, W = img.size()
|
| 563 |
+
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
| 564 |
+
flat_x.append(img)
|
| 565 |
+
x = flat_x
|
| 566 |
+
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
| 567 |
+
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
| 568 |
+
for i in range(bsz):
|
| 569 |
+
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
| 570 |
+
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
| 571 |
+
|
| 572 |
+
padded_img_embed = self.x_embedder(padded_img_embed)
|
| 573 |
+
padded_img_mask = padded_img_mask.unsqueeze(1)
|
| 574 |
+
for layer in self.noise_refiner:
|
| 575 |
+
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
| 576 |
+
|
| 577 |
+
if cap_mask is not None:
|
| 578 |
+
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
| 579 |
+
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
| 580 |
+
else:
|
| 581 |
+
mask = None
|
| 582 |
+
|
| 583 |
+
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
| 584 |
+
for i in range(bsz):
|
| 585 |
+
cap_len = l_effective_cap_len[i]
|
| 586 |
+
img_len = l_effective_img_len[i]
|
| 587 |
+
|
| 588 |
+
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
| 589 |
+
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
| 590 |
+
|
| 591 |
+
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
| 592 |
+
|
| 593 |
+
# def forward(self, x, t, cap_feats, cap_mask):
|
| 594 |
+
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
| 595 |
+
t = 1.0 - timesteps
|
| 596 |
+
cap_feats = context
|
| 597 |
+
cap_mask = attention_mask
|
| 598 |
+
bs, c, h, w = x.shape
|
| 599 |
+
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
| 600 |
+
"""
|
| 601 |
+
Forward pass of NextDiT.
|
| 602 |
+
t: (N,) tensor of diffusion timesteps
|
| 603 |
+
y: (N,) tensor of text tokens/features
|
| 604 |
+
"""
|
| 605 |
+
|
| 606 |
+
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
| 607 |
+
adaln_input = t
|
| 608 |
+
|
| 609 |
+
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
| 610 |
+
|
| 611 |
+
x_is_tensor = isinstance(x, torch.Tensor)
|
| 612 |
+
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
| 613 |
+
freqs_cis = freqs_cis.to(x.device)
|
| 614 |
+
|
| 615 |
+
for layer in self.layers:
|
| 616 |
+
x = layer(x, mask, freqs_cis, adaln_input)
|
| 617 |
+
|
| 618 |
+
x = self.final_layer(x, adaln_input)
|
| 619 |
+
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
| 620 |
+
|
| 621 |
+
return -x
|
| 622 |
+
|
ComfyUI/comfy/ldm/models/autoencoder.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from typing import Any, Dict, Tuple, Union
|
| 6 |
+
|
| 7 |
+
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
| 8 |
+
|
| 9 |
+
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
| 10 |
+
from comfy.ldm.modules.ema import LitEma
|
| 11 |
+
import comfy.ops
|
| 12 |
+
|
| 13 |
+
class DiagonalGaussianRegularizer(torch.nn.Module):
|
| 14 |
+
def __init__(self, sample: bool = False):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.sample = sample
|
| 17 |
+
|
| 18 |
+
def get_trainable_parameters(self) -> Any:
|
| 19 |
+
yield from ()
|
| 20 |
+
|
| 21 |
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
| 22 |
+
posterior = DiagonalGaussianDistribution(z)
|
| 23 |
+
if self.sample:
|
| 24 |
+
z = posterior.sample()
|
| 25 |
+
else:
|
| 26 |
+
z = posterior.mode()
|
| 27 |
+
return z, None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AbstractAutoencoder(torch.nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
| 33 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
| 34 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
ema_decay: Union[None, float] = None,
|
| 40 |
+
monitor: Union[None, str] = None,
|
| 41 |
+
input_key: str = "jpg",
|
| 42 |
+
**kwargs,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.input_key = input_key
|
| 47 |
+
self.use_ema = ema_decay is not None
|
| 48 |
+
if monitor is not None:
|
| 49 |
+
self.monitor = monitor
|
| 50 |
+
|
| 51 |
+
if self.use_ema:
|
| 52 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
| 53 |
+
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 54 |
+
|
| 55 |
+
def get_input(self, batch) -> Any:
|
| 56 |
+
raise NotImplementedError()
|
| 57 |
+
|
| 58 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 59 |
+
# for EMA computation
|
| 60 |
+
if self.use_ema:
|
| 61 |
+
self.model_ema(self)
|
| 62 |
+
|
| 63 |
+
@contextmanager
|
| 64 |
+
def ema_scope(self, context=None):
|
| 65 |
+
if self.use_ema:
|
| 66 |
+
self.model_ema.store(self.parameters())
|
| 67 |
+
self.model_ema.copy_to(self)
|
| 68 |
+
if context is not None:
|
| 69 |
+
logging.info(f"{context}: Switched to EMA weights")
|
| 70 |
+
try:
|
| 71 |
+
yield None
|
| 72 |
+
finally:
|
| 73 |
+
if self.use_ema:
|
| 74 |
+
self.model_ema.restore(self.parameters())
|
| 75 |
+
if context is not None:
|
| 76 |
+
logging.info(f"{context}: Restored training weights")
|
| 77 |
+
|
| 78 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
| 79 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
| 80 |
+
|
| 81 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
| 82 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
| 83 |
+
|
| 84 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
| 85 |
+
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
| 86 |
+
return get_obj_from_str(cfg["target"])(
|
| 87 |
+
params, lr=lr, **cfg.get("params", dict())
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def configure_optimizers(self) -> Any:
|
| 91 |
+
raise NotImplementedError()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
| 95 |
+
"""
|
| 96 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
| 97 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
| 98 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
*args,
|
| 104 |
+
encoder_config: Dict,
|
| 105 |
+
decoder_config: Dict,
|
| 106 |
+
regularizer_config: Dict,
|
| 107 |
+
**kwargs,
|
| 108 |
+
):
|
| 109 |
+
super().__init__(*args, **kwargs)
|
| 110 |
+
|
| 111 |
+
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
| 112 |
+
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
| 113 |
+
self.regularization = instantiate_from_config(
|
| 114 |
+
regularizer_config
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def get_last_layer(self):
|
| 118 |
+
return self.decoder.get_last_layer()
|
| 119 |
+
|
| 120 |
+
def encode(
|
| 121 |
+
self,
|
| 122 |
+
x: torch.Tensor,
|
| 123 |
+
return_reg_log: bool = False,
|
| 124 |
+
unregularized: bool = False,
|
| 125 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
| 126 |
+
z = self.encoder(x)
|
| 127 |
+
if unregularized:
|
| 128 |
+
return z, dict()
|
| 129 |
+
z, reg_log = self.regularization(z)
|
| 130 |
+
if return_reg_log:
|
| 131 |
+
return z, reg_log
|
| 132 |
+
return z
|
| 133 |
+
|
| 134 |
+
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 135 |
+
x = self.decoder(z, **kwargs)
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
def forward(
|
| 139 |
+
self, x: torch.Tensor, **additional_decode_kwargs
|
| 140 |
+
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
| 141 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
| 142 |
+
dec = self.decode(z, **additional_decode_kwargs)
|
| 143 |
+
return z, dec, reg_log
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class AutoencodingEngineLegacy(AutoencodingEngine):
|
| 147 |
+
def __init__(self, embed_dim: int, **kwargs):
|
| 148 |
+
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
| 149 |
+
ddconfig = kwargs.pop("ddconfig")
|
| 150 |
+
super().__init__(
|
| 151 |
+
encoder_config={
|
| 152 |
+
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
|
| 153 |
+
"params": ddconfig,
|
| 154 |
+
},
|
| 155 |
+
decoder_config={
|
| 156 |
+
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
|
| 157 |
+
"params": ddconfig,
|
| 158 |
+
},
|
| 159 |
+
**kwargs,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if ddconfig.get("conv3d", False):
|
| 163 |
+
conv_op = comfy.ops.disable_weight_init.Conv3d
|
| 164 |
+
else:
|
| 165 |
+
conv_op = comfy.ops.disable_weight_init.Conv2d
|
| 166 |
+
|
| 167 |
+
self.quant_conv = conv_op(
|
| 168 |
+
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
| 169 |
+
(1 + ddconfig["double_z"]) * embed_dim,
|
| 170 |
+
1,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
| 174 |
+
self.embed_dim = embed_dim
|
| 175 |
+
|
| 176 |
+
def get_autoencoder_params(self) -> list:
|
| 177 |
+
params = super().get_autoencoder_params()
|
| 178 |
+
return params
|
| 179 |
+
|
| 180 |
+
def encode(
|
| 181 |
+
self, x: torch.Tensor, return_reg_log: bool = False
|
| 182 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
| 183 |
+
if self.max_batch_size is None:
|
| 184 |
+
z = self.encoder(x)
|
| 185 |
+
z = self.quant_conv(z)
|
| 186 |
+
else:
|
| 187 |
+
N = x.shape[0]
|
| 188 |
+
bs = self.max_batch_size
|
| 189 |
+
n_batches = int(math.ceil(N / bs))
|
| 190 |
+
z = list()
|
| 191 |
+
for i_batch in range(n_batches):
|
| 192 |
+
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
| 193 |
+
z_batch = self.quant_conv(z_batch)
|
| 194 |
+
z.append(z_batch)
|
| 195 |
+
z = torch.cat(z, 0)
|
| 196 |
+
|
| 197 |
+
z, reg_log = self.regularization(z)
|
| 198 |
+
if return_reg_log:
|
| 199 |
+
return z, reg_log
|
| 200 |
+
return z
|
| 201 |
+
|
| 202 |
+
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
| 203 |
+
if self.max_batch_size is None:
|
| 204 |
+
dec = self.post_quant_conv(z)
|
| 205 |
+
dec = self.decoder(dec, **decoder_kwargs)
|
| 206 |
+
else:
|
| 207 |
+
N = z.shape[0]
|
| 208 |
+
bs = self.max_batch_size
|
| 209 |
+
n_batches = int(math.ceil(N / bs))
|
| 210 |
+
dec = list()
|
| 211 |
+
for i_batch in range(n_batches):
|
| 212 |
+
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
| 213 |
+
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
| 214 |
+
dec.append(dec_batch)
|
| 215 |
+
dec = torch.cat(dec, 0)
|
| 216 |
+
|
| 217 |
+
return dec
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class AutoencoderKL(AutoencodingEngineLegacy):
|
| 221 |
+
def __init__(self, **kwargs):
|
| 222 |
+
if "lossconfig" in kwargs:
|
| 223 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
| 224 |
+
super().__init__(
|
| 225 |
+
regularizer_config={
|
| 226 |
+
"target": (
|
| 227 |
+
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
|
| 228 |
+
)
|
| 229 |
+
},
|
| 230 |
+
**kwargs,
|
| 231 |
+
)
|
ComfyUI/comfy/ldm/modules/attention.py
ADDED
|
@@ -0,0 +1,1035 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn, einsum
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
| 12 |
+
from .sub_quadratic_attention import efficient_dot_product_attention
|
| 13 |
+
|
| 14 |
+
from comfy import model_management
|
| 15 |
+
|
| 16 |
+
if model_management.xformers_enabled():
|
| 17 |
+
import xformers
|
| 18 |
+
import xformers.ops
|
| 19 |
+
|
| 20 |
+
if model_management.sage_attention_enabled():
|
| 21 |
+
try:
|
| 22 |
+
from sageattention import sageattn
|
| 23 |
+
except ModuleNotFoundError as e:
|
| 24 |
+
if e.name == "sageattention":
|
| 25 |
+
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
|
| 26 |
+
else:
|
| 27 |
+
raise e
|
| 28 |
+
exit(-1)
|
| 29 |
+
|
| 30 |
+
if model_management.flash_attention_enabled():
|
| 31 |
+
try:
|
| 32 |
+
from flash_attn import flash_attn_func
|
| 33 |
+
except ModuleNotFoundError:
|
| 34 |
+
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
| 35 |
+
exit(-1)
|
| 36 |
+
|
| 37 |
+
from comfy.cli_args import args
|
| 38 |
+
import comfy.ops
|
| 39 |
+
ops = comfy.ops.disable_weight_init
|
| 40 |
+
|
| 41 |
+
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
| 42 |
+
|
| 43 |
+
def get_attn_precision(attn_precision, current_dtype):
|
| 44 |
+
if args.dont_upcast_attention:
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
|
| 48 |
+
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
|
| 49 |
+
return attn_precision
|
| 50 |
+
|
| 51 |
+
def exists(val):
|
| 52 |
+
return val is not None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def default(val, d):
|
| 56 |
+
if exists(val):
|
| 57 |
+
return val
|
| 58 |
+
return d
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# feedforward
|
| 62 |
+
class GEGLU(nn.Module):
|
| 63 |
+
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 69 |
+
return x * F.gelu(gate)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class FeedForward(nn.Module):
|
| 73 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
|
| 74 |
+
super().__init__()
|
| 75 |
+
inner_dim = int(dim * mult)
|
| 76 |
+
dim_out = default(dim_out, dim)
|
| 77 |
+
project_in = nn.Sequential(
|
| 78 |
+
operations.Linear(dim, inner_dim, dtype=dtype, device=device),
|
| 79 |
+
nn.GELU()
|
| 80 |
+
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
| 81 |
+
|
| 82 |
+
self.net = nn.Sequential(
|
| 83 |
+
project_in,
|
| 84 |
+
nn.Dropout(dropout),
|
| 85 |
+
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
return self.net(x)
|
| 90 |
+
|
| 91 |
+
def Normalize(in_channels, dtype=None, device=None):
|
| 92 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
| 93 |
+
|
| 94 |
+
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
| 95 |
+
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
| 96 |
+
|
| 97 |
+
if skip_reshape:
|
| 98 |
+
b, _, _, dim_head = q.shape
|
| 99 |
+
else:
|
| 100 |
+
b, _, dim_head = q.shape
|
| 101 |
+
dim_head //= heads
|
| 102 |
+
|
| 103 |
+
scale = dim_head ** -0.5
|
| 104 |
+
|
| 105 |
+
h = heads
|
| 106 |
+
if skip_reshape:
|
| 107 |
+
q, k, v = map(
|
| 108 |
+
lambda t: t.reshape(b * heads, -1, dim_head),
|
| 109 |
+
(q, k, v),
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
q, k, v = map(
|
| 113 |
+
lambda t: t.unsqueeze(3)
|
| 114 |
+
.reshape(b, -1, heads, dim_head)
|
| 115 |
+
.permute(0, 2, 1, 3)
|
| 116 |
+
.reshape(b * heads, -1, dim_head)
|
| 117 |
+
.contiguous(),
|
| 118 |
+
(q, k, v),
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# force cast to fp32 to avoid overflowing
|
| 122 |
+
if attn_precision == torch.float32:
|
| 123 |
+
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
| 124 |
+
else:
|
| 125 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
| 126 |
+
|
| 127 |
+
del q, k
|
| 128 |
+
|
| 129 |
+
if exists(mask):
|
| 130 |
+
if mask.dtype == torch.bool:
|
| 131 |
+
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
|
| 132 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 133 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 134 |
+
sim.masked_fill_(~mask, max_neg_value)
|
| 135 |
+
else:
|
| 136 |
+
if len(mask.shape) == 2:
|
| 137 |
+
bs = 1
|
| 138 |
+
else:
|
| 139 |
+
bs = mask.shape[0]
|
| 140 |
+
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
| 141 |
+
sim.add_(mask)
|
| 142 |
+
|
| 143 |
+
# attention, what we cannot get enough of
|
| 144 |
+
sim = sim.softmax(dim=-1)
|
| 145 |
+
|
| 146 |
+
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
| 147 |
+
|
| 148 |
+
if skip_output_reshape:
|
| 149 |
+
out = (
|
| 150 |
+
out.unsqueeze(0)
|
| 151 |
+
.reshape(b, heads, -1, dim_head)
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
out = (
|
| 155 |
+
out.unsqueeze(0)
|
| 156 |
+
.reshape(b, heads, -1, dim_head)
|
| 157 |
+
.permute(0, 2, 1, 3)
|
| 158 |
+
.reshape(b, -1, heads * dim_head)
|
| 159 |
+
)
|
| 160 |
+
return out
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
| 164 |
+
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
| 165 |
+
|
| 166 |
+
if skip_reshape:
|
| 167 |
+
b, _, _, dim_head = query.shape
|
| 168 |
+
else:
|
| 169 |
+
b, _, dim_head = query.shape
|
| 170 |
+
dim_head //= heads
|
| 171 |
+
|
| 172 |
+
if skip_reshape:
|
| 173 |
+
query = query.reshape(b * heads, -1, dim_head)
|
| 174 |
+
value = value.reshape(b * heads, -1, dim_head)
|
| 175 |
+
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
|
| 176 |
+
else:
|
| 177 |
+
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
| 178 |
+
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
| 179 |
+
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
dtype = query.dtype
|
| 183 |
+
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
| 184 |
+
if upcast_attention:
|
| 185 |
+
bytes_per_token = torch.finfo(torch.float32).bits//8
|
| 186 |
+
else:
|
| 187 |
+
bytes_per_token = torch.finfo(query.dtype).bits//8
|
| 188 |
+
batch_x_heads, q_tokens, _ = query.shape
|
| 189 |
+
_, _, k_tokens = key.shape
|
| 190 |
+
|
| 191 |
+
mem_free_total, _ = model_management.get_free_memory(query.device, True)
|
| 192 |
+
|
| 193 |
+
kv_chunk_size_min = None
|
| 194 |
+
kv_chunk_size = None
|
| 195 |
+
query_chunk_size = None
|
| 196 |
+
|
| 197 |
+
for x in [4096, 2048, 1024, 512, 256]:
|
| 198 |
+
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
|
| 199 |
+
if count >= k_tokens:
|
| 200 |
+
kv_chunk_size = k_tokens
|
| 201 |
+
query_chunk_size = x
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
if query_chunk_size is None:
|
| 205 |
+
query_chunk_size = 512
|
| 206 |
+
|
| 207 |
+
if mask is not None:
|
| 208 |
+
if len(mask.shape) == 2:
|
| 209 |
+
bs = 1
|
| 210 |
+
else:
|
| 211 |
+
bs = mask.shape[0]
|
| 212 |
+
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
| 213 |
+
|
| 214 |
+
hidden_states = efficient_dot_product_attention(
|
| 215 |
+
query,
|
| 216 |
+
key,
|
| 217 |
+
value,
|
| 218 |
+
query_chunk_size=query_chunk_size,
|
| 219 |
+
kv_chunk_size=kv_chunk_size,
|
| 220 |
+
kv_chunk_size_min=kv_chunk_size_min,
|
| 221 |
+
use_checkpoint=False,
|
| 222 |
+
upcast_attention=upcast_attention,
|
| 223 |
+
mask=mask,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
hidden_states = hidden_states.to(dtype)
|
| 227 |
+
if skip_output_reshape:
|
| 228 |
+
hidden_states = hidden_states.unflatten(0, (-1, heads))
|
| 229 |
+
else:
|
| 230 |
+
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
| 231 |
+
return hidden_states
|
| 232 |
+
|
| 233 |
+
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
| 234 |
+
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
| 235 |
+
|
| 236 |
+
if skip_reshape:
|
| 237 |
+
b, _, _, dim_head = q.shape
|
| 238 |
+
else:
|
| 239 |
+
b, _, dim_head = q.shape
|
| 240 |
+
dim_head //= heads
|
| 241 |
+
|
| 242 |
+
scale = dim_head ** -0.5
|
| 243 |
+
|
| 244 |
+
if skip_reshape:
|
| 245 |
+
q, k, v = map(
|
| 246 |
+
lambda t: t.reshape(b * heads, -1, dim_head),
|
| 247 |
+
(q, k, v),
|
| 248 |
+
)
|
| 249 |
+
else:
|
| 250 |
+
q, k, v = map(
|
| 251 |
+
lambda t: t.unsqueeze(3)
|
| 252 |
+
.reshape(b, -1, heads, dim_head)
|
| 253 |
+
.permute(0, 2, 1, 3)
|
| 254 |
+
.reshape(b * heads, -1, dim_head)
|
| 255 |
+
.contiguous(),
|
| 256 |
+
(q, k, v),
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
| 260 |
+
|
| 261 |
+
mem_free_total = model_management.get_free_memory(q.device)
|
| 262 |
+
|
| 263 |
+
if attn_precision == torch.float32:
|
| 264 |
+
element_size = 4
|
| 265 |
+
upcast = True
|
| 266 |
+
else:
|
| 267 |
+
element_size = q.element_size()
|
| 268 |
+
upcast = False
|
| 269 |
+
|
| 270 |
+
gb = 1024 ** 3
|
| 271 |
+
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
| 272 |
+
modifier = 3
|
| 273 |
+
mem_required = tensor_size * modifier
|
| 274 |
+
steps = 1
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if mem_required > mem_free_total:
|
| 278 |
+
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
| 279 |
+
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
| 280 |
+
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
| 281 |
+
|
| 282 |
+
if steps > 64:
|
| 283 |
+
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
| 284 |
+
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
| 285 |
+
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
| 286 |
+
|
| 287 |
+
if mask is not None:
|
| 288 |
+
if len(mask.shape) == 2:
|
| 289 |
+
bs = 1
|
| 290 |
+
else:
|
| 291 |
+
bs = mask.shape[0]
|
| 292 |
+
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
| 293 |
+
|
| 294 |
+
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
| 295 |
+
first_op_done = False
|
| 296 |
+
cleared_cache = False
|
| 297 |
+
while True:
|
| 298 |
+
try:
|
| 299 |
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
| 300 |
+
for i in range(0, q.shape[1], slice_size):
|
| 301 |
+
end = i + slice_size
|
| 302 |
+
if upcast:
|
| 303 |
+
with torch.autocast(enabled=False, device_type = 'cuda'):
|
| 304 |
+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
| 305 |
+
else:
|
| 306 |
+
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
| 307 |
+
|
| 308 |
+
if mask is not None:
|
| 309 |
+
if len(mask.shape) == 2:
|
| 310 |
+
s1 += mask[i:end]
|
| 311 |
+
else:
|
| 312 |
+
if mask.shape[1] == 1:
|
| 313 |
+
s1 += mask
|
| 314 |
+
else:
|
| 315 |
+
s1 += mask[:, i:end]
|
| 316 |
+
|
| 317 |
+
s2 = s1.softmax(dim=-1).to(v.dtype)
|
| 318 |
+
del s1
|
| 319 |
+
first_op_done = True
|
| 320 |
+
|
| 321 |
+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
| 322 |
+
del s2
|
| 323 |
+
break
|
| 324 |
+
except model_management.OOM_EXCEPTION as e:
|
| 325 |
+
if first_op_done == False:
|
| 326 |
+
model_management.soft_empty_cache(True)
|
| 327 |
+
if cleared_cache == False:
|
| 328 |
+
cleared_cache = True
|
| 329 |
+
logging.warning("out of memory error, emptying cache and trying again")
|
| 330 |
+
continue
|
| 331 |
+
steps *= 2
|
| 332 |
+
if steps > 64:
|
| 333 |
+
raise e
|
| 334 |
+
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
| 335 |
+
else:
|
| 336 |
+
raise e
|
| 337 |
+
|
| 338 |
+
del q, k, v
|
| 339 |
+
|
| 340 |
+
if skip_output_reshape:
|
| 341 |
+
r1 = (
|
| 342 |
+
r1.unsqueeze(0)
|
| 343 |
+
.reshape(b, heads, -1, dim_head)
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
r1 = (
|
| 347 |
+
r1.unsqueeze(0)
|
| 348 |
+
.reshape(b, heads, -1, dim_head)
|
| 349 |
+
.permute(0, 2, 1, 3)
|
| 350 |
+
.reshape(b, -1, heads * dim_head)
|
| 351 |
+
)
|
| 352 |
+
return r1
|
| 353 |
+
|
| 354 |
+
BROKEN_XFORMERS = False
|
| 355 |
+
try:
|
| 356 |
+
x_vers = xformers.__version__
|
| 357 |
+
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
| 358 |
+
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
| 359 |
+
except:
|
| 360 |
+
pass
|
| 361 |
+
|
| 362 |
+
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
| 363 |
+
b = q.shape[0]
|
| 364 |
+
dim_head = q.shape[-1]
|
| 365 |
+
# check to make sure xformers isn't broken
|
| 366 |
+
disabled_xformers = False
|
| 367 |
+
|
| 368 |
+
if BROKEN_XFORMERS:
|
| 369 |
+
if b * heads > 65535:
|
| 370 |
+
disabled_xformers = True
|
| 371 |
+
|
| 372 |
+
if not disabled_xformers:
|
| 373 |
+
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
| 374 |
+
disabled_xformers = True
|
| 375 |
+
|
| 376 |
+
if disabled_xformers:
|
| 377 |
+
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
| 378 |
+
|
| 379 |
+
if skip_reshape:
|
| 380 |
+
# b h k d -> b k h d
|
| 381 |
+
q, k, v = map(
|
| 382 |
+
lambda t: t.permute(0, 2, 1, 3),
|
| 383 |
+
(q, k, v),
|
| 384 |
+
)
|
| 385 |
+
# actually do the reshaping
|
| 386 |
+
else:
|
| 387 |
+
dim_head //= heads
|
| 388 |
+
q, k, v = map(
|
| 389 |
+
lambda t: t.reshape(b, -1, heads, dim_head),
|
| 390 |
+
(q, k, v),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if mask is not None:
|
| 394 |
+
# add a singleton batch dimension
|
| 395 |
+
if mask.ndim == 2:
|
| 396 |
+
mask = mask.unsqueeze(0)
|
| 397 |
+
# add a singleton heads dimension
|
| 398 |
+
if mask.ndim == 3:
|
| 399 |
+
mask = mask.unsqueeze(1)
|
| 400 |
+
# pad to a multiple of 8
|
| 401 |
+
pad = 8 - mask.shape[-1] % 8
|
| 402 |
+
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
|
| 403 |
+
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
|
| 404 |
+
# in flux, this matrix ends up being over 1GB
|
| 405 |
+
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
|
| 406 |
+
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
| 407 |
+
|
| 408 |
+
mask_out[..., :mask.shape[-1]] = mask
|
| 409 |
+
# doesn't this remove the padding again??
|
| 410 |
+
mask = mask_out[..., :mask.shape[-1]]
|
| 411 |
+
mask = mask.expand(b, heads, -1, -1)
|
| 412 |
+
|
| 413 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
| 414 |
+
|
| 415 |
+
if skip_output_reshape:
|
| 416 |
+
out = out.permute(0, 2, 1, 3)
|
| 417 |
+
else:
|
| 418 |
+
out = (
|
| 419 |
+
out.reshape(b, -1, heads * dim_head)
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
return out
|
| 423 |
+
|
| 424 |
+
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
| 425 |
+
SDP_BATCH_LIMIT = 2**15
|
| 426 |
+
else:
|
| 427 |
+
#TODO: other GPUs ?
|
| 428 |
+
SDP_BATCH_LIMIT = 2**31
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
| 432 |
+
if skip_reshape:
|
| 433 |
+
b, _, _, dim_head = q.shape
|
| 434 |
+
else:
|
| 435 |
+
b, _, dim_head = q.shape
|
| 436 |
+
dim_head //= heads
|
| 437 |
+
q, k, v = map(
|
| 438 |
+
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
| 439 |
+
(q, k, v),
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if mask is not None:
|
| 443 |
+
# add a batch dimension if there isn't already one
|
| 444 |
+
if mask.ndim == 2:
|
| 445 |
+
mask = mask.unsqueeze(0)
|
| 446 |
+
# add a heads dimension if there isn't already one
|
| 447 |
+
if mask.ndim == 3:
|
| 448 |
+
mask = mask.unsqueeze(1)
|
| 449 |
+
|
| 450 |
+
if SDP_BATCH_LIMIT >= b:
|
| 451 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
| 452 |
+
if not skip_output_reshape:
|
| 453 |
+
out = (
|
| 454 |
+
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
| 458 |
+
for i in range(0, b, SDP_BATCH_LIMIT):
|
| 459 |
+
m = mask
|
| 460 |
+
if mask is not None:
|
| 461 |
+
if mask.shape[0] > 1:
|
| 462 |
+
m = mask[i : i + SDP_BATCH_LIMIT]
|
| 463 |
+
|
| 464 |
+
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
| 465 |
+
q[i : i + SDP_BATCH_LIMIT],
|
| 466 |
+
k[i : i + SDP_BATCH_LIMIT],
|
| 467 |
+
v[i : i + SDP_BATCH_LIMIT],
|
| 468 |
+
attn_mask=m,
|
| 469 |
+
dropout_p=0.0, is_causal=False
|
| 470 |
+
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
| 471 |
+
return out
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
| 475 |
+
if skip_reshape:
|
| 476 |
+
b, _, _, dim_head = q.shape
|
| 477 |
+
tensor_layout = "HND"
|
| 478 |
+
else:
|
| 479 |
+
b, _, dim_head = q.shape
|
| 480 |
+
dim_head //= heads
|
| 481 |
+
q, k, v = map(
|
| 482 |
+
lambda t: t.view(b, -1, heads, dim_head),
|
| 483 |
+
(q, k, v),
|
| 484 |
+
)
|
| 485 |
+
tensor_layout = "NHD"
|
| 486 |
+
|
| 487 |
+
if mask is not None:
|
| 488 |
+
# add a batch dimension if there isn't already one
|
| 489 |
+
if mask.ndim == 2:
|
| 490 |
+
mask = mask.unsqueeze(0)
|
| 491 |
+
# add a heads dimension if there isn't already one
|
| 492 |
+
if mask.ndim == 3:
|
| 493 |
+
mask = mask.unsqueeze(1)
|
| 494 |
+
|
| 495 |
+
try:
|
| 496 |
+
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
| 497 |
+
except Exception as e:
|
| 498 |
+
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
| 499 |
+
if tensor_layout == "NHD":
|
| 500 |
+
q, k, v = map(
|
| 501 |
+
lambda t: t.transpose(1, 2),
|
| 502 |
+
(q, k, v),
|
| 503 |
+
)
|
| 504 |
+
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
| 505 |
+
|
| 506 |
+
if tensor_layout == "HND":
|
| 507 |
+
if not skip_output_reshape:
|
| 508 |
+
out = (
|
| 509 |
+
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
| 510 |
+
)
|
| 511 |
+
else:
|
| 512 |
+
if skip_output_reshape:
|
| 513 |
+
out = out.transpose(1, 2)
|
| 514 |
+
else:
|
| 515 |
+
out = out.reshape(b, -1, heads * dim_head)
|
| 516 |
+
return out
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
try:
|
| 520 |
+
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
| 521 |
+
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 522 |
+
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
| 523 |
+
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
@flash_attn_wrapper.register_fake
|
| 527 |
+
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
|
| 528 |
+
# Output shape is the same as q
|
| 529 |
+
return q.new_empty(q.shape)
|
| 530 |
+
except AttributeError as error:
|
| 531 |
+
FLASH_ATTN_ERROR = error
|
| 532 |
+
|
| 533 |
+
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 534 |
+
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
| 535 |
+
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
| 539 |
+
if skip_reshape:
|
| 540 |
+
b, _, _, dim_head = q.shape
|
| 541 |
+
else:
|
| 542 |
+
b, _, dim_head = q.shape
|
| 543 |
+
dim_head //= heads
|
| 544 |
+
q, k, v = map(
|
| 545 |
+
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
| 546 |
+
(q, k, v),
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
if mask is not None:
|
| 550 |
+
# add a batch dimension if there isn't already one
|
| 551 |
+
if mask.ndim == 2:
|
| 552 |
+
mask = mask.unsqueeze(0)
|
| 553 |
+
# add a heads dimension if there isn't already one
|
| 554 |
+
if mask.ndim == 3:
|
| 555 |
+
mask = mask.unsqueeze(1)
|
| 556 |
+
|
| 557 |
+
try:
|
| 558 |
+
assert mask is None
|
| 559 |
+
out = flash_attn_wrapper(
|
| 560 |
+
q.transpose(1, 2),
|
| 561 |
+
k.transpose(1, 2),
|
| 562 |
+
v.transpose(1, 2),
|
| 563 |
+
dropout_p=0.0,
|
| 564 |
+
causal=False,
|
| 565 |
+
).transpose(1, 2)
|
| 566 |
+
except Exception as e:
|
| 567 |
+
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
| 568 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
| 569 |
+
if not skip_output_reshape:
|
| 570 |
+
out = (
|
| 571 |
+
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
| 572 |
+
)
|
| 573 |
+
return out
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
optimized_attention = attention_basic
|
| 577 |
+
|
| 578 |
+
if model_management.sage_attention_enabled():
|
| 579 |
+
logging.info("Using sage attention")
|
| 580 |
+
optimized_attention = attention_sage
|
| 581 |
+
elif model_management.xformers_enabled():
|
| 582 |
+
logging.info("Using xformers attention")
|
| 583 |
+
optimized_attention = attention_xformers
|
| 584 |
+
elif model_management.flash_attention_enabled():
|
| 585 |
+
logging.info("Using Flash Attention")
|
| 586 |
+
optimized_attention = attention_flash
|
| 587 |
+
elif model_management.pytorch_attention_enabled():
|
| 588 |
+
logging.info("Using pytorch attention")
|
| 589 |
+
optimized_attention = attention_pytorch
|
| 590 |
+
else:
|
| 591 |
+
if args.use_split_cross_attention:
|
| 592 |
+
logging.info("Using split optimization for attention")
|
| 593 |
+
optimized_attention = attention_split
|
| 594 |
+
else:
|
| 595 |
+
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
| 596 |
+
optimized_attention = attention_sub_quad
|
| 597 |
+
|
| 598 |
+
optimized_attention_masked = optimized_attention
|
| 599 |
+
|
| 600 |
+
def optimized_attention_for_device(device, mask=False, small_input=False):
|
| 601 |
+
if small_input:
|
| 602 |
+
if model_management.pytorch_attention_enabled():
|
| 603 |
+
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
| 604 |
+
else:
|
| 605 |
+
return attention_basic
|
| 606 |
+
|
| 607 |
+
if device == torch.device("cpu"):
|
| 608 |
+
return attention_sub_quad
|
| 609 |
+
|
| 610 |
+
if mask:
|
| 611 |
+
return optimized_attention_masked
|
| 612 |
+
|
| 613 |
+
return optimized_attention
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class CrossAttention(nn.Module):
|
| 617 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
|
| 618 |
+
super().__init__()
|
| 619 |
+
inner_dim = dim_head * heads
|
| 620 |
+
context_dim = default(context_dim, query_dim)
|
| 621 |
+
self.attn_precision = attn_precision
|
| 622 |
+
|
| 623 |
+
self.heads = heads
|
| 624 |
+
self.dim_head = dim_head
|
| 625 |
+
|
| 626 |
+
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 627 |
+
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 628 |
+
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 629 |
+
|
| 630 |
+
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
| 631 |
+
|
| 632 |
+
def forward(self, x, context=None, value=None, mask=None):
|
| 633 |
+
q = self.to_q(x)
|
| 634 |
+
context = default(context, x)
|
| 635 |
+
k = self.to_k(context)
|
| 636 |
+
if value is not None:
|
| 637 |
+
v = self.to_v(value)
|
| 638 |
+
del value
|
| 639 |
+
else:
|
| 640 |
+
v = self.to_v(context)
|
| 641 |
+
|
| 642 |
+
if mask is None:
|
| 643 |
+
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
| 644 |
+
else:
|
| 645 |
+
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
| 646 |
+
return self.to_out(out)
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
class BasicTransformerBlock(nn.Module):
|
| 650 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
| 651 |
+
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
|
| 652 |
+
super().__init__()
|
| 653 |
+
|
| 654 |
+
self.ff_in = ff_in or inner_dim is not None
|
| 655 |
+
if inner_dim is None:
|
| 656 |
+
inner_dim = dim
|
| 657 |
+
|
| 658 |
+
self.is_res = inner_dim == dim
|
| 659 |
+
self.attn_precision = attn_precision
|
| 660 |
+
|
| 661 |
+
if self.ff_in:
|
| 662 |
+
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
| 663 |
+
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
| 664 |
+
|
| 665 |
+
self.disable_self_attn = disable_self_attn
|
| 666 |
+
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
| 667 |
+
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
| 668 |
+
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
| 669 |
+
|
| 670 |
+
if disable_temporal_crossattention:
|
| 671 |
+
if switch_temporal_ca_to_sa:
|
| 672 |
+
raise ValueError
|
| 673 |
+
else:
|
| 674 |
+
self.attn2 = None
|
| 675 |
+
else:
|
| 676 |
+
context_dim_attn2 = None
|
| 677 |
+
if not switch_temporal_ca_to_sa:
|
| 678 |
+
context_dim_attn2 = context_dim
|
| 679 |
+
|
| 680 |
+
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
| 681 |
+
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
| 682 |
+
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
| 683 |
+
|
| 684 |
+
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
| 685 |
+
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
| 686 |
+
self.n_heads = n_heads
|
| 687 |
+
self.d_head = d_head
|
| 688 |
+
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
| 689 |
+
|
| 690 |
+
def forward(self, x, context=None, transformer_options={}):
|
| 691 |
+
extra_options = {}
|
| 692 |
+
block = transformer_options.get("block", None)
|
| 693 |
+
block_index = transformer_options.get("block_index", 0)
|
| 694 |
+
transformer_patches = {}
|
| 695 |
+
transformer_patches_replace = {}
|
| 696 |
+
|
| 697 |
+
for k in transformer_options:
|
| 698 |
+
if k == "patches":
|
| 699 |
+
transformer_patches = transformer_options[k]
|
| 700 |
+
elif k == "patches_replace":
|
| 701 |
+
transformer_patches_replace = transformer_options[k]
|
| 702 |
+
else:
|
| 703 |
+
extra_options[k] = transformer_options[k]
|
| 704 |
+
|
| 705 |
+
extra_options["n_heads"] = self.n_heads
|
| 706 |
+
extra_options["dim_head"] = self.d_head
|
| 707 |
+
extra_options["attn_precision"] = self.attn_precision
|
| 708 |
+
|
| 709 |
+
if self.ff_in:
|
| 710 |
+
x_skip = x
|
| 711 |
+
x = self.ff_in(self.norm_in(x))
|
| 712 |
+
if self.is_res:
|
| 713 |
+
x += x_skip
|
| 714 |
+
|
| 715 |
+
n = self.norm1(x)
|
| 716 |
+
if self.disable_self_attn:
|
| 717 |
+
context_attn1 = context
|
| 718 |
+
else:
|
| 719 |
+
context_attn1 = None
|
| 720 |
+
value_attn1 = None
|
| 721 |
+
|
| 722 |
+
if "attn1_patch" in transformer_patches:
|
| 723 |
+
patch = transformer_patches["attn1_patch"]
|
| 724 |
+
if context_attn1 is None:
|
| 725 |
+
context_attn1 = n
|
| 726 |
+
value_attn1 = context_attn1
|
| 727 |
+
for p in patch:
|
| 728 |
+
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
| 729 |
+
|
| 730 |
+
if block is not None:
|
| 731 |
+
transformer_block = (block[0], block[1], block_index)
|
| 732 |
+
else:
|
| 733 |
+
transformer_block = None
|
| 734 |
+
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
| 735 |
+
block_attn1 = transformer_block
|
| 736 |
+
if block_attn1 not in attn1_replace_patch:
|
| 737 |
+
block_attn1 = block
|
| 738 |
+
|
| 739 |
+
if block_attn1 in attn1_replace_patch:
|
| 740 |
+
if context_attn1 is None:
|
| 741 |
+
context_attn1 = n
|
| 742 |
+
value_attn1 = n
|
| 743 |
+
n = self.attn1.to_q(n)
|
| 744 |
+
context_attn1 = self.attn1.to_k(context_attn1)
|
| 745 |
+
value_attn1 = self.attn1.to_v(value_attn1)
|
| 746 |
+
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
| 747 |
+
n = self.attn1.to_out(n)
|
| 748 |
+
else:
|
| 749 |
+
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
| 750 |
+
|
| 751 |
+
if "attn1_output_patch" in transformer_patches:
|
| 752 |
+
patch = transformer_patches["attn1_output_patch"]
|
| 753 |
+
for p in patch:
|
| 754 |
+
n = p(n, extra_options)
|
| 755 |
+
|
| 756 |
+
x = n + x
|
| 757 |
+
if "middle_patch" in transformer_patches:
|
| 758 |
+
patch = transformer_patches["middle_patch"]
|
| 759 |
+
for p in patch:
|
| 760 |
+
x = p(x, extra_options)
|
| 761 |
+
|
| 762 |
+
if self.attn2 is not None:
|
| 763 |
+
n = self.norm2(x)
|
| 764 |
+
if self.switch_temporal_ca_to_sa:
|
| 765 |
+
context_attn2 = n
|
| 766 |
+
else:
|
| 767 |
+
context_attn2 = context
|
| 768 |
+
value_attn2 = None
|
| 769 |
+
if "attn2_patch" in transformer_patches:
|
| 770 |
+
patch = transformer_patches["attn2_patch"]
|
| 771 |
+
value_attn2 = context_attn2
|
| 772 |
+
for p in patch:
|
| 773 |
+
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
| 774 |
+
|
| 775 |
+
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
| 776 |
+
block_attn2 = transformer_block
|
| 777 |
+
if block_attn2 not in attn2_replace_patch:
|
| 778 |
+
block_attn2 = block
|
| 779 |
+
|
| 780 |
+
if block_attn2 in attn2_replace_patch:
|
| 781 |
+
if value_attn2 is None:
|
| 782 |
+
value_attn2 = context_attn2
|
| 783 |
+
n = self.attn2.to_q(n)
|
| 784 |
+
context_attn2 = self.attn2.to_k(context_attn2)
|
| 785 |
+
value_attn2 = self.attn2.to_v(value_attn2)
|
| 786 |
+
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
| 787 |
+
n = self.attn2.to_out(n)
|
| 788 |
+
else:
|
| 789 |
+
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
| 790 |
+
|
| 791 |
+
if "attn2_output_patch" in transformer_patches:
|
| 792 |
+
patch = transformer_patches["attn2_output_patch"]
|
| 793 |
+
for p in patch:
|
| 794 |
+
n = p(n, extra_options)
|
| 795 |
+
|
| 796 |
+
x = n + x
|
| 797 |
+
if self.is_res:
|
| 798 |
+
x_skip = x
|
| 799 |
+
x = self.ff(self.norm3(x))
|
| 800 |
+
if self.is_res:
|
| 801 |
+
x = x_skip + x
|
| 802 |
+
|
| 803 |
+
return x
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
class SpatialTransformer(nn.Module):
|
| 807 |
+
"""
|
| 808 |
+
Transformer block for image-like data.
|
| 809 |
+
First, project the input (aka embedding)
|
| 810 |
+
and reshape to b, t, d.
|
| 811 |
+
Then apply standard transformer action.
|
| 812 |
+
Finally, reshape to image
|
| 813 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
| 814 |
+
"""
|
| 815 |
+
def __init__(self, in_channels, n_heads, d_head,
|
| 816 |
+
depth=1, dropout=0., context_dim=None,
|
| 817 |
+
disable_self_attn=False, use_linear=False,
|
| 818 |
+
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
|
| 819 |
+
super().__init__()
|
| 820 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
| 821 |
+
context_dim = [context_dim] * depth
|
| 822 |
+
self.in_channels = in_channels
|
| 823 |
+
inner_dim = n_heads * d_head
|
| 824 |
+
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
| 825 |
+
if not use_linear:
|
| 826 |
+
self.proj_in = operations.Conv2d(in_channels,
|
| 827 |
+
inner_dim,
|
| 828 |
+
kernel_size=1,
|
| 829 |
+
stride=1,
|
| 830 |
+
padding=0, dtype=dtype, device=device)
|
| 831 |
+
else:
|
| 832 |
+
self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
| 833 |
+
|
| 834 |
+
self.transformer_blocks = nn.ModuleList(
|
| 835 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
| 836 |
+
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
| 837 |
+
for d in range(depth)]
|
| 838 |
+
)
|
| 839 |
+
if not use_linear:
|
| 840 |
+
self.proj_out = operations.Conv2d(inner_dim,in_channels,
|
| 841 |
+
kernel_size=1,
|
| 842 |
+
stride=1,
|
| 843 |
+
padding=0, dtype=dtype, device=device)
|
| 844 |
+
else:
|
| 845 |
+
self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
| 846 |
+
self.use_linear = use_linear
|
| 847 |
+
|
| 848 |
+
def forward(self, x, context=None, transformer_options={}):
|
| 849 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
| 850 |
+
if not isinstance(context, list):
|
| 851 |
+
context = [context] * len(self.transformer_blocks)
|
| 852 |
+
b, c, h, w = x.shape
|
| 853 |
+
transformer_options["activations_shape"] = list(x.shape)
|
| 854 |
+
x_in = x
|
| 855 |
+
x = self.norm(x)
|
| 856 |
+
if not self.use_linear:
|
| 857 |
+
x = self.proj_in(x)
|
| 858 |
+
x = x.movedim(1, 3).flatten(1, 2).contiguous()
|
| 859 |
+
if self.use_linear:
|
| 860 |
+
x = self.proj_in(x)
|
| 861 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 862 |
+
transformer_options["block_index"] = i
|
| 863 |
+
x = block(x, context=context[i], transformer_options=transformer_options)
|
| 864 |
+
if self.use_linear:
|
| 865 |
+
x = self.proj_out(x)
|
| 866 |
+
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
|
| 867 |
+
if not self.use_linear:
|
| 868 |
+
x = self.proj_out(x)
|
| 869 |
+
return x + x_in
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
class SpatialVideoTransformer(SpatialTransformer):
|
| 873 |
+
def __init__(
|
| 874 |
+
self,
|
| 875 |
+
in_channels,
|
| 876 |
+
n_heads,
|
| 877 |
+
d_head,
|
| 878 |
+
depth=1,
|
| 879 |
+
dropout=0.0,
|
| 880 |
+
use_linear=False,
|
| 881 |
+
context_dim=None,
|
| 882 |
+
use_spatial_context=False,
|
| 883 |
+
timesteps=None,
|
| 884 |
+
merge_strategy: str = "fixed",
|
| 885 |
+
merge_factor: float = 0.5,
|
| 886 |
+
time_context_dim=None,
|
| 887 |
+
ff_in=False,
|
| 888 |
+
checkpoint=False,
|
| 889 |
+
time_depth=1,
|
| 890 |
+
disable_self_attn=False,
|
| 891 |
+
disable_temporal_crossattention=False,
|
| 892 |
+
max_time_embed_period: int = 10000,
|
| 893 |
+
attn_precision=None,
|
| 894 |
+
dtype=None, device=None, operations=ops
|
| 895 |
+
):
|
| 896 |
+
super().__init__(
|
| 897 |
+
in_channels,
|
| 898 |
+
n_heads,
|
| 899 |
+
d_head,
|
| 900 |
+
depth=depth,
|
| 901 |
+
dropout=dropout,
|
| 902 |
+
use_checkpoint=checkpoint,
|
| 903 |
+
context_dim=context_dim,
|
| 904 |
+
use_linear=use_linear,
|
| 905 |
+
disable_self_attn=disable_self_attn,
|
| 906 |
+
attn_precision=attn_precision,
|
| 907 |
+
dtype=dtype, device=device, operations=operations
|
| 908 |
+
)
|
| 909 |
+
self.time_depth = time_depth
|
| 910 |
+
self.depth = depth
|
| 911 |
+
self.max_time_embed_period = max_time_embed_period
|
| 912 |
+
|
| 913 |
+
time_mix_d_head = d_head
|
| 914 |
+
n_time_mix_heads = n_heads
|
| 915 |
+
|
| 916 |
+
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
| 917 |
+
|
| 918 |
+
inner_dim = n_heads * d_head
|
| 919 |
+
if use_spatial_context:
|
| 920 |
+
time_context_dim = context_dim
|
| 921 |
+
|
| 922 |
+
self.time_stack = nn.ModuleList(
|
| 923 |
+
[
|
| 924 |
+
BasicTransformerBlock(
|
| 925 |
+
inner_dim,
|
| 926 |
+
n_time_mix_heads,
|
| 927 |
+
time_mix_d_head,
|
| 928 |
+
dropout=dropout,
|
| 929 |
+
context_dim=time_context_dim,
|
| 930 |
+
# timesteps=timesteps,
|
| 931 |
+
checkpoint=checkpoint,
|
| 932 |
+
ff_in=ff_in,
|
| 933 |
+
inner_dim=time_mix_inner_dim,
|
| 934 |
+
disable_self_attn=disable_self_attn,
|
| 935 |
+
disable_temporal_crossattention=disable_temporal_crossattention,
|
| 936 |
+
attn_precision=attn_precision,
|
| 937 |
+
dtype=dtype, device=device, operations=operations
|
| 938 |
+
)
|
| 939 |
+
for _ in range(self.depth)
|
| 940 |
+
]
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
assert len(self.time_stack) == len(self.transformer_blocks)
|
| 944 |
+
|
| 945 |
+
self.use_spatial_context = use_spatial_context
|
| 946 |
+
self.in_channels = in_channels
|
| 947 |
+
|
| 948 |
+
time_embed_dim = self.in_channels * 4
|
| 949 |
+
self.time_pos_embed = nn.Sequential(
|
| 950 |
+
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
|
| 951 |
+
nn.SiLU(),
|
| 952 |
+
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
self.time_mixer = AlphaBlender(
|
| 956 |
+
alpha=merge_factor, merge_strategy=merge_strategy
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
def forward(
|
| 960 |
+
self,
|
| 961 |
+
x: torch.Tensor,
|
| 962 |
+
context: Optional[torch.Tensor] = None,
|
| 963 |
+
time_context: Optional[torch.Tensor] = None,
|
| 964 |
+
timesteps: Optional[int] = None,
|
| 965 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
| 966 |
+
transformer_options={}
|
| 967 |
+
) -> torch.Tensor:
|
| 968 |
+
_, _, h, w = x.shape
|
| 969 |
+
transformer_options["activations_shape"] = list(x.shape)
|
| 970 |
+
x_in = x
|
| 971 |
+
spatial_context = None
|
| 972 |
+
if exists(context):
|
| 973 |
+
spatial_context = context
|
| 974 |
+
|
| 975 |
+
if self.use_spatial_context:
|
| 976 |
+
assert (
|
| 977 |
+
context.ndim == 3
|
| 978 |
+
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
| 979 |
+
|
| 980 |
+
if time_context is None:
|
| 981 |
+
time_context = context
|
| 982 |
+
time_context_first_timestep = time_context[::timesteps]
|
| 983 |
+
time_context = repeat(
|
| 984 |
+
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
| 985 |
+
)
|
| 986 |
+
elif time_context is not None and not self.use_spatial_context:
|
| 987 |
+
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
| 988 |
+
if time_context.ndim == 2:
|
| 989 |
+
time_context = rearrange(time_context, "b c -> b 1 c")
|
| 990 |
+
|
| 991 |
+
x = self.norm(x)
|
| 992 |
+
if not self.use_linear:
|
| 993 |
+
x = self.proj_in(x)
|
| 994 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 995 |
+
if self.use_linear:
|
| 996 |
+
x = self.proj_in(x)
|
| 997 |
+
|
| 998 |
+
num_frames = torch.arange(timesteps, device=x.device)
|
| 999 |
+
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
| 1000 |
+
num_frames = rearrange(num_frames, "b t -> (b t)")
|
| 1001 |
+
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
|
| 1002 |
+
emb = self.time_pos_embed(t_emb)
|
| 1003 |
+
emb = emb[:, None, :]
|
| 1004 |
+
|
| 1005 |
+
for it_, (block, mix_block) in enumerate(
|
| 1006 |
+
zip(self.transformer_blocks, self.time_stack)
|
| 1007 |
+
):
|
| 1008 |
+
transformer_options["block_index"] = it_
|
| 1009 |
+
x = block(
|
| 1010 |
+
x,
|
| 1011 |
+
context=spatial_context,
|
| 1012 |
+
transformer_options=transformer_options,
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
x_mix = x
|
| 1016 |
+
x_mix = x_mix + emb
|
| 1017 |
+
|
| 1018 |
+
B, S, C = x_mix.shape
|
| 1019 |
+
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
| 1020 |
+
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
| 1021 |
+
x_mix = rearrange(
|
| 1022 |
+
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
|
| 1026 |
+
|
| 1027 |
+
if self.use_linear:
|
| 1028 |
+
x = self.proj_out(x)
|
| 1029 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 1030 |
+
if not self.use_linear:
|
| 1031 |
+
x = self.proj_out(x)
|
| 1032 |
+
out = x + x_in
|
| 1033 |
+
return out
|
| 1034 |
+
|
| 1035 |
+
|
ComfyUI/comfy/ldm/modules/ema.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LitEma(nn.Module):
|
| 6 |
+
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
| 7 |
+
super().__init__()
|
| 8 |
+
if decay < 0.0 or decay > 1.0:
|
| 9 |
+
raise ValueError('Decay must be between 0 and 1')
|
| 10 |
+
|
| 11 |
+
self.m_name2s_name = {}
|
| 12 |
+
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
| 13 |
+
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
| 14 |
+
else torch.tensor(-1, dtype=torch.int))
|
| 15 |
+
|
| 16 |
+
for name, p in model.named_parameters():
|
| 17 |
+
if p.requires_grad:
|
| 18 |
+
# remove as '.'-character is not allowed in buffers
|
| 19 |
+
s_name = name.replace('.', '')
|
| 20 |
+
self.m_name2s_name.update({name: s_name})
|
| 21 |
+
self.register_buffer(s_name, p.clone().detach().data)
|
| 22 |
+
|
| 23 |
+
self.collected_params = []
|
| 24 |
+
|
| 25 |
+
def reset_num_updates(self):
|
| 26 |
+
del self.num_updates
|
| 27 |
+
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
| 28 |
+
|
| 29 |
+
def forward(self, model):
|
| 30 |
+
decay = self.decay
|
| 31 |
+
|
| 32 |
+
if self.num_updates >= 0:
|
| 33 |
+
self.num_updates += 1
|
| 34 |
+
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
| 35 |
+
|
| 36 |
+
one_minus_decay = 1.0 - decay
|
| 37 |
+
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
m_param = dict(model.named_parameters())
|
| 40 |
+
shadow_params = dict(self.named_buffers())
|
| 41 |
+
|
| 42 |
+
for key in m_param:
|
| 43 |
+
if m_param[key].requires_grad:
|
| 44 |
+
sname = self.m_name2s_name[key]
|
| 45 |
+
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
| 46 |
+
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
| 47 |
+
else:
|
| 48 |
+
assert not key in self.m_name2s_name
|
| 49 |
+
|
| 50 |
+
def copy_to(self, model):
|
| 51 |
+
m_param = dict(model.named_parameters())
|
| 52 |
+
shadow_params = dict(self.named_buffers())
|
| 53 |
+
for key in m_param:
|
| 54 |
+
if m_param[key].requires_grad:
|
| 55 |
+
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
| 56 |
+
else:
|
| 57 |
+
assert not key in self.m_name2s_name
|
| 58 |
+
|
| 59 |
+
def store(self, parameters):
|
| 60 |
+
"""
|
| 61 |
+
Save the current parameters for restoring later.
|
| 62 |
+
Args:
|
| 63 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 64 |
+
temporarily stored.
|
| 65 |
+
"""
|
| 66 |
+
self.collected_params = [param.clone() for param in parameters]
|
| 67 |
+
|
| 68 |
+
def restore(self, parameters):
|
| 69 |
+
"""
|
| 70 |
+
Restore the parameters stored with the `store` method.
|
| 71 |
+
Useful to validate the model with EMA parameters without affecting the
|
| 72 |
+
original optimization process. Store the parameters before the
|
| 73 |
+
`copy_to` method. After validation (or model saving), use this to
|
| 74 |
+
restore the former parameters.
|
| 75 |
+
Args:
|
| 76 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 77 |
+
updated with the stored parameters.
|
| 78 |
+
"""
|
| 79 |
+
for c_param, param in zip(self.collected_params, parameters):
|
| 80 |
+
param.data.copy_(c_param.data)
|
ComfyUI/comfy/ldm/modules/sub_quadratic_attention.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# original source:
|
| 2 |
+
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
| 3 |
+
# license:
|
| 4 |
+
# MIT
|
| 5 |
+
# credit:
|
| 6 |
+
# Amin Rezaei (original author)
|
| 7 |
+
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
| 8 |
+
# implementation of:
|
| 9 |
+
# Self-attention Does Not Need O(n2) Memory":
|
| 10 |
+
# https://arxiv.org/abs/2112.05682v2
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch.utils.checkpoint import checkpoint
|
| 16 |
+
import math
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from typing import Optional, NamedTuple, List, Protocol
|
| 21 |
+
except ImportError:
|
| 22 |
+
from typing import Optional, NamedTuple, List
|
| 23 |
+
from typing_extensions import Protocol
|
| 24 |
+
|
| 25 |
+
from typing import List
|
| 26 |
+
|
| 27 |
+
from comfy import model_management
|
| 28 |
+
|
| 29 |
+
def dynamic_slice(
|
| 30 |
+
x: Tensor,
|
| 31 |
+
starts: List[int],
|
| 32 |
+
sizes: List[int],
|
| 33 |
+
) -> Tensor:
|
| 34 |
+
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
|
| 35 |
+
return x[slicing]
|
| 36 |
+
|
| 37 |
+
class AttnChunk(NamedTuple):
|
| 38 |
+
exp_values: Tensor
|
| 39 |
+
exp_weights_sum: Tensor
|
| 40 |
+
max_score: Tensor
|
| 41 |
+
|
| 42 |
+
class SummarizeChunk(Protocol):
|
| 43 |
+
@staticmethod
|
| 44 |
+
def __call__(
|
| 45 |
+
query: Tensor,
|
| 46 |
+
key_t: Tensor,
|
| 47 |
+
value: Tensor,
|
| 48 |
+
) -> AttnChunk: ...
|
| 49 |
+
|
| 50 |
+
class ComputeQueryChunkAttn(Protocol):
|
| 51 |
+
@staticmethod
|
| 52 |
+
def __call__(
|
| 53 |
+
query: Tensor,
|
| 54 |
+
key_t: Tensor,
|
| 55 |
+
value: Tensor,
|
| 56 |
+
) -> Tensor: ...
|
| 57 |
+
|
| 58 |
+
def _summarize_chunk(
|
| 59 |
+
query: Tensor,
|
| 60 |
+
key_t: Tensor,
|
| 61 |
+
value: Tensor,
|
| 62 |
+
scale: float,
|
| 63 |
+
upcast_attention: bool,
|
| 64 |
+
mask,
|
| 65 |
+
) -> AttnChunk:
|
| 66 |
+
if upcast_attention:
|
| 67 |
+
with torch.autocast(enabled=False, device_type = 'cuda'):
|
| 68 |
+
query = query.float()
|
| 69 |
+
key_t = key_t.float()
|
| 70 |
+
attn_weights = torch.baddbmm(
|
| 71 |
+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
| 72 |
+
query,
|
| 73 |
+
key_t,
|
| 74 |
+
alpha=scale,
|
| 75 |
+
beta=0,
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
attn_weights = torch.baddbmm(
|
| 79 |
+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
| 80 |
+
query,
|
| 81 |
+
key_t,
|
| 82 |
+
alpha=scale,
|
| 83 |
+
beta=0,
|
| 84 |
+
)
|
| 85 |
+
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
| 86 |
+
max_score = max_score.detach()
|
| 87 |
+
attn_weights -= max_score
|
| 88 |
+
if mask is not None:
|
| 89 |
+
attn_weights += mask
|
| 90 |
+
torch.exp(attn_weights, out=attn_weights)
|
| 91 |
+
exp_weights = attn_weights.to(value.dtype)
|
| 92 |
+
exp_values = torch.bmm(exp_weights, value)
|
| 93 |
+
max_score = max_score.squeeze(-1)
|
| 94 |
+
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
| 95 |
+
|
| 96 |
+
def _query_chunk_attention(
|
| 97 |
+
query: Tensor,
|
| 98 |
+
key_t: Tensor,
|
| 99 |
+
value: Tensor,
|
| 100 |
+
summarize_chunk: SummarizeChunk,
|
| 101 |
+
kv_chunk_size: int,
|
| 102 |
+
mask,
|
| 103 |
+
) -> Tensor:
|
| 104 |
+
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
| 105 |
+
_, _, v_channels_per_head = value.shape
|
| 106 |
+
|
| 107 |
+
def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
|
| 108 |
+
key_chunk = dynamic_slice(
|
| 109 |
+
key_t,
|
| 110 |
+
(0, 0, chunk_idx),
|
| 111 |
+
(batch_x_heads, k_channels_per_head, kv_chunk_size)
|
| 112 |
+
)
|
| 113 |
+
value_chunk = dynamic_slice(
|
| 114 |
+
value,
|
| 115 |
+
(0, chunk_idx, 0),
|
| 116 |
+
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
| 117 |
+
)
|
| 118 |
+
if mask is not None:
|
| 119 |
+
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
|
| 120 |
+
|
| 121 |
+
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
|
| 122 |
+
|
| 123 |
+
chunks: List[AttnChunk] = [
|
| 124 |
+
chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
| 125 |
+
]
|
| 126 |
+
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
| 127 |
+
chunk_values, chunk_weights, chunk_max = acc_chunk
|
| 128 |
+
|
| 129 |
+
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
| 130 |
+
max_diffs = torch.exp(chunk_max - global_max)
|
| 131 |
+
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
| 132 |
+
chunk_weights *= max_diffs
|
| 133 |
+
|
| 134 |
+
all_values = chunk_values.sum(dim=0)
|
| 135 |
+
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
| 136 |
+
return all_values / all_weights
|
| 137 |
+
|
| 138 |
+
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
| 139 |
+
def _get_attention_scores_no_kv_chunking(
|
| 140 |
+
query: Tensor,
|
| 141 |
+
key_t: Tensor,
|
| 142 |
+
value: Tensor,
|
| 143 |
+
scale: float,
|
| 144 |
+
upcast_attention: bool,
|
| 145 |
+
mask,
|
| 146 |
+
) -> Tensor:
|
| 147 |
+
if upcast_attention:
|
| 148 |
+
with torch.autocast(enabled=False, device_type = 'cuda'):
|
| 149 |
+
query = query.float()
|
| 150 |
+
key_t = key_t.float()
|
| 151 |
+
attn_scores = torch.baddbmm(
|
| 152 |
+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
| 153 |
+
query,
|
| 154 |
+
key_t,
|
| 155 |
+
alpha=scale,
|
| 156 |
+
beta=0,
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
attn_scores = torch.baddbmm(
|
| 160 |
+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
| 161 |
+
query,
|
| 162 |
+
key_t,
|
| 163 |
+
alpha=scale,
|
| 164 |
+
beta=0,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if mask is not None:
|
| 168 |
+
attn_scores += mask
|
| 169 |
+
try:
|
| 170 |
+
attn_probs = attn_scores.softmax(dim=-1)
|
| 171 |
+
del attn_scores
|
| 172 |
+
except model_management.OOM_EXCEPTION:
|
| 173 |
+
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
| 174 |
+
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
| 175 |
+
torch.exp(attn_scores, out=attn_scores)
|
| 176 |
+
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
| 177 |
+
attn_scores /= summed
|
| 178 |
+
attn_probs = attn_scores
|
| 179 |
+
|
| 180 |
+
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
| 181 |
+
return hidden_states_slice
|
| 182 |
+
|
| 183 |
+
class ScannedChunk(NamedTuple):
|
| 184 |
+
chunk_idx: int
|
| 185 |
+
attn_chunk: AttnChunk
|
| 186 |
+
|
| 187 |
+
def efficient_dot_product_attention(
|
| 188 |
+
query: Tensor,
|
| 189 |
+
key_t: Tensor,
|
| 190 |
+
value: Tensor,
|
| 191 |
+
query_chunk_size=1024,
|
| 192 |
+
kv_chunk_size: Optional[int] = None,
|
| 193 |
+
kv_chunk_size_min: Optional[int] = None,
|
| 194 |
+
use_checkpoint=True,
|
| 195 |
+
upcast_attention=False,
|
| 196 |
+
mask = None,
|
| 197 |
+
):
|
| 198 |
+
"""Computes efficient dot-product attention given query, transposed key, and value.
|
| 199 |
+
This is efficient version of attention presented in
|
| 200 |
+
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
| 201 |
+
Args:
|
| 202 |
+
query: queries for calculating attention with shape of
|
| 203 |
+
`[batch * num_heads, tokens, channels_per_head]`.
|
| 204 |
+
key_t: keys for calculating attention with shape of
|
| 205 |
+
`[batch * num_heads, channels_per_head, tokens]`.
|
| 206 |
+
value: values to be used in attention with shape of
|
| 207 |
+
`[batch * num_heads, tokens, channels_per_head]`.
|
| 208 |
+
query_chunk_size: int: query chunks size
|
| 209 |
+
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
| 210 |
+
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
| 211 |
+
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
| 212 |
+
Returns:
|
| 213 |
+
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
| 214 |
+
"""
|
| 215 |
+
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
| 216 |
+
_, _, k_tokens = key_t.shape
|
| 217 |
+
scale = q_channels_per_head ** -0.5
|
| 218 |
+
|
| 219 |
+
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
| 220 |
+
if kv_chunk_size_min is not None:
|
| 221 |
+
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
| 222 |
+
|
| 223 |
+
if mask is not None and len(mask.shape) == 2:
|
| 224 |
+
mask = mask.unsqueeze(0)
|
| 225 |
+
|
| 226 |
+
def get_query_chunk(chunk_idx: int) -> Tensor:
|
| 227 |
+
return dynamic_slice(
|
| 228 |
+
query,
|
| 229 |
+
(0, chunk_idx, 0),
|
| 230 |
+
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
| 234 |
+
if mask is None:
|
| 235 |
+
return None
|
| 236 |
+
if mask.shape[1] == 1:
|
| 237 |
+
return mask
|
| 238 |
+
chunk = min(query_chunk_size, q_tokens)
|
| 239 |
+
return mask[:,chunk_idx:chunk_idx + chunk]
|
| 240 |
+
|
| 241 |
+
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
| 242 |
+
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
| 243 |
+
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
| 244 |
+
_get_attention_scores_no_kv_chunking,
|
| 245 |
+
scale=scale,
|
| 246 |
+
upcast_attention=upcast_attention
|
| 247 |
+
) if k_tokens <= kv_chunk_size else (
|
| 248 |
+
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
| 249 |
+
partial(
|
| 250 |
+
_query_chunk_attention,
|
| 251 |
+
kv_chunk_size=kv_chunk_size,
|
| 252 |
+
summarize_chunk=summarize_chunk,
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
if q_tokens <= query_chunk_size:
|
| 257 |
+
# fast-path for when there's just 1 query chunk
|
| 258 |
+
return compute_query_chunk_attn(
|
| 259 |
+
query=query,
|
| 260 |
+
key_t=key_t,
|
| 261 |
+
value=value,
|
| 262 |
+
mask=mask,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
| 266 |
+
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
| 267 |
+
res = torch.cat([
|
| 268 |
+
compute_query_chunk_attn(
|
| 269 |
+
query=get_query_chunk(i * query_chunk_size),
|
| 270 |
+
key_t=key_t,
|
| 271 |
+
value=value,
|
| 272 |
+
mask=get_mask_chunk(i * query_chunk_size)
|
| 273 |
+
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
| 274 |
+
], dim=1)
|
| 275 |
+
return res
|
ComfyUI/comfy/ldm/modules/temporal_ae.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from typing import Iterable, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
|
| 7 |
+
import comfy.ops
|
| 8 |
+
ops = comfy.ops.disable_weight_init
|
| 9 |
+
|
| 10 |
+
from .diffusionmodules.model import (
|
| 11 |
+
AttnBlock,
|
| 12 |
+
Decoder,
|
| 13 |
+
ResnetBlock,
|
| 14 |
+
)
|
| 15 |
+
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
| 16 |
+
from .attention import BasicTransformerBlock
|
| 17 |
+
|
| 18 |
+
def partialclass(cls, *args, **kwargs):
|
| 19 |
+
class NewCls(cls):
|
| 20 |
+
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
| 21 |
+
|
| 22 |
+
return NewCls
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VideoResBlock(ResnetBlock):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
out_channels,
|
| 29 |
+
*args,
|
| 30 |
+
dropout=0.0,
|
| 31 |
+
video_kernel_size=3,
|
| 32 |
+
alpha=0.0,
|
| 33 |
+
merge_strategy="learned",
|
| 34 |
+
**kwargs,
|
| 35 |
+
):
|
| 36 |
+
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
| 37 |
+
if video_kernel_size is None:
|
| 38 |
+
video_kernel_size = [3, 1, 1]
|
| 39 |
+
self.time_stack = ResBlock(
|
| 40 |
+
channels=out_channels,
|
| 41 |
+
emb_channels=0,
|
| 42 |
+
dropout=dropout,
|
| 43 |
+
dims=3,
|
| 44 |
+
use_scale_shift_norm=False,
|
| 45 |
+
use_conv=False,
|
| 46 |
+
up=False,
|
| 47 |
+
down=False,
|
| 48 |
+
kernel_size=video_kernel_size,
|
| 49 |
+
use_checkpoint=False,
|
| 50 |
+
skip_t_emb=True,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.merge_strategy = merge_strategy
|
| 54 |
+
if self.merge_strategy == "fixed":
|
| 55 |
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
| 56 |
+
elif self.merge_strategy == "learned":
|
| 57 |
+
self.register_parameter(
|
| 58 |
+
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
| 62 |
+
|
| 63 |
+
def get_alpha(self, bs):
|
| 64 |
+
if self.merge_strategy == "fixed":
|
| 65 |
+
return self.mix_factor
|
| 66 |
+
elif self.merge_strategy == "learned":
|
| 67 |
+
return torch.sigmoid(self.mix_factor)
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError()
|
| 70 |
+
|
| 71 |
+
def forward(self, x, temb, skip_video=False, timesteps=None):
|
| 72 |
+
b, c, h, w = x.shape
|
| 73 |
+
if timesteps is None:
|
| 74 |
+
timesteps = b
|
| 75 |
+
|
| 76 |
+
x = super().forward(x, temb)
|
| 77 |
+
|
| 78 |
+
if not skip_video:
|
| 79 |
+
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
| 80 |
+
|
| 81 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
| 82 |
+
|
| 83 |
+
x = self.time_stack(x, temb)
|
| 84 |
+
|
| 85 |
+
alpha = self.get_alpha(bs=b // timesteps).to(x.device)
|
| 86 |
+
x = alpha * x + (1.0 - alpha) * x_mix
|
| 87 |
+
|
| 88 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class AE3DConv(ops.Conv2d):
|
| 93 |
+
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
| 94 |
+
super().__init__(in_channels, out_channels, *args, **kwargs)
|
| 95 |
+
if isinstance(video_kernel_size, Iterable):
|
| 96 |
+
padding = [int(k // 2) for k in video_kernel_size]
|
| 97 |
+
else:
|
| 98 |
+
padding = int(video_kernel_size // 2)
|
| 99 |
+
|
| 100 |
+
self.time_mix_conv = ops.Conv3d(
|
| 101 |
+
in_channels=out_channels,
|
| 102 |
+
out_channels=out_channels,
|
| 103 |
+
kernel_size=video_kernel_size,
|
| 104 |
+
padding=padding,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def forward(self, input, timesteps=None, skip_video=False):
|
| 108 |
+
if timesteps is None:
|
| 109 |
+
timesteps = input.shape[0]
|
| 110 |
+
x = super().forward(input)
|
| 111 |
+
if skip_video:
|
| 112 |
+
return x
|
| 113 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
| 114 |
+
x = self.time_mix_conv(x)
|
| 115 |
+
return rearrange(x, "b c t h w -> (b t) c h w")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class AttnVideoBlock(AttnBlock):
|
| 119 |
+
def __init__(
|
| 120 |
+
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
| 121 |
+
):
|
| 122 |
+
super().__init__(in_channels)
|
| 123 |
+
# no context, single headed, as in base class
|
| 124 |
+
self.time_mix_block = BasicTransformerBlock(
|
| 125 |
+
dim=in_channels,
|
| 126 |
+
n_heads=1,
|
| 127 |
+
d_head=in_channels,
|
| 128 |
+
checkpoint=False,
|
| 129 |
+
ff_in=True,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
time_embed_dim = self.in_channels * 4
|
| 133 |
+
self.video_time_embed = torch.nn.Sequential(
|
| 134 |
+
ops.Linear(self.in_channels, time_embed_dim),
|
| 135 |
+
torch.nn.SiLU(),
|
| 136 |
+
ops.Linear(time_embed_dim, self.in_channels),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.merge_strategy = merge_strategy
|
| 140 |
+
if self.merge_strategy == "fixed":
|
| 141 |
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
| 142 |
+
elif self.merge_strategy == "learned":
|
| 143 |
+
self.register_parameter(
|
| 144 |
+
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
| 148 |
+
|
| 149 |
+
def forward(self, x, timesteps=None, skip_time_block=False):
|
| 150 |
+
if skip_time_block:
|
| 151 |
+
return super().forward(x)
|
| 152 |
+
|
| 153 |
+
if timesteps is None:
|
| 154 |
+
timesteps = x.shape[0]
|
| 155 |
+
|
| 156 |
+
x_in = x
|
| 157 |
+
x = self.attention(x)
|
| 158 |
+
h, w = x.shape[2:]
|
| 159 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 160 |
+
|
| 161 |
+
x_mix = x
|
| 162 |
+
num_frames = torch.arange(timesteps, device=x.device)
|
| 163 |
+
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
| 164 |
+
num_frames = rearrange(num_frames, "b t -> (b t)")
|
| 165 |
+
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
| 166 |
+
emb = self.video_time_embed(t_emb) # b, n_channels
|
| 167 |
+
emb = emb[:, None, :]
|
| 168 |
+
x_mix = x_mix + emb
|
| 169 |
+
|
| 170 |
+
alpha = self.get_alpha().to(x.device)
|
| 171 |
+
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
| 172 |
+
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
| 173 |
+
|
| 174 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 175 |
+
x = self.proj_out(x)
|
| 176 |
+
|
| 177 |
+
return x_in + x
|
| 178 |
+
|
| 179 |
+
def get_alpha(
|
| 180 |
+
self,
|
| 181 |
+
):
|
| 182 |
+
if self.merge_strategy == "fixed":
|
| 183 |
+
return self.mix_factor
|
| 184 |
+
elif self.merge_strategy == "learned":
|
| 185 |
+
return torch.sigmoid(self.mix_factor)
|
| 186 |
+
else:
|
| 187 |
+
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def make_time_attn(
|
| 192 |
+
in_channels,
|
| 193 |
+
attn_type="vanilla",
|
| 194 |
+
attn_kwargs=None,
|
| 195 |
+
alpha: float = 0,
|
| 196 |
+
merge_strategy: str = "learned",
|
| 197 |
+
conv_op=ops.Conv2d,
|
| 198 |
+
):
|
| 199 |
+
return partialclass(
|
| 200 |
+
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Conv2DWrapper(torch.nn.Conv2d):
|
| 205 |
+
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 206 |
+
return super().forward(input)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class VideoDecoder(Decoder):
|
| 210 |
+
available_time_modes = ["all", "conv-only", "attn-only"]
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
*args,
|
| 215 |
+
video_kernel_size: Union[int, list] = 3,
|
| 216 |
+
alpha: float = 0.0,
|
| 217 |
+
merge_strategy: str = "learned",
|
| 218 |
+
time_mode: str = "conv-only",
|
| 219 |
+
**kwargs,
|
| 220 |
+
):
|
| 221 |
+
self.video_kernel_size = video_kernel_size
|
| 222 |
+
self.alpha = alpha
|
| 223 |
+
self.merge_strategy = merge_strategy
|
| 224 |
+
self.time_mode = time_mode
|
| 225 |
+
assert (
|
| 226 |
+
self.time_mode in self.available_time_modes
|
| 227 |
+
), f"time_mode parameter has to be in {self.available_time_modes}"
|
| 228 |
+
|
| 229 |
+
if self.time_mode != "attn-only":
|
| 230 |
+
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
| 231 |
+
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
| 232 |
+
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
| 233 |
+
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
| 234 |
+
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
| 235 |
+
|
| 236 |
+
super().__init__(*args, **kwargs)
|
| 237 |
+
|
| 238 |
+
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
| 239 |
+
if self.time_mode == "attn-only":
|
| 240 |
+
raise NotImplementedError("TODO")
|
| 241 |
+
else:
|
| 242 |
+
return (
|
| 243 |
+
self.conv_out.time_mix_conv.weight
|
| 244 |
+
if not skip_time_mix
|
| 245 |
+
else self.conv_out.weight
|
| 246 |
+
)
|
ComfyUI/comfy/ldm/omnigen/omnigen2.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Original code: https://github.com/VectorSpaceLab/OmniGen2
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
from comfy.ldm.lightricks.model import Timesteps
|
| 10 |
+
from comfy.ldm.flux.layers import EmbedND
|
| 11 |
+
from comfy.ldm.modules.attention import optimized_attention_masked
|
| 12 |
+
import comfy.model_management
|
| 13 |
+
import comfy.ldm.common_dit
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def apply_rotary_emb(x, freqs_cis):
|
| 17 |
+
if x.shape[1] == 0:
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
| 21 |
+
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
| 22 |
+
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
return F.silu(x) * y
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TimestepEmbedding(nn.Module):
|
| 30 |
+
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
| 33 |
+
self.act = nn.SiLU()
|
| 34 |
+
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
| 35 |
+
|
| 36 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
sample = self.linear_1(sample)
|
| 38 |
+
sample = self.act(sample)
|
| 39 |
+
sample = self.linear_2(sample)
|
| 40 |
+
return sample
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class LuminaRMSNormZero(nn.Module):
|
| 44 |
+
def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.silu = nn.SiLU()
|
| 47 |
+
self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device)
|
| 48 |
+
self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device)
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 51 |
+
emb = self.linear(self.silu(emb))
|
| 52 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
| 53 |
+
x = self.norm(x) * (1 + scale_msa[:, None])
|
| 54 |
+
return x, gate_msa, scale_mlp, gate_mlp
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LuminaLayerNormContinuous(nn.Module):
|
| 58 |
+
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.silu = nn.SiLU()
|
| 61 |
+
self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device)
|
| 62 |
+
self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device)
|
| 63 |
+
self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None
|
| 64 |
+
|
| 65 |
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
| 67 |
+
x = self.norm(x) * (1 + emb)[:, None, :]
|
| 68 |
+
if self.linear_2 is not None:
|
| 69 |
+
x = self.linear_2(x)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LuminaFeedForward(nn.Module):
|
| 74 |
+
def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None):
|
| 75 |
+
super().__init__()
|
| 76 |
+
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
| 77 |
+
self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 78 |
+
self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device)
|
| 79 |
+
self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
h1, h2 = self.linear_1(x), self.linear_3(x)
|
| 83 |
+
return self.linear_2(swiglu(h1, h2))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
| 87 |
+
def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale)
|
| 90 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations)
|
| 91 |
+
self.caption_embedder = nn.Sequential(
|
| 92 |
+
operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device),
|
| 93 |
+
operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 97 |
+
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
| 98 |
+
time_embed = self.timestep_embedder(timestep_proj)
|
| 99 |
+
caption_embed = self.caption_embedder(text_hidden_states)
|
| 100 |
+
return time_embed, caption_embed
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Attention(nn.Module):
|
| 104 |
+
def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.heads = heads
|
| 107 |
+
self.kv_heads = kv_heads
|
| 108 |
+
self.dim_head = dim_head
|
| 109 |
+
self.scale = dim_head ** -0.5
|
| 110 |
+
|
| 111 |
+
self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device)
|
| 112 |
+
self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
| 113 |
+
self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device)
|
| 114 |
+
|
| 115 |
+
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
| 116 |
+
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
| 117 |
+
|
| 118 |
+
self.to_out = nn.Sequential(
|
| 119 |
+
operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device),
|
| 120 |
+
nn.Dropout(0.0)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 124 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 125 |
+
|
| 126 |
+
query = self.to_q(hidden_states)
|
| 127 |
+
key = self.to_k(encoder_hidden_states)
|
| 128 |
+
value = self.to_v(encoder_hidden_states)
|
| 129 |
+
|
| 130 |
+
query = query.view(batch_size, -1, self.heads, self.dim_head)
|
| 131 |
+
key = key.view(batch_size, -1, self.kv_heads, self.dim_head)
|
| 132 |
+
value = value.view(batch_size, -1, self.kv_heads, self.dim_head)
|
| 133 |
+
|
| 134 |
+
query = self.norm_q(query)
|
| 135 |
+
key = self.norm_k(key)
|
| 136 |
+
|
| 137 |
+
if image_rotary_emb is not None:
|
| 138 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 139 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 140 |
+
|
| 141 |
+
query = query.transpose(1, 2)
|
| 142 |
+
key = key.transpose(1, 2)
|
| 143 |
+
value = value.transpose(1, 2)
|
| 144 |
+
|
| 145 |
+
if self.kv_heads < self.heads:
|
| 146 |
+
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
| 147 |
+
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
|
| 148 |
+
|
| 149 |
+
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
|
| 150 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 151 |
+
return hidden_states
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class OmniGen2TransformerBlock(nn.Module):
|
| 155 |
+
def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.modulation = modulation
|
| 158 |
+
|
| 159 |
+
self.attn = Attention(
|
| 160 |
+
query_dim=dim,
|
| 161 |
+
dim_head=dim // num_attention_heads,
|
| 162 |
+
heads=num_attention_heads,
|
| 163 |
+
kv_heads=num_kv_heads,
|
| 164 |
+
eps=1e-5,
|
| 165 |
+
bias=False,
|
| 166 |
+
dtype=dtype, device=device, operations=operations,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.feed_forward = LuminaFeedForward(
|
| 170 |
+
dim=dim,
|
| 171 |
+
inner_dim=4 * dim,
|
| 172 |
+
multiple_of=multiple_of,
|
| 173 |
+
dtype=dtype, device=device, operations=operations
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
if modulation:
|
| 177 |
+
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
| 178 |
+
else:
|
| 179 |
+
self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
| 180 |
+
|
| 181 |
+
self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
| 182 |
+
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
| 183 |
+
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
| 184 |
+
|
| 185 |
+
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 186 |
+
if self.modulation:
|
| 187 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 188 |
+
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
| 189 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 190 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 191 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 192 |
+
else:
|
| 193 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 194 |
+
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
|
| 195 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 196 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 197 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 198 |
+
return hidden_states
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class OmniGen2RotaryPosEmbed(nn.Module):
|
| 202 |
+
def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.theta = theta
|
| 205 |
+
self.axes_dim = axes_dim
|
| 206 |
+
self.axes_lens = axes_lens
|
| 207 |
+
self.patch_size = patch_size
|
| 208 |
+
self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim)
|
| 209 |
+
|
| 210 |
+
def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device):
|
| 211 |
+
p = self.patch_size
|
| 212 |
+
|
| 213 |
+
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
|
| 214 |
+
|
| 215 |
+
max_seq_len = max(seq_lengths)
|
| 216 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 217 |
+
max_img_len = max(l_effective_img_len)
|
| 218 |
+
|
| 219 |
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
| 220 |
+
|
| 221 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 222 |
+
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
|
| 223 |
+
|
| 224 |
+
pe_shift = cap_seq_len
|
| 225 |
+
pe_shift_len = cap_seq_len
|
| 226 |
+
|
| 227 |
+
if ref_img_sizes[i] is not None:
|
| 228 |
+
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
| 229 |
+
H, W = ref_img_size
|
| 230 |
+
ref_H_tokens, ref_W_tokens = H // p, W // p
|
| 231 |
+
|
| 232 |
+
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
|
| 233 |
+
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
|
| 234 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
| 235 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
| 236 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
| 237 |
+
|
| 238 |
+
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
| 239 |
+
pe_shift_len += ref_img_len
|
| 240 |
+
|
| 241 |
+
H, W = img_sizes[i]
|
| 242 |
+
H_tokens, W_tokens = H // p, W // p
|
| 243 |
+
|
| 244 |
+
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
|
| 245 |
+
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
|
| 246 |
+
|
| 247 |
+
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
| 248 |
+
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
| 249 |
+
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
| 250 |
+
|
| 251 |
+
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
|
| 252 |
+
|
| 253 |
+
cap_freqs_cis_shape = list(freqs_cis.shape)
|
| 254 |
+
cap_freqs_cis_shape[1] = encoder_seq_len
|
| 255 |
+
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
| 256 |
+
|
| 257 |
+
ref_img_freqs_cis_shape = list(freqs_cis.shape)
|
| 258 |
+
ref_img_freqs_cis_shape[1] = max_ref_img_len
|
| 259 |
+
ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
| 260 |
+
|
| 261 |
+
img_freqs_cis_shape = list(freqs_cis.shape)
|
| 262 |
+
img_freqs_cis_shape[1] = max_img_len
|
| 263 |
+
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
| 264 |
+
|
| 265 |
+
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
|
| 266 |
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
| 267 |
+
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
|
| 268 |
+
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
|
| 269 |
+
|
| 270 |
+
return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class OmniGen2Transformer2DModel(nn.Module):
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
patch_size: int = 2,
|
| 277 |
+
in_channels: int = 16,
|
| 278 |
+
out_channels: Optional[int] = None,
|
| 279 |
+
hidden_size: int = 2304,
|
| 280 |
+
num_layers: int = 26,
|
| 281 |
+
num_refiner_layers: int = 2,
|
| 282 |
+
num_attention_heads: int = 24,
|
| 283 |
+
num_kv_heads: int = 8,
|
| 284 |
+
multiple_of: int = 256,
|
| 285 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 286 |
+
norm_eps: float = 1e-5,
|
| 287 |
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
| 288 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 289 |
+
text_feat_dim: int = 1024,
|
| 290 |
+
timestep_scale: float = 1.0,
|
| 291 |
+
image_model=None,
|
| 292 |
+
device=None,
|
| 293 |
+
dtype=None,
|
| 294 |
+
operations=None,
|
| 295 |
+
):
|
| 296 |
+
super().__init__()
|
| 297 |
+
|
| 298 |
+
self.patch_size = patch_size
|
| 299 |
+
self.out_channels = out_channels or in_channels
|
| 300 |
+
self.hidden_size = hidden_size
|
| 301 |
+
self.dtype = dtype
|
| 302 |
+
|
| 303 |
+
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
| 304 |
+
theta=10000,
|
| 305 |
+
axes_dim=axes_dim_rope,
|
| 306 |
+
axes_lens=axes_lens,
|
| 307 |
+
patch_size=patch_size,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
| 311 |
+
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
| 312 |
+
|
| 313 |
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
| 314 |
+
hidden_size=hidden_size,
|
| 315 |
+
text_feat_dim=text_feat_dim,
|
| 316 |
+
norm_eps=norm_eps,
|
| 317 |
+
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
self.noise_refiner = nn.ModuleList([
|
| 321 |
+
OmniGen2TransformerBlock(
|
| 322 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 323 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
| 324 |
+
) for _ in range(num_refiner_layers)
|
| 325 |
+
])
|
| 326 |
+
|
| 327 |
+
self.ref_image_refiner = nn.ModuleList([
|
| 328 |
+
OmniGen2TransformerBlock(
|
| 329 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 330 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
| 331 |
+
) for _ in range(num_refiner_layers)
|
| 332 |
+
])
|
| 333 |
+
|
| 334 |
+
self.context_refiner = nn.ModuleList([
|
| 335 |
+
OmniGen2TransformerBlock(
|
| 336 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 337 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations
|
| 338 |
+
) for _ in range(num_refiner_layers)
|
| 339 |
+
])
|
| 340 |
+
|
| 341 |
+
self.layers = nn.ModuleList([
|
| 342 |
+
OmniGen2TransformerBlock(
|
| 343 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 344 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations
|
| 345 |
+
) for _ in range(num_layers)
|
| 346 |
+
])
|
| 347 |
+
|
| 348 |
+
self.norm_out = LuminaLayerNormContinuous(
|
| 349 |
+
embedding_dim=hidden_size,
|
| 350 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
| 351 |
+
elementwise_affine=False,
|
| 352 |
+
eps=1e-6,
|
| 353 |
+
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
|
| 357 |
+
|
| 358 |
+
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
| 359 |
+
batch_size = len(hidden_states)
|
| 360 |
+
p = self.patch_size
|
| 361 |
+
|
| 362 |
+
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
| 363 |
+
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
| 364 |
+
|
| 365 |
+
if ref_image_hidden_states is not None:
|
| 366 |
+
ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states))
|
| 367 |
+
ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size
|
| 368 |
+
l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
|
| 369 |
+
else:
|
| 370 |
+
ref_img_sizes = [None for _ in range(batch_size)]
|
| 371 |
+
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
| 372 |
+
|
| 373 |
+
flat_ref_img_hidden_states = None
|
| 374 |
+
if ref_image_hidden_states is not None:
|
| 375 |
+
imgs = []
|
| 376 |
+
for ref_img in ref_image_hidden_states:
|
| 377 |
+
B, C, H, W = ref_img.size()
|
| 378 |
+
ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
| 379 |
+
imgs.append(ref_img)
|
| 380 |
+
flat_ref_img_hidden_states = torch.cat(imgs, dim=1)
|
| 381 |
+
|
| 382 |
+
img = hidden_states
|
| 383 |
+
B, C, H, W = img.size()
|
| 384 |
+
flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
| 385 |
+
|
| 386 |
+
return (
|
| 387 |
+
flat_hidden_states, flat_ref_img_hidden_states,
|
| 388 |
+
None, None,
|
| 389 |
+
l_effective_ref_img_len, l_effective_img_len,
|
| 390 |
+
ref_img_sizes, img_sizes,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
|
| 394 |
+
batch_size = len(hidden_states)
|
| 395 |
+
|
| 396 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 397 |
+
if ref_image_hidden_states is not None:
|
| 398 |
+
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
| 399 |
+
image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device)
|
| 400 |
+
|
| 401 |
+
for i in range(batch_size):
|
| 402 |
+
shift = 0
|
| 403 |
+
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
| 404 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j]
|
| 405 |
+
shift += ref_img_len
|
| 406 |
+
|
| 407 |
+
for layer in self.noise_refiner:
|
| 408 |
+
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
| 409 |
+
|
| 410 |
+
if ref_image_hidden_states is not None:
|
| 411 |
+
for layer in self.ref_image_refiner:
|
| 412 |
+
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
|
| 413 |
+
|
| 414 |
+
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
|
| 415 |
+
|
| 416 |
+
return hidden_states
|
| 417 |
+
|
| 418 |
+
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
|
| 419 |
+
B, C, H, W = x.shape
|
| 420 |
+
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
| 421 |
+
_, _, H_padded, W_padded = hidden_states.shape
|
| 422 |
+
timestep = 1.0 - timesteps
|
| 423 |
+
text_hidden_states = context
|
| 424 |
+
text_attention_mask = attention_mask
|
| 425 |
+
ref_image_hidden_states = ref_latents
|
| 426 |
+
device = hidden_states.device
|
| 427 |
+
|
| 428 |
+
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
| 429 |
+
|
| 430 |
+
(
|
| 431 |
+
hidden_states, ref_image_hidden_states,
|
| 432 |
+
img_mask, ref_img_mask,
|
| 433 |
+
l_effective_ref_img_len, l_effective_img_len,
|
| 434 |
+
ref_img_sizes, img_sizes,
|
| 435 |
+
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
| 436 |
+
|
| 437 |
+
(
|
| 438 |
+
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
|
| 439 |
+
rotary_emb, encoder_seq_lengths, seq_lengths,
|
| 440 |
+
) = self.rope_embedder(
|
| 441 |
+
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
|
| 442 |
+
l_effective_ref_img_len, l_effective_img_len,
|
| 443 |
+
ref_img_sizes, img_sizes, device,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
for layer in self.context_refiner:
|
| 447 |
+
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
| 448 |
+
|
| 449 |
+
img_len = hidden_states.shape[1]
|
| 450 |
+
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
| 451 |
+
hidden_states, ref_image_hidden_states,
|
| 452 |
+
img_mask, ref_img_mask,
|
| 453 |
+
noise_rotary_emb, ref_img_rotary_emb,
|
| 454 |
+
l_effective_ref_img_len, l_effective_img_len,
|
| 455 |
+
temb,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
| 459 |
+
attention_mask = None
|
| 460 |
+
|
| 461 |
+
for layer in self.layers:
|
| 462 |
+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
| 463 |
+
|
| 464 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 465 |
+
|
| 466 |
+
p = self.patch_size
|
| 467 |
+
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W]
|
| 468 |
+
|
| 469 |
+
return -output
|
ComfyUI/comfy/ldm/pixart/pixartms.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Based on:
|
| 2 |
+
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
| 3 |
+
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .blocks import (
|
| 8 |
+
t2i_modulate,
|
| 9 |
+
CaptionEmbedder,
|
| 10 |
+
AttentionKVCompress,
|
| 11 |
+
MultiHeadCrossAttention,
|
| 12 |
+
T2IFinalLayer,
|
| 13 |
+
SizeEmbedder,
|
| 14 |
+
)
|
| 15 |
+
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
| 19 |
+
grid_h, grid_w = torch.meshgrid(
|
| 20 |
+
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
| 21 |
+
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
| 22 |
+
indexing='ij'
|
| 23 |
+
)
|
| 24 |
+
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
| 25 |
+
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
| 26 |
+
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
| 27 |
+
return emb
|
| 28 |
+
|
| 29 |
+
class PixArtMSBlock(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
| 34 |
+
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.hidden_size = hidden_size
|
| 37 |
+
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 38 |
+
self.attn = AttentionKVCompress(
|
| 39 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
| 40 |
+
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
| 41 |
+
)
|
| 42 |
+
self.cross_attn = MultiHeadCrossAttention(
|
| 43 |
+
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
| 44 |
+
)
|
| 45 |
+
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| 46 |
+
# to be compatible with lower version pytorch
|
| 47 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 48 |
+
self.mlp = Mlp(
|
| 49 |
+
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
| 50 |
+
dtype=dtype, device=device, operations=operations
|
| 51 |
+
)
|
| 52 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
| 53 |
+
|
| 54 |
+
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
| 55 |
+
B, N, C = x.shape
|
| 56 |
+
|
| 57 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
| 58 |
+
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
| 59 |
+
x = x + self.cross_attn(x, y, mask)
|
| 60 |
+
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
| 61 |
+
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
### Core PixArt Model ###
|
| 66 |
+
class PixArtMS(nn.Module):
|
| 67 |
+
"""
|
| 68 |
+
Diffusion model with a Transformer backbone.
|
| 69 |
+
"""
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
input_size=32,
|
| 73 |
+
patch_size=2,
|
| 74 |
+
in_channels=4,
|
| 75 |
+
hidden_size=1152,
|
| 76 |
+
depth=28,
|
| 77 |
+
num_heads=16,
|
| 78 |
+
mlp_ratio=4.0,
|
| 79 |
+
class_dropout_prob=0.1,
|
| 80 |
+
learn_sigma=True,
|
| 81 |
+
pred_sigma=True,
|
| 82 |
+
drop_path: float = 0.,
|
| 83 |
+
caption_channels=4096,
|
| 84 |
+
pe_interpolation=None,
|
| 85 |
+
pe_precision=None,
|
| 86 |
+
config=None,
|
| 87 |
+
model_max_length=120,
|
| 88 |
+
micro_condition=True,
|
| 89 |
+
qk_norm=False,
|
| 90 |
+
kv_compress_config=None,
|
| 91 |
+
dtype=None,
|
| 92 |
+
device=None,
|
| 93 |
+
operations=None,
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
nn.Module.__init__(self)
|
| 97 |
+
self.dtype = dtype
|
| 98 |
+
self.pred_sigma = pred_sigma
|
| 99 |
+
self.in_channels = in_channels
|
| 100 |
+
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
| 101 |
+
self.patch_size = patch_size
|
| 102 |
+
self.num_heads = num_heads
|
| 103 |
+
self.pe_interpolation = pe_interpolation
|
| 104 |
+
self.pe_precision = pe_precision
|
| 105 |
+
self.hidden_size = hidden_size
|
| 106 |
+
self.depth = depth
|
| 107 |
+
|
| 108 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 109 |
+
self.t_block = nn.Sequential(
|
| 110 |
+
nn.SiLU(),
|
| 111 |
+
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
| 112 |
+
)
|
| 113 |
+
self.x_embedder = PatchEmbed(
|
| 114 |
+
patch_size=patch_size,
|
| 115 |
+
in_chans=in_channels,
|
| 116 |
+
embed_dim=hidden_size,
|
| 117 |
+
bias=True,
|
| 118 |
+
dtype=dtype,
|
| 119 |
+
device=device,
|
| 120 |
+
operations=operations
|
| 121 |
+
)
|
| 122 |
+
self.t_embedder = TimestepEmbedder(
|
| 123 |
+
hidden_size, dtype=dtype, device=device, operations=operations,
|
| 124 |
+
)
|
| 125 |
+
self.y_embedder = CaptionEmbedder(
|
| 126 |
+
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
| 127 |
+
act_layer=approx_gelu, token_num=model_max_length,
|
| 128 |
+
dtype=dtype, device=device, operations=operations,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.micro_conditioning = micro_condition
|
| 132 |
+
if self.micro_conditioning:
|
| 133 |
+
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
| 134 |
+
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
| 135 |
+
|
| 136 |
+
# For fixed sin-cos embedding:
|
| 137 |
+
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
| 138 |
+
# self.base_size = input_size // self.patch_size
|
| 139 |
+
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
| 140 |
+
|
| 141 |
+
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
| 142 |
+
if kv_compress_config is None:
|
| 143 |
+
kv_compress_config = {
|
| 144 |
+
'sampling': None,
|
| 145 |
+
'scale_factor': 1,
|
| 146 |
+
'kv_compress_layer': [],
|
| 147 |
+
}
|
| 148 |
+
self.blocks = nn.ModuleList([
|
| 149 |
+
PixArtMSBlock(
|
| 150 |
+
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
| 151 |
+
sampling=kv_compress_config['sampling'],
|
| 152 |
+
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
| 153 |
+
qk_norm=qk_norm,
|
| 154 |
+
dtype=dtype,
|
| 155 |
+
device=device,
|
| 156 |
+
operations=operations,
|
| 157 |
+
)
|
| 158 |
+
for i in range(depth)
|
| 159 |
+
])
|
| 160 |
+
self.final_layer = T2IFinalLayer(
|
| 161 |
+
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
| 165 |
+
"""
|
| 166 |
+
Original forward pass of PixArt.
|
| 167 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 168 |
+
t: (N,) tensor of diffusion timesteps
|
| 169 |
+
y: (N, 1, 120, C) conditioning
|
| 170 |
+
ar: (N, 1): aspect ratio
|
| 171 |
+
cs: (N ,2) size conditioning for height/width
|
| 172 |
+
"""
|
| 173 |
+
B, C, H, W = x.shape
|
| 174 |
+
c_res = (H + W) // 2
|
| 175 |
+
pe_interpolation = self.pe_interpolation
|
| 176 |
+
if pe_interpolation is None or self.pe_precision is not None:
|
| 177 |
+
# calculate pe_interpolation on-the-fly
|
| 178 |
+
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
| 179 |
+
|
| 180 |
+
pos_embed = get_2d_sincos_pos_embed_torch(
|
| 181 |
+
self.hidden_size,
|
| 182 |
+
h=(H // self.patch_size),
|
| 183 |
+
w=(W // self.patch_size),
|
| 184 |
+
pe_interpolation=pe_interpolation,
|
| 185 |
+
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
| 186 |
+
device=x.device,
|
| 187 |
+
dtype=x.dtype,
|
| 188 |
+
).unsqueeze(0)
|
| 189 |
+
|
| 190 |
+
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
| 191 |
+
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
| 192 |
+
|
| 193 |
+
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
| 194 |
+
bs = x.shape[0]
|
| 195 |
+
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
| 196 |
+
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
| 197 |
+
t = t + torch.cat([c_size, c_ar], dim=1)
|
| 198 |
+
|
| 199 |
+
t0 = self.t_block(t)
|
| 200 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 201 |
+
|
| 202 |
+
if mask is not None:
|
| 203 |
+
if mask.shape[0] != y.shape[0]:
|
| 204 |
+
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
| 205 |
+
mask = mask.squeeze(1).squeeze(1)
|
| 206 |
+
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
| 207 |
+
y_lens = mask.sum(dim=1).tolist()
|
| 208 |
+
else:
|
| 209 |
+
y_lens = None
|
| 210 |
+
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
| 211 |
+
for block in self.blocks:
|
| 212 |
+
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
| 213 |
+
|
| 214 |
+
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
| 215 |
+
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
| 216 |
+
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
| 220 |
+
B, C, H, W = x.shape
|
| 221 |
+
|
| 222 |
+
# Fallback for missing microconds
|
| 223 |
+
if self.micro_conditioning:
|
| 224 |
+
if c_size is None:
|
| 225 |
+
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
| 226 |
+
|
| 227 |
+
if c_ar is None:
|
| 228 |
+
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
| 229 |
+
|
| 230 |
+
## Still accepts the input w/o that dim but returns garbage
|
| 231 |
+
if len(context.shape) == 3:
|
| 232 |
+
context = context.unsqueeze(1)
|
| 233 |
+
|
| 234 |
+
## run original forward pass
|
| 235 |
+
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
| 236 |
+
|
| 237 |
+
## only return EPS
|
| 238 |
+
if self.pred_sigma:
|
| 239 |
+
return out[:, :self.in_channels]
|
| 240 |
+
return out
|
| 241 |
+
|
| 242 |
+
def unpatchify(self, x, h, w):
|
| 243 |
+
"""
|
| 244 |
+
x: (N, T, patch_size**2 * C)
|
| 245 |
+
imgs: (N, H, W, C)
|
| 246 |
+
"""
|
| 247 |
+
c = self.out_channels
|
| 248 |
+
p = self.x_embedder.patch_size[0]
|
| 249 |
+
h = h // self.patch_size
|
| 250 |
+
w = w // self.patch_size
|
| 251 |
+
assert h * w == x.shape[1]
|
| 252 |
+
|
| 253 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 254 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 255 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 256 |
+
return imgs
|