Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet_union.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/stable_audio_transformer.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/t5_film_transformer.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_2d.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_allegro.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_bria.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_chroma.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cogview3plus.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cogview4.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cosmos.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_easyanimate.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_flux.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hidream_image.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hunyuan_video.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hunyuan_video_framepack.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_ltx.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_lumina2.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_mochi.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_omnigen.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_qwenimage.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_sd3.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_skyreels_v2.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_temporal.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_wan.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_wan_vace.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/auraflow_transformer_2d.py +564 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py +531 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/dit_transformer_2d.py +226 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/dual_transformer_2d.py +156 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/hunyuan_transformer_2d.py +579 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/latte_transformer_3d.py +331 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/t5_film_transformer.py +436 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_ltx.py +568 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_lumina2.py +548 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__init__.py +18 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_1d.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_1d_blocks.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_blocks.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_blocks_flax.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition_flax.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_3d_blocks.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_3d_condition.cpython-310.pyc +0 -0
pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet.cpython-310.pyc
ADDED
|
Binary file (8.55 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet_union.cpython-310.pyc
ADDED
|
Binary file (8.86 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/stable_audio_transformer.cpython-310.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/t5_film_transformer.cpython-310.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_2d.cpython-310.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_allegro.cpython-310.pyc
ADDED
|
Binary file (8.29 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_bria.cpython-310.pyc
ADDED
|
Binary file (21.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_chroma.cpython-310.pyc
ADDED
|
Binary file (19.2 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cogview3plus.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cogview4.cpython-310.pyc
ADDED
|
Binary file (21.5 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cosmos.cpython-310.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_easyanimate.cpython-310.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_flux.cpython-310.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hidream_image.cpython-310.pyc
ADDED
|
Binary file (25.5 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hunyuan_video.cpython-310.pyc
ADDED
|
Binary file (30.1 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hunyuan_video_framepack.cpython-310.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_ltx.cpython-310.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_lumina2.cpython-310.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_mochi.cpython-310.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_omnigen.cpython-310.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_qwenimage.cpython-310.pyc
ADDED
|
Binary file (19.4 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_sd3.cpython-310.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_skyreels_v2.cpython-310.pyc
ADDED
|
Binary file (21.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_temporal.cpython-310.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_wan.cpython-310.pyc
ADDED
|
Binary file (18.9 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_wan_vace.cpython-310.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/auraflow_transformer_2d.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 AuraFlow Authors, The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, Optional, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 24 |
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 25 |
+
from ...utils.torch_utils import maybe_allow_in_graph
|
| 26 |
+
from ..attention_processor import (
|
| 27 |
+
Attention,
|
| 28 |
+
AttentionProcessor,
|
| 29 |
+
AuraFlowAttnProcessor2_0,
|
| 30 |
+
FusedAuraFlowAttnProcessor2_0,
|
| 31 |
+
)
|
| 32 |
+
from ..embeddings import TimestepEmbedding, Timesteps
|
| 33 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 34 |
+
from ..modeling_utils import ModelMixin
|
| 35 |
+
from ..normalization import AdaLayerNormZero, FP32LayerNorm
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Taken from the original aura flow inference code.
|
| 42 |
+
def find_multiple(n: int, k: int) -> int:
|
| 43 |
+
if n % k == 0:
|
| 44 |
+
return n
|
| 45 |
+
return n + k - (n % k)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Aura Flow patch embed doesn't use convs for projections.
|
| 49 |
+
# Additionally, it uses learned positional embeddings.
|
| 50 |
+
class AuraFlowPatchEmbed(nn.Module):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
height=224,
|
| 54 |
+
width=224,
|
| 55 |
+
patch_size=16,
|
| 56 |
+
in_channels=3,
|
| 57 |
+
embed_dim=768,
|
| 58 |
+
pos_embed_max_size=None,
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
self.num_patches = (height // patch_size) * (width // patch_size)
|
| 63 |
+
self.pos_embed_max_size = pos_embed_max_size
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
|
| 66 |
+
self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
|
| 67 |
+
|
| 68 |
+
self.patch_size = patch_size
|
| 69 |
+
self.height, self.width = height // patch_size, width // patch_size
|
| 70 |
+
self.base_size = height // patch_size
|
| 71 |
+
|
| 72 |
+
def pe_selection_index_based_on_dim(self, h, w):
|
| 73 |
+
# select subset of positional embedding based on H, W, where H, W is size of latent
|
| 74 |
+
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
|
| 75 |
+
# because original input are in flattened format, we have to flatten this 2d grid as well.
|
| 76 |
+
h_p, w_p = h // self.patch_size, w // self.patch_size
|
| 77 |
+
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
|
| 78 |
+
|
| 79 |
+
# Calculate the top-left corner indices for the centered patch grid
|
| 80 |
+
starth = h_max // 2 - h_p // 2
|
| 81 |
+
startw = w_max // 2 - w_p // 2
|
| 82 |
+
|
| 83 |
+
# Generate the row and column indices for the desired patch grid
|
| 84 |
+
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
|
| 85 |
+
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
|
| 86 |
+
|
| 87 |
+
# Create a 2D grid of indices
|
| 88 |
+
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
|
| 89 |
+
|
| 90 |
+
# Convert the 2D grid indices to flattened 1D indices
|
| 91 |
+
selected_indices = (row_indices * w_max + col_indices).flatten()
|
| 92 |
+
|
| 93 |
+
return selected_indices
|
| 94 |
+
|
| 95 |
+
def forward(self, latent):
|
| 96 |
+
batch_size, num_channels, height, width = latent.size()
|
| 97 |
+
latent = latent.view(
|
| 98 |
+
batch_size,
|
| 99 |
+
num_channels,
|
| 100 |
+
height // self.patch_size,
|
| 101 |
+
self.patch_size,
|
| 102 |
+
width // self.patch_size,
|
| 103 |
+
self.patch_size,
|
| 104 |
+
)
|
| 105 |
+
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
| 106 |
+
latent = self.proj(latent)
|
| 107 |
+
pe_index = self.pe_selection_index_based_on_dim(height, width)
|
| 108 |
+
return latent + self.pos_embed[:, pe_index]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Taken from the original Aura flow inference code.
|
| 112 |
+
# Our feedforward only has GELU but Aura uses SiLU.
|
| 113 |
+
class AuraFlowFeedForward(nn.Module):
|
| 114 |
+
def __init__(self, dim, hidden_dim=None) -> None:
|
| 115 |
+
super().__init__()
|
| 116 |
+
if hidden_dim is None:
|
| 117 |
+
hidden_dim = 4 * dim
|
| 118 |
+
|
| 119 |
+
final_hidden_dim = int(2 * hidden_dim / 3)
|
| 120 |
+
final_hidden_dim = find_multiple(final_hidden_dim, 256)
|
| 121 |
+
|
| 122 |
+
self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False)
|
| 123 |
+
self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False)
|
| 124 |
+
self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False)
|
| 125 |
+
|
| 126 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 127 |
+
x = F.silu(self.linear_1(x)) * self.linear_2(x)
|
| 128 |
+
x = self.out_projection(x)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class AuraFlowPreFinalBlock(nn.Module):
|
| 133 |
+
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int):
|
| 134 |
+
super().__init__()
|
| 135 |
+
|
| 136 |
+
self.silu = nn.SiLU()
|
| 137 |
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False)
|
| 138 |
+
|
| 139 |
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
| 140 |
+
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
| 141 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 142 |
+
x = x * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@maybe_allow_in_graph
|
| 147 |
+
class AuraFlowSingleTransformerBlock(nn.Module):
|
| 148 |
+
"""Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
|
| 149 |
+
|
| 150 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim):
|
| 151 |
+
super().__init__()
|
| 152 |
+
|
| 153 |
+
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
| 154 |
+
|
| 155 |
+
processor = AuraFlowAttnProcessor2_0()
|
| 156 |
+
self.attn = Attention(
|
| 157 |
+
query_dim=dim,
|
| 158 |
+
cross_attention_dim=None,
|
| 159 |
+
dim_head=attention_head_dim,
|
| 160 |
+
heads=num_attention_heads,
|
| 161 |
+
qk_norm="fp32_layer_norm",
|
| 162 |
+
out_dim=dim,
|
| 163 |
+
bias=False,
|
| 164 |
+
out_bias=False,
|
| 165 |
+
processor=processor,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
| 169 |
+
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
| 170 |
+
|
| 171 |
+
def forward(
|
| 172 |
+
self,
|
| 173 |
+
hidden_states: torch.FloatTensor,
|
| 174 |
+
temb: torch.FloatTensor,
|
| 175 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 176 |
+
):
|
| 177 |
+
residual = hidden_states
|
| 178 |
+
attention_kwargs = attention_kwargs or {}
|
| 179 |
+
|
| 180 |
+
# Norm + Projection.
|
| 181 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 182 |
+
|
| 183 |
+
# Attention.
|
| 184 |
+
attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
|
| 185 |
+
|
| 186 |
+
# Process attention outputs for the `hidden_states`.
|
| 187 |
+
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
| 188 |
+
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 189 |
+
ff_output = self.ff(hidden_states)
|
| 190 |
+
hidden_states = gate_mlp.unsqueeze(1) * ff_output
|
| 191 |
+
hidden_states = residual + hidden_states
|
| 192 |
+
|
| 193 |
+
return hidden_states
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@maybe_allow_in_graph
|
| 197 |
+
class AuraFlowJointTransformerBlock(nn.Module):
|
| 198 |
+
r"""
|
| 199 |
+
Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
|
| 200 |
+
|
| 201 |
+
* QK Norm in the attention blocks
|
| 202 |
+
* No bias in the attention blocks
|
| 203 |
+
* Most LayerNorms are in FP32
|
| 204 |
+
|
| 205 |
+
Parameters:
|
| 206 |
+
dim (`int`): The number of channels in the input and output.
|
| 207 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 208 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 209 |
+
is_last (`bool`): Boolean to determine if this is the last block in the model.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim):
|
| 213 |
+
super().__init__()
|
| 214 |
+
|
| 215 |
+
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
| 216 |
+
self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
| 217 |
+
|
| 218 |
+
processor = AuraFlowAttnProcessor2_0()
|
| 219 |
+
self.attn = Attention(
|
| 220 |
+
query_dim=dim,
|
| 221 |
+
cross_attention_dim=None,
|
| 222 |
+
added_kv_proj_dim=dim,
|
| 223 |
+
added_proj_bias=False,
|
| 224 |
+
dim_head=attention_head_dim,
|
| 225 |
+
heads=num_attention_heads,
|
| 226 |
+
qk_norm="fp32_layer_norm",
|
| 227 |
+
out_dim=dim,
|
| 228 |
+
bias=False,
|
| 229 |
+
out_bias=False,
|
| 230 |
+
processor=processor,
|
| 231 |
+
context_pre_only=False,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
| 235 |
+
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
| 236 |
+
self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
| 237 |
+
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
|
| 238 |
+
|
| 239 |
+
def forward(
|
| 240 |
+
self,
|
| 241 |
+
hidden_states: torch.FloatTensor,
|
| 242 |
+
encoder_hidden_states: torch.FloatTensor,
|
| 243 |
+
temb: torch.FloatTensor,
|
| 244 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 245 |
+
):
|
| 246 |
+
residual = hidden_states
|
| 247 |
+
residual_context = encoder_hidden_states
|
| 248 |
+
attention_kwargs = attention_kwargs or {}
|
| 249 |
+
|
| 250 |
+
# Norm + Projection.
|
| 251 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 252 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 253 |
+
encoder_hidden_states, emb=temb
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Attention.
|
| 257 |
+
attn_output, context_attn_output = self.attn(
|
| 258 |
+
hidden_states=norm_hidden_states,
|
| 259 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 260 |
+
**attention_kwargs,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Process attention outputs for the `hidden_states`.
|
| 264 |
+
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
| 265 |
+
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 266 |
+
hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states)
|
| 267 |
+
hidden_states = residual + hidden_states
|
| 268 |
+
|
| 269 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 270 |
+
encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output)
|
| 271 |
+
encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 272 |
+
encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states)
|
| 273 |
+
encoder_hidden_states = residual_context + encoder_hidden_states
|
| 274 |
+
|
| 275 |
+
return encoder_hidden_states, hidden_states
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 279 |
+
r"""
|
| 280 |
+
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
| 281 |
+
|
| 282 |
+
Parameters:
|
| 283 |
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
| 284 |
+
it is used to learn a number of position embeddings.
|
| 285 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
| 286 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
|
| 287 |
+
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
|
| 288 |
+
num_single_dit_layers (`int`, *optional*, defaults to 32):
|
| 289 |
+
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
|
| 290 |
+
representations.
|
| 291 |
+
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
|
| 292 |
+
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
|
| 293 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
| 294 |
+
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
| 295 |
+
out_channels (`int`, defaults to 4): Number of output channels.
|
| 296 |
+
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
| 300 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 301 |
+
_supports_gradient_checkpointing = True
|
| 302 |
+
|
| 303 |
+
@register_to_config
|
| 304 |
+
def __init__(
|
| 305 |
+
self,
|
| 306 |
+
sample_size: int = 64,
|
| 307 |
+
patch_size: int = 2,
|
| 308 |
+
in_channels: int = 4,
|
| 309 |
+
num_mmdit_layers: int = 4,
|
| 310 |
+
num_single_dit_layers: int = 32,
|
| 311 |
+
attention_head_dim: int = 256,
|
| 312 |
+
num_attention_heads: int = 12,
|
| 313 |
+
joint_attention_dim: int = 2048,
|
| 314 |
+
caption_projection_dim: int = 3072,
|
| 315 |
+
out_channels: int = 4,
|
| 316 |
+
pos_embed_max_size: int = 1024,
|
| 317 |
+
):
|
| 318 |
+
super().__init__()
|
| 319 |
+
default_out_channels = in_channels
|
| 320 |
+
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
| 321 |
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
| 322 |
+
|
| 323 |
+
self.pos_embed = AuraFlowPatchEmbed(
|
| 324 |
+
height=self.config.sample_size,
|
| 325 |
+
width=self.config.sample_size,
|
| 326 |
+
patch_size=self.config.patch_size,
|
| 327 |
+
in_channels=self.config.in_channels,
|
| 328 |
+
embed_dim=self.inner_dim,
|
| 329 |
+
pos_embed_max_size=pos_embed_max_size,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
self.context_embedder = nn.Linear(
|
| 333 |
+
self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False
|
| 334 |
+
)
|
| 335 |
+
self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True)
|
| 336 |
+
self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
|
| 337 |
+
|
| 338 |
+
self.joint_transformer_blocks = nn.ModuleList(
|
| 339 |
+
[
|
| 340 |
+
AuraFlowJointTransformerBlock(
|
| 341 |
+
dim=self.inner_dim,
|
| 342 |
+
num_attention_heads=self.config.num_attention_heads,
|
| 343 |
+
attention_head_dim=self.config.attention_head_dim,
|
| 344 |
+
)
|
| 345 |
+
for i in range(self.config.num_mmdit_layers)
|
| 346 |
+
]
|
| 347 |
+
)
|
| 348 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 349 |
+
[
|
| 350 |
+
AuraFlowSingleTransformerBlock(
|
| 351 |
+
dim=self.inner_dim,
|
| 352 |
+
num_attention_heads=self.config.num_attention_heads,
|
| 353 |
+
attention_head_dim=self.config.attention_head_dim,
|
| 354 |
+
)
|
| 355 |
+
for _ in range(self.config.num_single_dit_layers)
|
| 356 |
+
]
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
|
| 360 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
| 361 |
+
|
| 362 |
+
# https://huggingface.co/papers/2309.16588
|
| 363 |
+
# prevents artifacts in the attention maps
|
| 364 |
+
self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
|
| 365 |
+
|
| 366 |
+
self.gradient_checkpointing = False
|
| 367 |
+
|
| 368 |
+
@property
|
| 369 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 370 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 371 |
+
r"""
|
| 372 |
+
Returns:
|
| 373 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 374 |
+
indexed by its weight name.
|
| 375 |
+
"""
|
| 376 |
+
# set recursively
|
| 377 |
+
processors = {}
|
| 378 |
+
|
| 379 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 380 |
+
if hasattr(module, "get_processor"):
|
| 381 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 382 |
+
|
| 383 |
+
for sub_name, child in module.named_children():
|
| 384 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 385 |
+
|
| 386 |
+
return processors
|
| 387 |
+
|
| 388 |
+
for name, module in self.named_children():
|
| 389 |
+
fn_recursive_add_processors(name, module, processors)
|
| 390 |
+
|
| 391 |
+
return processors
|
| 392 |
+
|
| 393 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 394 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 395 |
+
r"""
|
| 396 |
+
Sets the attention processor to use to compute attention.
|
| 397 |
+
|
| 398 |
+
Parameters:
|
| 399 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 400 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 401 |
+
for **all** `Attention` layers.
|
| 402 |
+
|
| 403 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 404 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 405 |
+
|
| 406 |
+
"""
|
| 407 |
+
count = len(self.attn_processors.keys())
|
| 408 |
+
|
| 409 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 412 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 416 |
+
if hasattr(module, "set_processor"):
|
| 417 |
+
if not isinstance(processor, dict):
|
| 418 |
+
module.set_processor(processor)
|
| 419 |
+
else:
|
| 420 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 421 |
+
|
| 422 |
+
for sub_name, child in module.named_children():
|
| 423 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 424 |
+
|
| 425 |
+
for name, module in self.named_children():
|
| 426 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 427 |
+
|
| 428 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
|
| 429 |
+
def fuse_qkv_projections(self):
|
| 430 |
+
"""
|
| 431 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 432 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 433 |
+
|
| 434 |
+
<Tip warning={true}>
|
| 435 |
+
|
| 436 |
+
This API is 🧪 experimental.
|
| 437 |
+
|
| 438 |
+
</Tip>
|
| 439 |
+
"""
|
| 440 |
+
self.original_attn_processors = None
|
| 441 |
+
|
| 442 |
+
for _, attn_processor in self.attn_processors.items():
|
| 443 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 444 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 445 |
+
|
| 446 |
+
self.original_attn_processors = self.attn_processors
|
| 447 |
+
|
| 448 |
+
for module in self.modules():
|
| 449 |
+
if isinstance(module, Attention):
|
| 450 |
+
module.fuse_projections(fuse=True)
|
| 451 |
+
|
| 452 |
+
self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
|
| 453 |
+
|
| 454 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 455 |
+
def unfuse_qkv_projections(self):
|
| 456 |
+
"""Disables the fused QKV projection if enabled.
|
| 457 |
+
|
| 458 |
+
<Tip warning={true}>
|
| 459 |
+
|
| 460 |
+
This API is 🧪 experimental.
|
| 461 |
+
|
| 462 |
+
</Tip>
|
| 463 |
+
|
| 464 |
+
"""
|
| 465 |
+
if self.original_attn_processors is not None:
|
| 466 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 467 |
+
|
| 468 |
+
def forward(
|
| 469 |
+
self,
|
| 470 |
+
hidden_states: torch.FloatTensor,
|
| 471 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 472 |
+
timestep: torch.LongTensor = None,
|
| 473 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 474 |
+
return_dict: bool = True,
|
| 475 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
| 476 |
+
if attention_kwargs is not None:
|
| 477 |
+
attention_kwargs = attention_kwargs.copy()
|
| 478 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 479 |
+
else:
|
| 480 |
+
lora_scale = 1.0
|
| 481 |
+
|
| 482 |
+
if USE_PEFT_BACKEND:
|
| 483 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 484 |
+
scale_lora_layers(self, lora_scale)
|
| 485 |
+
else:
|
| 486 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 487 |
+
logger.warning(
|
| 488 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
height, width = hidden_states.shape[-2:]
|
| 492 |
+
|
| 493 |
+
# Apply patch embedding, timestep embedding, and project the caption embeddings.
|
| 494 |
+
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
| 495 |
+
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
|
| 496 |
+
temb = self.time_step_proj(temb)
|
| 497 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 498 |
+
encoder_hidden_states = torch.cat(
|
| 499 |
+
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# MMDiT blocks.
|
| 503 |
+
for index_block, block in enumerate(self.joint_transformer_blocks):
|
| 504 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 505 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 506 |
+
block,
|
| 507 |
+
hidden_states,
|
| 508 |
+
encoder_hidden_states,
|
| 509 |
+
temb,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
else:
|
| 513 |
+
encoder_hidden_states, hidden_states = block(
|
| 514 |
+
hidden_states=hidden_states,
|
| 515 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 516 |
+
temb=temb,
|
| 517 |
+
attention_kwargs=attention_kwargs,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
|
| 521 |
+
if len(self.single_transformer_blocks) > 0:
|
| 522 |
+
encoder_seq_len = encoder_hidden_states.size(1)
|
| 523 |
+
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 524 |
+
|
| 525 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 526 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 527 |
+
combined_hidden_states = self._gradient_checkpointing_func(
|
| 528 |
+
block,
|
| 529 |
+
combined_hidden_states,
|
| 530 |
+
temb,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
else:
|
| 534 |
+
combined_hidden_states = block(
|
| 535 |
+
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
hidden_states = combined_hidden_states[:, encoder_seq_len:]
|
| 539 |
+
|
| 540 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 541 |
+
hidden_states = self.proj_out(hidden_states)
|
| 542 |
+
|
| 543 |
+
# unpatchify
|
| 544 |
+
patch_size = self.config.patch_size
|
| 545 |
+
out_channels = self.config.out_channels
|
| 546 |
+
height = height // patch_size
|
| 547 |
+
width = width // patch_size
|
| 548 |
+
|
| 549 |
+
hidden_states = hidden_states.reshape(
|
| 550 |
+
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels)
|
| 551 |
+
)
|
| 552 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
| 553 |
+
output = hidden_states.reshape(
|
| 554 |
+
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
if USE_PEFT_BACKEND:
|
| 558 |
+
# remove `lora_scale` from each PEFT layer
|
| 559 |
+
unscale_lora_layers(self, lora_scale)
|
| 560 |
+
|
| 561 |
+
if not return_dict:
|
| 562 |
+
return (output,)
|
| 563 |
+
|
| 564 |
+
return Transformer2DModelOutput(sample=output)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from ...loaders import PeftAdapterMixin
|
| 23 |
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 24 |
+
from ...utils.torch_utils import maybe_allow_in_graph
|
| 25 |
+
from ..attention import Attention, FeedForward
|
| 26 |
+
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
| 27 |
+
from ..cache_utils import CacheMixin
|
| 28 |
+
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
| 29 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 30 |
+
from ..modeling_utils import ModelMixin
|
| 31 |
+
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@maybe_allow_in_graph
|
| 38 |
+
class CogVideoXBlock(nn.Module):
|
| 39 |
+
r"""
|
| 40 |
+
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
dim (`int`):
|
| 44 |
+
The number of channels in the input and output.
|
| 45 |
+
num_attention_heads (`int`):
|
| 46 |
+
The number of heads to use for multi-head attention.
|
| 47 |
+
attention_head_dim (`int`):
|
| 48 |
+
The number of channels in each head.
|
| 49 |
+
time_embed_dim (`int`):
|
| 50 |
+
The number of channels in timestep embedding.
|
| 51 |
+
dropout (`float`, defaults to `0.0`):
|
| 52 |
+
The dropout probability to use.
|
| 53 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 54 |
+
Activation function to be used in feed-forward.
|
| 55 |
+
attention_bias (`bool`, defaults to `False`):
|
| 56 |
+
Whether or not to use bias in attention projection layers.
|
| 57 |
+
qk_norm (`bool`, defaults to `True`):
|
| 58 |
+
Whether or not to use normalization after query and key projections in Attention.
|
| 59 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 60 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 61 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 62 |
+
Epsilon value for normalization layers.
|
| 63 |
+
final_dropout (`bool` defaults to `False`):
|
| 64 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 65 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
| 66 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
| 67 |
+
ff_bias (`bool`, defaults to `True`):
|
| 68 |
+
Whether or not to use bias in Feed-forward layer.
|
| 69 |
+
attention_out_bias (`bool`, defaults to `True`):
|
| 70 |
+
Whether or not to use bias in Attention output projection layer.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
dim: int,
|
| 76 |
+
num_attention_heads: int,
|
| 77 |
+
attention_head_dim: int,
|
| 78 |
+
time_embed_dim: int,
|
| 79 |
+
dropout: float = 0.0,
|
| 80 |
+
activation_fn: str = "gelu-approximate",
|
| 81 |
+
attention_bias: bool = False,
|
| 82 |
+
qk_norm: bool = True,
|
| 83 |
+
norm_elementwise_affine: bool = True,
|
| 84 |
+
norm_eps: float = 1e-5,
|
| 85 |
+
final_dropout: bool = True,
|
| 86 |
+
ff_inner_dim: Optional[int] = None,
|
| 87 |
+
ff_bias: bool = True,
|
| 88 |
+
attention_out_bias: bool = True,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
|
| 92 |
+
# 1. Self Attention
|
| 93 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 94 |
+
|
| 95 |
+
self.attn1 = Attention(
|
| 96 |
+
query_dim=dim,
|
| 97 |
+
dim_head=attention_head_dim,
|
| 98 |
+
heads=num_attention_heads,
|
| 99 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 100 |
+
eps=1e-6,
|
| 101 |
+
bias=attention_bias,
|
| 102 |
+
out_bias=attention_out_bias,
|
| 103 |
+
processor=CogVideoXAttnProcessor2_0(),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# 2. Feed Forward
|
| 107 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 108 |
+
|
| 109 |
+
self.ff = FeedForward(
|
| 110 |
+
dim,
|
| 111 |
+
dropout=dropout,
|
| 112 |
+
activation_fn=activation_fn,
|
| 113 |
+
final_dropout=final_dropout,
|
| 114 |
+
inner_dim=ff_inner_dim,
|
| 115 |
+
bias=ff_bias,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
hidden_states: torch.Tensor,
|
| 121 |
+
encoder_hidden_states: torch.Tensor,
|
| 122 |
+
temb: torch.Tensor,
|
| 123 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 124 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 125 |
+
) -> torch.Tensor:
|
| 126 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 127 |
+
attention_kwargs = attention_kwargs or {}
|
| 128 |
+
|
| 129 |
+
# norm & modulate
|
| 130 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
| 131 |
+
hidden_states, encoder_hidden_states, temb
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# attention
|
| 135 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
| 136 |
+
hidden_states=norm_hidden_states,
|
| 137 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 138 |
+
image_rotary_emb=image_rotary_emb,
|
| 139 |
+
**attention_kwargs,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
| 143 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
| 144 |
+
|
| 145 |
+
# norm & modulate
|
| 146 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
| 147 |
+
hidden_states, encoder_hidden_states, temb
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# feed-forward
|
| 151 |
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
| 152 |
+
ff_output = self.ff(norm_hidden_states)
|
| 153 |
+
|
| 154 |
+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
| 155 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
| 156 |
+
|
| 157 |
+
return hidden_states, encoder_hidden_states
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
| 161 |
+
"""
|
| 162 |
+
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
| 163 |
+
|
| 164 |
+
Parameters:
|
| 165 |
+
num_attention_heads (`int`, defaults to `30`):
|
| 166 |
+
The number of heads to use for multi-head attention.
|
| 167 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 168 |
+
The number of channels in each head.
|
| 169 |
+
in_channels (`int`, defaults to `16`):
|
| 170 |
+
The number of channels in the input.
|
| 171 |
+
out_channels (`int`, *optional*, defaults to `16`):
|
| 172 |
+
The number of channels in the output.
|
| 173 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 174 |
+
Whether to flip the sin to cos in the time embedding.
|
| 175 |
+
time_embed_dim (`int`, defaults to `512`):
|
| 176 |
+
Output dimension of timestep embeddings.
|
| 177 |
+
ofs_embed_dim (`int`, defaults to `512`):
|
| 178 |
+
Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
|
| 179 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 180 |
+
Input dimension of text embeddings from the text encoder.
|
| 181 |
+
num_layers (`int`, defaults to `30`):
|
| 182 |
+
The number of layers of Transformer blocks to use.
|
| 183 |
+
dropout (`float`, defaults to `0.0`):
|
| 184 |
+
The dropout probability to use.
|
| 185 |
+
attention_bias (`bool`, defaults to `True`):
|
| 186 |
+
Whether to use bias in the attention projection layers.
|
| 187 |
+
sample_width (`int`, defaults to `90`):
|
| 188 |
+
The width of the input latents.
|
| 189 |
+
sample_height (`int`, defaults to `60`):
|
| 190 |
+
The height of the input latents.
|
| 191 |
+
sample_frames (`int`, defaults to `49`):
|
| 192 |
+
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
| 193 |
+
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
| 194 |
+
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
| 195 |
+
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
| 196 |
+
patch_size (`int`, defaults to `2`):
|
| 197 |
+
The size of the patches to use in the patch embedding layer.
|
| 198 |
+
temporal_compression_ratio (`int`, defaults to `4`):
|
| 199 |
+
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
| 200 |
+
max_text_seq_length (`int`, defaults to `226`):
|
| 201 |
+
The maximum sequence length of the input text embeddings.
|
| 202 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 203 |
+
Activation function to use in feed-forward.
|
| 204 |
+
timestep_activation_fn (`str`, defaults to `"silu"`):
|
| 205 |
+
Activation function to use when generating the timestep embeddings.
|
| 206 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 207 |
+
Whether to use elementwise affine in normalization layers.
|
| 208 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 209 |
+
The epsilon value to use in normalization layers.
|
| 210 |
+
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
| 211 |
+
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
| 212 |
+
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
| 213 |
+
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
| 217 |
+
_supports_gradient_checkpointing = True
|
| 218 |
+
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
|
| 219 |
+
|
| 220 |
+
@register_to_config
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
num_attention_heads: int = 30,
|
| 224 |
+
attention_head_dim: int = 64,
|
| 225 |
+
in_channels: int = 16,
|
| 226 |
+
out_channels: Optional[int] = 16,
|
| 227 |
+
flip_sin_to_cos: bool = True,
|
| 228 |
+
freq_shift: int = 0,
|
| 229 |
+
time_embed_dim: int = 512,
|
| 230 |
+
ofs_embed_dim: Optional[int] = None,
|
| 231 |
+
text_embed_dim: int = 4096,
|
| 232 |
+
num_layers: int = 30,
|
| 233 |
+
dropout: float = 0.0,
|
| 234 |
+
attention_bias: bool = True,
|
| 235 |
+
sample_width: int = 90,
|
| 236 |
+
sample_height: int = 60,
|
| 237 |
+
sample_frames: int = 49,
|
| 238 |
+
patch_size: int = 2,
|
| 239 |
+
patch_size_t: Optional[int] = None,
|
| 240 |
+
temporal_compression_ratio: int = 4,
|
| 241 |
+
max_text_seq_length: int = 226,
|
| 242 |
+
activation_fn: str = "gelu-approximate",
|
| 243 |
+
timestep_activation_fn: str = "silu",
|
| 244 |
+
norm_elementwise_affine: bool = True,
|
| 245 |
+
norm_eps: float = 1e-5,
|
| 246 |
+
spatial_interpolation_scale: float = 1.875,
|
| 247 |
+
temporal_interpolation_scale: float = 1.0,
|
| 248 |
+
use_rotary_positional_embeddings: bool = False,
|
| 249 |
+
use_learned_positional_embeddings: bool = False,
|
| 250 |
+
patch_bias: bool = True,
|
| 251 |
+
):
|
| 252 |
+
super().__init__()
|
| 253 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 254 |
+
|
| 255 |
+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
| 258 |
+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
| 259 |
+
"issue at https://github.com/huggingface/diffusers/issues."
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# 1. Patch embedding
|
| 263 |
+
self.patch_embed = CogVideoXPatchEmbed(
|
| 264 |
+
patch_size=patch_size,
|
| 265 |
+
patch_size_t=patch_size_t,
|
| 266 |
+
in_channels=in_channels,
|
| 267 |
+
embed_dim=inner_dim,
|
| 268 |
+
text_embed_dim=text_embed_dim,
|
| 269 |
+
bias=patch_bias,
|
| 270 |
+
sample_width=sample_width,
|
| 271 |
+
sample_height=sample_height,
|
| 272 |
+
sample_frames=sample_frames,
|
| 273 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 274 |
+
max_text_seq_length=max_text_seq_length,
|
| 275 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
| 276 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
| 277 |
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
| 278 |
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
| 279 |
+
)
|
| 280 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
| 281 |
+
|
| 282 |
+
# 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
|
| 283 |
+
|
| 284 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
| 285 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
| 286 |
+
|
| 287 |
+
self.ofs_proj = None
|
| 288 |
+
self.ofs_embedding = None
|
| 289 |
+
if ofs_embed_dim:
|
| 290 |
+
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
|
| 291 |
+
self.ofs_embedding = TimestepEmbedding(
|
| 292 |
+
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
|
| 293 |
+
) # same as time embeddings, for ofs
|
| 294 |
+
|
| 295 |
+
# 3. Define spatio-temporal transformers blocks
|
| 296 |
+
self.transformer_blocks = nn.ModuleList(
|
| 297 |
+
[
|
| 298 |
+
CogVideoXBlock(
|
| 299 |
+
dim=inner_dim,
|
| 300 |
+
num_attention_heads=num_attention_heads,
|
| 301 |
+
attention_head_dim=attention_head_dim,
|
| 302 |
+
time_embed_dim=time_embed_dim,
|
| 303 |
+
dropout=dropout,
|
| 304 |
+
activation_fn=activation_fn,
|
| 305 |
+
attention_bias=attention_bias,
|
| 306 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 307 |
+
norm_eps=norm_eps,
|
| 308 |
+
)
|
| 309 |
+
for _ in range(num_layers)
|
| 310 |
+
]
|
| 311 |
+
)
|
| 312 |
+
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
| 313 |
+
|
| 314 |
+
# 4. Output blocks
|
| 315 |
+
self.norm_out = AdaLayerNorm(
|
| 316 |
+
embedding_dim=time_embed_dim,
|
| 317 |
+
output_dim=2 * inner_dim,
|
| 318 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 319 |
+
norm_eps=norm_eps,
|
| 320 |
+
chunk_dim=1,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if patch_size_t is None:
|
| 324 |
+
# For CogVideox 1.0
|
| 325 |
+
output_dim = patch_size * patch_size * out_channels
|
| 326 |
+
else:
|
| 327 |
+
# For CogVideoX 1.5
|
| 328 |
+
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
| 329 |
+
|
| 330 |
+
self.proj_out = nn.Linear(inner_dim, output_dim)
|
| 331 |
+
|
| 332 |
+
self.gradient_checkpointing = False
|
| 333 |
+
|
| 334 |
+
@property
|
| 335 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 336 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 337 |
+
r"""
|
| 338 |
+
Returns:
|
| 339 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 340 |
+
indexed by its weight name.
|
| 341 |
+
"""
|
| 342 |
+
# set recursively
|
| 343 |
+
processors = {}
|
| 344 |
+
|
| 345 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 346 |
+
if hasattr(module, "get_processor"):
|
| 347 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 348 |
+
|
| 349 |
+
for sub_name, child in module.named_children():
|
| 350 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 351 |
+
|
| 352 |
+
return processors
|
| 353 |
+
|
| 354 |
+
for name, module in self.named_children():
|
| 355 |
+
fn_recursive_add_processors(name, module, processors)
|
| 356 |
+
|
| 357 |
+
return processors
|
| 358 |
+
|
| 359 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 360 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 361 |
+
r"""
|
| 362 |
+
Sets the attention processor to use to compute attention.
|
| 363 |
+
|
| 364 |
+
Parameters:
|
| 365 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 366 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 367 |
+
for **all** `Attention` layers.
|
| 368 |
+
|
| 369 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 370 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 371 |
+
|
| 372 |
+
"""
|
| 373 |
+
count = len(self.attn_processors.keys())
|
| 374 |
+
|
| 375 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 378 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 382 |
+
if hasattr(module, "set_processor"):
|
| 383 |
+
if not isinstance(processor, dict):
|
| 384 |
+
module.set_processor(processor)
|
| 385 |
+
else:
|
| 386 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 387 |
+
|
| 388 |
+
for sub_name, child in module.named_children():
|
| 389 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 390 |
+
|
| 391 |
+
for name, module in self.named_children():
|
| 392 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 393 |
+
|
| 394 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
| 395 |
+
def fuse_qkv_projections(self):
|
| 396 |
+
"""
|
| 397 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 398 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 399 |
+
|
| 400 |
+
<Tip warning={true}>
|
| 401 |
+
|
| 402 |
+
This API is 🧪 experimental.
|
| 403 |
+
|
| 404 |
+
</Tip>
|
| 405 |
+
"""
|
| 406 |
+
self.original_attn_processors = None
|
| 407 |
+
|
| 408 |
+
for _, attn_processor in self.attn_processors.items():
|
| 409 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 410 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 411 |
+
|
| 412 |
+
self.original_attn_processors = self.attn_processors
|
| 413 |
+
|
| 414 |
+
for module in self.modules():
|
| 415 |
+
if isinstance(module, Attention):
|
| 416 |
+
module.fuse_projections(fuse=True)
|
| 417 |
+
|
| 418 |
+
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
| 419 |
+
|
| 420 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 421 |
+
def unfuse_qkv_projections(self):
|
| 422 |
+
"""Disables the fused QKV projection if enabled.
|
| 423 |
+
|
| 424 |
+
<Tip warning={true}>
|
| 425 |
+
|
| 426 |
+
This API is 🧪 experimental.
|
| 427 |
+
|
| 428 |
+
</Tip>
|
| 429 |
+
|
| 430 |
+
"""
|
| 431 |
+
if self.original_attn_processors is not None:
|
| 432 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 433 |
+
|
| 434 |
+
def forward(
|
| 435 |
+
self,
|
| 436 |
+
hidden_states: torch.Tensor,
|
| 437 |
+
encoder_hidden_states: torch.Tensor,
|
| 438 |
+
timestep: Union[int, float, torch.LongTensor],
|
| 439 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 440 |
+
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
| 441 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 442 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 443 |
+
return_dict: bool = True,
|
| 444 |
+
):
|
| 445 |
+
if attention_kwargs is not None:
|
| 446 |
+
attention_kwargs = attention_kwargs.copy()
|
| 447 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 448 |
+
else:
|
| 449 |
+
lora_scale = 1.0
|
| 450 |
+
|
| 451 |
+
if USE_PEFT_BACKEND:
|
| 452 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 453 |
+
scale_lora_layers(self, lora_scale)
|
| 454 |
+
else:
|
| 455 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 456 |
+
logger.warning(
|
| 457 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
| 461 |
+
|
| 462 |
+
# 1. Time embedding
|
| 463 |
+
timesteps = timestep
|
| 464 |
+
t_emb = self.time_proj(timesteps)
|
| 465 |
+
|
| 466 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 467 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 468 |
+
# there might be better ways to encapsulate this.
|
| 469 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
| 470 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 471 |
+
|
| 472 |
+
if self.ofs_embedding is not None:
|
| 473 |
+
ofs_emb = self.ofs_proj(ofs)
|
| 474 |
+
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
| 475 |
+
ofs_emb = self.ofs_embedding(ofs_emb)
|
| 476 |
+
emb = emb + ofs_emb
|
| 477 |
+
|
| 478 |
+
# 2. Patch embedding
|
| 479 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
| 480 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
| 481 |
+
|
| 482 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 483 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
| 484 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 485 |
+
|
| 486 |
+
# 3. Transformer blocks
|
| 487 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 488 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 489 |
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
| 490 |
+
block,
|
| 491 |
+
hidden_states,
|
| 492 |
+
encoder_hidden_states,
|
| 493 |
+
emb,
|
| 494 |
+
image_rotary_emb,
|
| 495 |
+
attention_kwargs,
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
hidden_states, encoder_hidden_states = block(
|
| 499 |
+
hidden_states=hidden_states,
|
| 500 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 501 |
+
temb=emb,
|
| 502 |
+
image_rotary_emb=image_rotary_emb,
|
| 503 |
+
attention_kwargs=attention_kwargs,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
hidden_states = self.norm_final(hidden_states)
|
| 507 |
+
|
| 508 |
+
# 4. Final block
|
| 509 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
| 510 |
+
hidden_states = self.proj_out(hidden_states)
|
| 511 |
+
|
| 512 |
+
# 5. Unpatchify
|
| 513 |
+
p = self.config.patch_size
|
| 514 |
+
p_t = self.config.patch_size_t
|
| 515 |
+
|
| 516 |
+
if p_t is None:
|
| 517 |
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
| 518 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
| 519 |
+
else:
|
| 520 |
+
output = hidden_states.reshape(
|
| 521 |
+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
| 522 |
+
)
|
| 523 |
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
| 524 |
+
|
| 525 |
+
if USE_PEFT_BACKEND:
|
| 526 |
+
# remove `lora_scale` from each PEFT layer
|
| 527 |
+
unscale_lora_layers(self, lora_scale)
|
| 528 |
+
|
| 529 |
+
if not return_dict:
|
| 530 |
+
return (output,)
|
| 531 |
+
return Transformer2DModelOutput(sample=output)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/consisid_transformer_3d.py
ADDED
|
@@ -0,0 +1,789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from ...loaders import PeftAdapterMixin
|
| 23 |
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 24 |
+
from ...utils.torch_utils import maybe_allow_in_graph
|
| 25 |
+
from ..attention import Attention, FeedForward
|
| 26 |
+
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
|
| 27 |
+
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
| 28 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 29 |
+
from ..modeling_utils import ModelMixin
|
| 30 |
+
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class PerceiverAttention(nn.Module):
|
| 37 |
+
def __init__(self, dim: int, dim_head: int = 64, heads: int = 8, kv_dim: Optional[int] = None):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
self.scale = dim_head**-0.5
|
| 41 |
+
self.dim_head = dim_head
|
| 42 |
+
self.heads = heads
|
| 43 |
+
inner_dim = dim_head * heads
|
| 44 |
+
|
| 45 |
+
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
|
| 46 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 47 |
+
|
| 48 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 49 |
+
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
|
| 50 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 51 |
+
|
| 52 |
+
def forward(self, image_embeds: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
# Apply normalization
|
| 54 |
+
image_embeds = self.norm1(image_embeds)
|
| 55 |
+
latents = self.norm2(latents)
|
| 56 |
+
|
| 57 |
+
batch_size, seq_len, _ = latents.shape # Get batch size and sequence length
|
| 58 |
+
|
| 59 |
+
# Compute query, key, and value matrices
|
| 60 |
+
query = self.to_q(latents)
|
| 61 |
+
kv_input = torch.cat((image_embeds, latents), dim=-2)
|
| 62 |
+
key, value = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 63 |
+
|
| 64 |
+
# Reshape the tensors for multi-head attention
|
| 65 |
+
query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
|
| 66 |
+
key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
|
| 67 |
+
value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
|
| 68 |
+
|
| 69 |
+
# attention
|
| 70 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
| 71 |
+
weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
| 72 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 73 |
+
output = weight @ value
|
| 74 |
+
|
| 75 |
+
# Reshape and return the final output
|
| 76 |
+
output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
| 77 |
+
|
| 78 |
+
return self.to_out(output)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class LocalFacialExtractor(nn.Module):
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
id_dim: int = 1280,
|
| 85 |
+
vit_dim: int = 1024,
|
| 86 |
+
depth: int = 10,
|
| 87 |
+
dim_head: int = 64,
|
| 88 |
+
heads: int = 16,
|
| 89 |
+
num_id_token: int = 5,
|
| 90 |
+
num_queries: int = 32,
|
| 91 |
+
output_dim: int = 2048,
|
| 92 |
+
ff_mult: int = 4,
|
| 93 |
+
num_scale: int = 5,
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
|
| 97 |
+
# Storing identity token and query information
|
| 98 |
+
self.num_id_token = num_id_token
|
| 99 |
+
self.vit_dim = vit_dim
|
| 100 |
+
self.num_queries = num_queries
|
| 101 |
+
assert depth % num_scale == 0
|
| 102 |
+
self.depth = depth // num_scale
|
| 103 |
+
self.num_scale = num_scale
|
| 104 |
+
scale = vit_dim**-0.5
|
| 105 |
+
|
| 106 |
+
# Learnable latent query embeddings
|
| 107 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, vit_dim) * scale)
|
| 108 |
+
# Projection layer to map the latent output to the desired dimension
|
| 109 |
+
self.proj_out = nn.Parameter(scale * torch.randn(vit_dim, output_dim))
|
| 110 |
+
|
| 111 |
+
# Attention and ConsisIDFeedForward layer stack
|
| 112 |
+
self.layers = nn.ModuleList([])
|
| 113 |
+
for _ in range(depth):
|
| 114 |
+
self.layers.append(
|
| 115 |
+
nn.ModuleList(
|
| 116 |
+
[
|
| 117 |
+
PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
|
| 118 |
+
nn.Sequential(
|
| 119 |
+
nn.LayerNorm(vit_dim),
|
| 120 |
+
nn.Linear(vit_dim, vit_dim * ff_mult, bias=False),
|
| 121 |
+
nn.GELU(),
|
| 122 |
+
nn.Linear(vit_dim * ff_mult, vit_dim, bias=False),
|
| 123 |
+
), # ConsisIDFeedForward layer
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Mappings for each of the 5 different ViT features
|
| 129 |
+
for i in range(num_scale):
|
| 130 |
+
setattr(
|
| 131 |
+
self,
|
| 132 |
+
f"mapping_{i}",
|
| 133 |
+
nn.Sequential(
|
| 134 |
+
nn.Linear(vit_dim, vit_dim),
|
| 135 |
+
nn.LayerNorm(vit_dim),
|
| 136 |
+
nn.LeakyReLU(),
|
| 137 |
+
nn.Linear(vit_dim, vit_dim),
|
| 138 |
+
nn.LayerNorm(vit_dim),
|
| 139 |
+
nn.LeakyReLU(),
|
| 140 |
+
nn.Linear(vit_dim, vit_dim),
|
| 141 |
+
),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Mapping for identity embedding vectors
|
| 145 |
+
self.id_embedding_mapping = nn.Sequential(
|
| 146 |
+
nn.Linear(id_dim, vit_dim),
|
| 147 |
+
nn.LayerNorm(vit_dim),
|
| 148 |
+
nn.LeakyReLU(),
|
| 149 |
+
nn.Linear(vit_dim, vit_dim),
|
| 150 |
+
nn.LayerNorm(vit_dim),
|
| 151 |
+
nn.LeakyReLU(),
|
| 152 |
+
nn.Linear(vit_dim, vit_dim * num_id_token),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(self, id_embeds: torch.Tensor, vit_hidden_states: List[torch.Tensor]) -> torch.Tensor:
|
| 156 |
+
# Repeat latent queries for the batch size
|
| 157 |
+
latents = self.latents.repeat(id_embeds.size(0), 1, 1)
|
| 158 |
+
|
| 159 |
+
# Map the identity embedding to tokens
|
| 160 |
+
id_embeds = self.id_embedding_mapping(id_embeds)
|
| 161 |
+
id_embeds = id_embeds.reshape(-1, self.num_id_token, self.vit_dim)
|
| 162 |
+
|
| 163 |
+
# Concatenate identity tokens with the latent queries
|
| 164 |
+
latents = torch.cat((latents, id_embeds), dim=1)
|
| 165 |
+
|
| 166 |
+
# Process each of the num_scale visual feature inputs
|
| 167 |
+
for i in range(self.num_scale):
|
| 168 |
+
vit_feature = getattr(self, f"mapping_{i}")(vit_hidden_states[i])
|
| 169 |
+
ctx_feature = torch.cat((id_embeds, vit_feature), dim=1)
|
| 170 |
+
|
| 171 |
+
# Pass through the PerceiverAttention and ConsisIDFeedForward layers
|
| 172 |
+
for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
|
| 173 |
+
latents = attn(ctx_feature, latents) + latents
|
| 174 |
+
latents = ff(latents) + latents
|
| 175 |
+
|
| 176 |
+
# Retain only the query latents
|
| 177 |
+
latents = latents[:, : self.num_queries]
|
| 178 |
+
# Project the latents to the output dimension
|
| 179 |
+
latents = latents @ self.proj_out
|
| 180 |
+
return latents
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class PerceiverCrossAttention(nn.Module):
|
| 184 |
+
def __init__(self, dim: int = 3072, dim_head: int = 128, heads: int = 16, kv_dim: int = 2048):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
self.scale = dim_head**-0.5
|
| 188 |
+
self.dim_head = dim_head
|
| 189 |
+
self.heads = heads
|
| 190 |
+
inner_dim = dim_head * heads
|
| 191 |
+
|
| 192 |
+
# Layer normalization to stabilize training
|
| 193 |
+
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
|
| 194 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 195 |
+
|
| 196 |
+
# Linear transformations to produce queries, keys, and values
|
| 197 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 198 |
+
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
|
| 199 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 200 |
+
|
| 201 |
+
def forward(self, image_embeds: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 202 |
+
# Apply layer normalization to the input image and latent features
|
| 203 |
+
image_embeds = self.norm1(image_embeds)
|
| 204 |
+
hidden_states = self.norm2(hidden_states)
|
| 205 |
+
|
| 206 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 207 |
+
|
| 208 |
+
# Compute queries, keys, and values
|
| 209 |
+
query = self.to_q(hidden_states)
|
| 210 |
+
key, value = self.to_kv(image_embeds).chunk(2, dim=-1)
|
| 211 |
+
|
| 212 |
+
# Reshape tensors to split into attention heads
|
| 213 |
+
query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
|
| 214 |
+
key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
|
| 215 |
+
value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
|
| 216 |
+
|
| 217 |
+
# Compute attention weights
|
| 218 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
| 219 |
+
weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable scaling than post-division
|
| 220 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 221 |
+
|
| 222 |
+
# Compute the output via weighted combination of values
|
| 223 |
+
out = weight @ value
|
| 224 |
+
|
| 225 |
+
# Reshape and permute to prepare for final linear transformation
|
| 226 |
+
out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
| 227 |
+
|
| 228 |
+
return self.to_out(out)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@maybe_allow_in_graph
|
| 232 |
+
class ConsisIDBlock(nn.Module):
|
| 233 |
+
r"""
|
| 234 |
+
Transformer block used in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) model.
|
| 235 |
+
|
| 236 |
+
Parameters:
|
| 237 |
+
dim (`int`):
|
| 238 |
+
The number of channels in the input and output.
|
| 239 |
+
num_attention_heads (`int`):
|
| 240 |
+
The number of heads to use for multi-head attention.
|
| 241 |
+
attention_head_dim (`int`):
|
| 242 |
+
The number of channels in each head.
|
| 243 |
+
time_embed_dim (`int`):
|
| 244 |
+
The number of channels in timestep embedding.
|
| 245 |
+
dropout (`float`, defaults to `0.0`):
|
| 246 |
+
The dropout probability to use.
|
| 247 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 248 |
+
Activation function to be used in feed-forward.
|
| 249 |
+
attention_bias (`bool`, defaults to `False`):
|
| 250 |
+
Whether or not to use bias in attention projection layers.
|
| 251 |
+
qk_norm (`bool`, defaults to `True`):
|
| 252 |
+
Whether or not to use normalization after query and key projections in Attention.
|
| 253 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 254 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 255 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 256 |
+
Epsilon value for normalization layers.
|
| 257 |
+
final_dropout (`bool` defaults to `False`):
|
| 258 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 259 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
| 260 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
| 261 |
+
ff_bias (`bool`, defaults to `True`):
|
| 262 |
+
Whether or not to use bias in Feed-forward layer.
|
| 263 |
+
attention_out_bias (`bool`, defaults to `True`):
|
| 264 |
+
Whether or not to use bias in Attention output projection layer.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
dim: int,
|
| 270 |
+
num_attention_heads: int,
|
| 271 |
+
attention_head_dim: int,
|
| 272 |
+
time_embed_dim: int,
|
| 273 |
+
dropout: float = 0.0,
|
| 274 |
+
activation_fn: str = "gelu-approximate",
|
| 275 |
+
attention_bias: bool = False,
|
| 276 |
+
qk_norm: bool = True,
|
| 277 |
+
norm_elementwise_affine: bool = True,
|
| 278 |
+
norm_eps: float = 1e-5,
|
| 279 |
+
final_dropout: bool = True,
|
| 280 |
+
ff_inner_dim: Optional[int] = None,
|
| 281 |
+
ff_bias: bool = True,
|
| 282 |
+
attention_out_bias: bool = True,
|
| 283 |
+
):
|
| 284 |
+
super().__init__()
|
| 285 |
+
|
| 286 |
+
# 1. Self Attention
|
| 287 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 288 |
+
|
| 289 |
+
self.attn1 = Attention(
|
| 290 |
+
query_dim=dim,
|
| 291 |
+
dim_head=attention_head_dim,
|
| 292 |
+
heads=num_attention_heads,
|
| 293 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 294 |
+
eps=1e-6,
|
| 295 |
+
bias=attention_bias,
|
| 296 |
+
out_bias=attention_out_bias,
|
| 297 |
+
processor=CogVideoXAttnProcessor2_0(),
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# 2. Feed Forward
|
| 301 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 302 |
+
|
| 303 |
+
self.ff = FeedForward(
|
| 304 |
+
dim,
|
| 305 |
+
dropout=dropout,
|
| 306 |
+
activation_fn=activation_fn,
|
| 307 |
+
final_dropout=final_dropout,
|
| 308 |
+
inner_dim=ff_inner_dim,
|
| 309 |
+
bias=ff_bias,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def forward(
|
| 313 |
+
self,
|
| 314 |
+
hidden_states: torch.Tensor,
|
| 315 |
+
encoder_hidden_states: torch.Tensor,
|
| 316 |
+
temb: torch.Tensor,
|
| 317 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 318 |
+
) -> torch.Tensor:
|
| 319 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 320 |
+
|
| 321 |
+
# norm & modulate
|
| 322 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
| 323 |
+
hidden_states, encoder_hidden_states, temb
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# attention
|
| 327 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
| 328 |
+
hidden_states=norm_hidden_states,
|
| 329 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 330 |
+
image_rotary_emb=image_rotary_emb,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
| 334 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
| 335 |
+
|
| 336 |
+
# norm & modulate
|
| 337 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
| 338 |
+
hidden_states, encoder_hidden_states, temb
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# feed-forward
|
| 342 |
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
| 343 |
+
ff_output = self.ff(norm_hidden_states)
|
| 344 |
+
|
| 345 |
+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
| 346 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
| 347 |
+
|
| 348 |
+
return hidden_states, encoder_hidden_states
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
| 352 |
+
"""
|
| 353 |
+
A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
|
| 354 |
+
|
| 355 |
+
Parameters:
|
| 356 |
+
num_attention_heads (`int`, defaults to `30`):
|
| 357 |
+
The number of heads to use for multi-head attention.
|
| 358 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 359 |
+
The number of channels in each head.
|
| 360 |
+
in_channels (`int`, defaults to `16`):
|
| 361 |
+
The number of channels in the input.
|
| 362 |
+
out_channels (`int`, *optional*, defaults to `16`):
|
| 363 |
+
The number of channels in the output.
|
| 364 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 365 |
+
Whether to flip the sin to cos in the time embedding.
|
| 366 |
+
time_embed_dim (`int`, defaults to `512`):
|
| 367 |
+
Output dimension of timestep embeddings.
|
| 368 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 369 |
+
Input dimension of text embeddings from the text encoder.
|
| 370 |
+
num_layers (`int`, defaults to `30`):
|
| 371 |
+
The number of layers of Transformer blocks to use.
|
| 372 |
+
dropout (`float`, defaults to `0.0`):
|
| 373 |
+
The dropout probability to use.
|
| 374 |
+
attention_bias (`bool`, defaults to `True`):
|
| 375 |
+
Whether to use bias in the attention projection layers.
|
| 376 |
+
sample_width (`int`, defaults to `90`):
|
| 377 |
+
The width of the input latents.
|
| 378 |
+
sample_height (`int`, defaults to `60`):
|
| 379 |
+
The height of the input latents.
|
| 380 |
+
sample_frames (`int`, defaults to `49`):
|
| 381 |
+
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
| 382 |
+
instead of 13 because ConsisID processed 13 latent frames at once in its default and recommended settings,
|
| 383 |
+
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
| 384 |
+
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
| 385 |
+
patch_size (`int`, defaults to `2`):
|
| 386 |
+
The size of the patches to use in the patch embedding layer.
|
| 387 |
+
temporal_compression_ratio (`int`, defaults to `4`):
|
| 388 |
+
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
| 389 |
+
max_text_seq_length (`int`, defaults to `226`):
|
| 390 |
+
The maximum sequence length of the input text embeddings.
|
| 391 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 392 |
+
Activation function to use in feed-forward.
|
| 393 |
+
timestep_activation_fn (`str`, defaults to `"silu"`):
|
| 394 |
+
Activation function to use when generating the timestep embeddings.
|
| 395 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 396 |
+
Whether to use elementwise affine in normalization layers.
|
| 397 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 398 |
+
The epsilon value to use in normalization layers.
|
| 399 |
+
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
| 400 |
+
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
| 401 |
+
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
| 402 |
+
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
| 403 |
+
is_train_face (`bool`, defaults to `False`):
|
| 404 |
+
Whether to use enable the identity-preserving module during the training process. When set to `True`, the
|
| 405 |
+
model will focus on identity-preserving tasks.
|
| 406 |
+
is_kps (`bool`, defaults to `False`):
|
| 407 |
+
Whether to enable keypoint for global facial extractor. If `True`, keypoints will be in the model.
|
| 408 |
+
cross_attn_interval (`int`, defaults to `2`):
|
| 409 |
+
The interval between cross-attention layers in the Transformer architecture. A larger value may reduce the
|
| 410 |
+
frequency of cross-attention computations, which can help reduce computational overhead.
|
| 411 |
+
cross_attn_dim_head (`int`, optional, defaults to `128`):
|
| 412 |
+
The dimensionality of each attention head in the cross-attention layers of the Transformer architecture. A
|
| 413 |
+
larger value increases the capacity to attend to more complex patterns, but also increases memory and
|
| 414 |
+
computation costs.
|
| 415 |
+
cross_attn_num_heads (`int`, optional, defaults to `16`):
|
| 416 |
+
The number of attention heads in the cross-attention layers. More heads allow for more parallel attention
|
| 417 |
+
mechanisms, capturing diverse relationships between different components of the input, but can also
|
| 418 |
+
increase computational requirements.
|
| 419 |
+
LFE_id_dim (`int`, optional, defaults to `1280`):
|
| 420 |
+
The dimensionality of the identity vector used in the Local Facial Extractor (LFE). This vector represents
|
| 421 |
+
the identity features of a face, which are important for tasks like face recognition and identity
|
| 422 |
+
preservation across different frames.
|
| 423 |
+
LFE_vit_dim (`int`, optional, defaults to `1024`):
|
| 424 |
+
The dimension of the vision transformer (ViT) output used in the Local Facial Extractor (LFE). This value
|
| 425 |
+
dictates the size of the transformer-generated feature vectors that will be processed for facial feature
|
| 426 |
+
extraction.
|
| 427 |
+
LFE_depth (`int`, optional, defaults to `10`):
|
| 428 |
+
The number of layers in the Local Facial Extractor (LFE). Increasing the depth allows the model to capture
|
| 429 |
+
more complex representations of facial features, but also increases the computational load.
|
| 430 |
+
LFE_dim_head (`int`, optional, defaults to `64`):
|
| 431 |
+
The dimensionality of each attention head in the Local Facial Extractor (LFE). This parameter affects how
|
| 432 |
+
finely the model can process and focus on different parts of the facial features during the extraction
|
| 433 |
+
process.
|
| 434 |
+
LFE_num_heads (`int`, optional, defaults to `16`):
|
| 435 |
+
The number of attention heads in the Local Facial Extractor (LFE). More heads can improve the model's
|
| 436 |
+
ability to capture diverse facial features, but at the cost of increased computational complexity.
|
| 437 |
+
LFE_num_id_token (`int`, optional, defaults to `5`):
|
| 438 |
+
The number of identity tokens used in the Local Facial Extractor (LFE). This defines how many
|
| 439 |
+
identity-related tokens the model will process to ensure face identity preservation during feature
|
| 440 |
+
extraction.
|
| 441 |
+
LFE_num_querie (`int`, optional, defaults to `32`):
|
| 442 |
+
The number of query tokens used in the Local Facial Extractor (LFE). These tokens are used to capture
|
| 443 |
+
high-frequency face-related information that aids in accurate facial feature extraction.
|
| 444 |
+
LFE_output_dim (`int`, optional, defaults to `2048`):
|
| 445 |
+
The output dimension of the Local Facial Extractor (LFE). This dimension determines the size of the feature
|
| 446 |
+
vectors produced by the LFE module, which will be used for subsequent tasks such as face recognition or
|
| 447 |
+
tracking.
|
| 448 |
+
LFE_ff_mult (`int`, optional, defaults to `4`):
|
| 449 |
+
The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial
|
| 450 |
+
Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature
|
| 451 |
+
transformations, but also increases the computation and memory requirements.
|
| 452 |
+
LFE_num_scale (`int`, optional, defaults to `5`):
|
| 453 |
+
The number of different scales visual feature. A higher value increases the model's capacity to learn more
|
| 454 |
+
complex facial feature transformations, but also increases the computation and memory requirements.
|
| 455 |
+
local_face_scale (`float`, defaults to `1.0`):
|
| 456 |
+
A scaling factor used to adjust the importance of local facial features in the model. This can influence
|
| 457 |
+
how strongly the model focuses on high frequency face-related content.
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
_supports_gradient_checkpointing = True
|
| 461 |
+
|
| 462 |
+
@register_to_config
|
| 463 |
+
def __init__(
|
| 464 |
+
self,
|
| 465 |
+
num_attention_heads: int = 30,
|
| 466 |
+
attention_head_dim: int = 64,
|
| 467 |
+
in_channels: int = 16,
|
| 468 |
+
out_channels: Optional[int] = 16,
|
| 469 |
+
flip_sin_to_cos: bool = True,
|
| 470 |
+
freq_shift: int = 0,
|
| 471 |
+
time_embed_dim: int = 512,
|
| 472 |
+
text_embed_dim: int = 4096,
|
| 473 |
+
num_layers: int = 30,
|
| 474 |
+
dropout: float = 0.0,
|
| 475 |
+
attention_bias: bool = True,
|
| 476 |
+
sample_width: int = 90,
|
| 477 |
+
sample_height: int = 60,
|
| 478 |
+
sample_frames: int = 49,
|
| 479 |
+
patch_size: int = 2,
|
| 480 |
+
temporal_compression_ratio: int = 4,
|
| 481 |
+
max_text_seq_length: int = 226,
|
| 482 |
+
activation_fn: str = "gelu-approximate",
|
| 483 |
+
timestep_activation_fn: str = "silu",
|
| 484 |
+
norm_elementwise_affine: bool = True,
|
| 485 |
+
norm_eps: float = 1e-5,
|
| 486 |
+
spatial_interpolation_scale: float = 1.875,
|
| 487 |
+
temporal_interpolation_scale: float = 1.0,
|
| 488 |
+
use_rotary_positional_embeddings: bool = False,
|
| 489 |
+
use_learned_positional_embeddings: bool = False,
|
| 490 |
+
is_train_face: bool = False,
|
| 491 |
+
is_kps: bool = False,
|
| 492 |
+
cross_attn_interval: int = 2,
|
| 493 |
+
cross_attn_dim_head: int = 128,
|
| 494 |
+
cross_attn_num_heads: int = 16,
|
| 495 |
+
LFE_id_dim: int = 1280,
|
| 496 |
+
LFE_vit_dim: int = 1024,
|
| 497 |
+
LFE_depth: int = 10,
|
| 498 |
+
LFE_dim_head: int = 64,
|
| 499 |
+
LFE_num_heads: int = 16,
|
| 500 |
+
LFE_num_id_token: int = 5,
|
| 501 |
+
LFE_num_querie: int = 32,
|
| 502 |
+
LFE_output_dim: int = 2048,
|
| 503 |
+
LFE_ff_mult: int = 4,
|
| 504 |
+
LFE_num_scale: int = 5,
|
| 505 |
+
local_face_scale: float = 1.0,
|
| 506 |
+
):
|
| 507 |
+
super().__init__()
|
| 508 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 509 |
+
|
| 510 |
+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
| 511 |
+
raise ValueError(
|
| 512 |
+
"There are no ConsisID checkpoints available with disable rotary embeddings and learned positional "
|
| 513 |
+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
| 514 |
+
"issue at https://github.com/huggingface/diffusers/issues."
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# 1. Patch embedding
|
| 518 |
+
self.patch_embed = CogVideoXPatchEmbed(
|
| 519 |
+
patch_size=patch_size,
|
| 520 |
+
in_channels=in_channels,
|
| 521 |
+
embed_dim=inner_dim,
|
| 522 |
+
text_embed_dim=text_embed_dim,
|
| 523 |
+
bias=True,
|
| 524 |
+
sample_width=sample_width,
|
| 525 |
+
sample_height=sample_height,
|
| 526 |
+
sample_frames=sample_frames,
|
| 527 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 528 |
+
max_text_seq_length=max_text_seq_length,
|
| 529 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
| 530 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
| 531 |
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
| 532 |
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
| 533 |
+
)
|
| 534 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
| 535 |
+
|
| 536 |
+
# 2. Time embeddings
|
| 537 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
| 538 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
| 539 |
+
|
| 540 |
+
# 3. Define spatio-temporal transformers blocks
|
| 541 |
+
self.transformer_blocks = nn.ModuleList(
|
| 542 |
+
[
|
| 543 |
+
ConsisIDBlock(
|
| 544 |
+
dim=inner_dim,
|
| 545 |
+
num_attention_heads=num_attention_heads,
|
| 546 |
+
attention_head_dim=attention_head_dim,
|
| 547 |
+
time_embed_dim=time_embed_dim,
|
| 548 |
+
dropout=dropout,
|
| 549 |
+
activation_fn=activation_fn,
|
| 550 |
+
attention_bias=attention_bias,
|
| 551 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 552 |
+
norm_eps=norm_eps,
|
| 553 |
+
)
|
| 554 |
+
for _ in range(num_layers)
|
| 555 |
+
]
|
| 556 |
+
)
|
| 557 |
+
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
| 558 |
+
|
| 559 |
+
# 4. Output blocks
|
| 560 |
+
self.norm_out = AdaLayerNorm(
|
| 561 |
+
embedding_dim=time_embed_dim,
|
| 562 |
+
output_dim=2 * inner_dim,
|
| 563 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 564 |
+
norm_eps=norm_eps,
|
| 565 |
+
chunk_dim=1,
|
| 566 |
+
)
|
| 567 |
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
| 568 |
+
|
| 569 |
+
self.is_train_face = is_train_face
|
| 570 |
+
self.is_kps = is_kps
|
| 571 |
+
|
| 572 |
+
# 5. Define identity-preserving config
|
| 573 |
+
if is_train_face:
|
| 574 |
+
# LFE configs
|
| 575 |
+
self.LFE_id_dim = LFE_id_dim
|
| 576 |
+
self.LFE_vit_dim = LFE_vit_dim
|
| 577 |
+
self.LFE_depth = LFE_depth
|
| 578 |
+
self.LFE_dim_head = LFE_dim_head
|
| 579 |
+
self.LFE_num_heads = LFE_num_heads
|
| 580 |
+
self.LFE_num_id_token = LFE_num_id_token
|
| 581 |
+
self.LFE_num_querie = LFE_num_querie
|
| 582 |
+
self.LFE_output_dim = LFE_output_dim
|
| 583 |
+
self.LFE_ff_mult = LFE_ff_mult
|
| 584 |
+
self.LFE_num_scale = LFE_num_scale
|
| 585 |
+
# cross configs
|
| 586 |
+
self.inner_dim = inner_dim
|
| 587 |
+
self.cross_attn_interval = cross_attn_interval
|
| 588 |
+
self.num_cross_attn = num_layers // cross_attn_interval
|
| 589 |
+
self.cross_attn_dim_head = cross_attn_dim_head
|
| 590 |
+
self.cross_attn_num_heads = cross_attn_num_heads
|
| 591 |
+
self.cross_attn_kv_dim = int(self.inner_dim / 3 * 2)
|
| 592 |
+
self.local_face_scale = local_face_scale
|
| 593 |
+
# face modules
|
| 594 |
+
self._init_face_inputs()
|
| 595 |
+
|
| 596 |
+
self.gradient_checkpointing = False
|
| 597 |
+
|
| 598 |
+
def _init_face_inputs(self):
|
| 599 |
+
self.local_facial_extractor = LocalFacialExtractor(
|
| 600 |
+
id_dim=self.LFE_id_dim,
|
| 601 |
+
vit_dim=self.LFE_vit_dim,
|
| 602 |
+
depth=self.LFE_depth,
|
| 603 |
+
dim_head=self.LFE_dim_head,
|
| 604 |
+
heads=self.LFE_num_heads,
|
| 605 |
+
num_id_token=self.LFE_num_id_token,
|
| 606 |
+
num_queries=self.LFE_num_querie,
|
| 607 |
+
output_dim=self.LFE_output_dim,
|
| 608 |
+
ff_mult=self.LFE_ff_mult,
|
| 609 |
+
num_scale=self.LFE_num_scale,
|
| 610 |
+
)
|
| 611 |
+
self.perceiver_cross_attention = nn.ModuleList(
|
| 612 |
+
[
|
| 613 |
+
PerceiverCrossAttention(
|
| 614 |
+
dim=self.inner_dim,
|
| 615 |
+
dim_head=self.cross_attn_dim_head,
|
| 616 |
+
heads=self.cross_attn_num_heads,
|
| 617 |
+
kv_dim=self.cross_attn_kv_dim,
|
| 618 |
+
)
|
| 619 |
+
for _ in range(self.num_cross_attn)
|
| 620 |
+
]
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
@property
|
| 624 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 625 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 626 |
+
r"""
|
| 627 |
+
Returns:
|
| 628 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 629 |
+
indexed by its weight name.
|
| 630 |
+
"""
|
| 631 |
+
# set recursively
|
| 632 |
+
processors = {}
|
| 633 |
+
|
| 634 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 635 |
+
if hasattr(module, "get_processor"):
|
| 636 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 637 |
+
|
| 638 |
+
for sub_name, child in module.named_children():
|
| 639 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 640 |
+
|
| 641 |
+
return processors
|
| 642 |
+
|
| 643 |
+
for name, module in self.named_children():
|
| 644 |
+
fn_recursive_add_processors(name, module, processors)
|
| 645 |
+
|
| 646 |
+
return processors
|
| 647 |
+
|
| 648 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 649 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 650 |
+
r"""
|
| 651 |
+
Sets the attention processor to use to compute attention.
|
| 652 |
+
|
| 653 |
+
Parameters:
|
| 654 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 655 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 656 |
+
for **all** `Attention` layers.
|
| 657 |
+
|
| 658 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 659 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 660 |
+
|
| 661 |
+
"""
|
| 662 |
+
count = len(self.attn_processors.keys())
|
| 663 |
+
|
| 664 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 665 |
+
raise ValueError(
|
| 666 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 667 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 671 |
+
if hasattr(module, "set_processor"):
|
| 672 |
+
if not isinstance(processor, dict):
|
| 673 |
+
module.set_processor(processor)
|
| 674 |
+
else:
|
| 675 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 676 |
+
|
| 677 |
+
for sub_name, child in module.named_children():
|
| 678 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 679 |
+
|
| 680 |
+
for name, module in self.named_children():
|
| 681 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 682 |
+
|
| 683 |
+
def forward(
|
| 684 |
+
self,
|
| 685 |
+
hidden_states: torch.Tensor,
|
| 686 |
+
encoder_hidden_states: torch.Tensor,
|
| 687 |
+
timestep: Union[int, float, torch.LongTensor],
|
| 688 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 689 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 690 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 691 |
+
id_cond: Optional[torch.Tensor] = None,
|
| 692 |
+
id_vit_hidden: Optional[torch.Tensor] = None,
|
| 693 |
+
return_dict: bool = True,
|
| 694 |
+
):
|
| 695 |
+
if attention_kwargs is not None:
|
| 696 |
+
attention_kwargs = attention_kwargs.copy()
|
| 697 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 698 |
+
else:
|
| 699 |
+
lora_scale = 1.0
|
| 700 |
+
|
| 701 |
+
if USE_PEFT_BACKEND:
|
| 702 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 703 |
+
scale_lora_layers(self, lora_scale)
|
| 704 |
+
else:
|
| 705 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 706 |
+
logger.warning(
|
| 707 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# fuse clip and insightface
|
| 711 |
+
valid_face_emb = None
|
| 712 |
+
if self.is_train_face:
|
| 713 |
+
id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
| 714 |
+
id_vit_hidden = [
|
| 715 |
+
tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden
|
| 716 |
+
]
|
| 717 |
+
valid_face_emb = self.local_facial_extractor(
|
| 718 |
+
id_cond, id_vit_hidden
|
| 719 |
+
) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
|
| 720 |
+
|
| 721 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
| 722 |
+
|
| 723 |
+
# 1. Time embedding
|
| 724 |
+
timesteps = timestep
|
| 725 |
+
t_emb = self.time_proj(timesteps)
|
| 726 |
+
|
| 727 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 728 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 729 |
+
# there might be better ways to encapsulate this.
|
| 730 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
| 731 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 732 |
+
|
| 733 |
+
# 2. Patch embedding
|
| 734 |
+
# torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90])
|
| 735 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072])
|
| 736 |
+
hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072])
|
| 737 |
+
|
| 738 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 739 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
|
| 740 |
+
hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072])
|
| 741 |
+
|
| 742 |
+
# 3. Transformer blocks
|
| 743 |
+
ca_idx = 0
|
| 744 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 745 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 746 |
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
| 747 |
+
block,
|
| 748 |
+
hidden_states,
|
| 749 |
+
encoder_hidden_states,
|
| 750 |
+
emb,
|
| 751 |
+
image_rotary_emb,
|
| 752 |
+
)
|
| 753 |
+
else:
|
| 754 |
+
hidden_states, encoder_hidden_states = block(
|
| 755 |
+
hidden_states=hidden_states,
|
| 756 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 757 |
+
temb=emb,
|
| 758 |
+
image_rotary_emb=image_rotary_emb,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
if self.is_train_face:
|
| 762 |
+
if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
|
| 763 |
+
hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](
|
| 764 |
+
valid_face_emb, hidden_states
|
| 765 |
+
) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
|
| 766 |
+
ca_idx += 1
|
| 767 |
+
|
| 768 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 769 |
+
hidden_states = self.norm_final(hidden_states)
|
| 770 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 771 |
+
|
| 772 |
+
# 4. Final block
|
| 773 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
| 774 |
+
hidden_states = self.proj_out(hidden_states)
|
| 775 |
+
|
| 776 |
+
# 5. Unpatchify
|
| 777 |
+
# Note: we use `-1` instead of `channels`:
|
| 778 |
+
# - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels)
|
| 779 |
+
p = self.config.patch_size
|
| 780 |
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
| 781 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
| 782 |
+
|
| 783 |
+
if USE_PEFT_BACKEND:
|
| 784 |
+
# remove `lora_scale` from each PEFT layer
|
| 785 |
+
unscale_lora_layers(self, lora_scale)
|
| 786 |
+
|
| 787 |
+
if not return_dict:
|
| 788 |
+
return (output,)
|
| 789 |
+
return Transformer2DModelOutput(sample=output)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/dit_transformer_2d.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Any, Dict, Optional
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from ...utils import logging
|
| 22 |
+
from ..attention import BasicTransformerBlock
|
| 23 |
+
from ..embeddings import PatchEmbed
|
| 24 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 25 |
+
from ..modeling_utils import ModelMixin
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
| 32 |
+
r"""
|
| 33 |
+
A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).
|
| 34 |
+
|
| 35 |
+
Parameters:
|
| 36 |
+
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
|
| 37 |
+
attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
|
| 38 |
+
in_channels (int, defaults to 4): The number of channels in the input.
|
| 39 |
+
out_channels (int, optional):
|
| 40 |
+
The number of channels in the output. Specify this parameter if the output channel number differs from the
|
| 41 |
+
input.
|
| 42 |
+
num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
|
| 43 |
+
dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
|
| 44 |
+
norm_num_groups (int, optional, defaults to 32):
|
| 45 |
+
Number of groups for group normalization within Transformer blocks.
|
| 46 |
+
attention_bias (bool, optional, defaults to True):
|
| 47 |
+
Configure if the Transformer blocks' attention should contain a bias parameter.
|
| 48 |
+
sample_size (int, defaults to 32):
|
| 49 |
+
The width of the latent images. This parameter is fixed during training.
|
| 50 |
+
patch_size (int, defaults to 2):
|
| 51 |
+
Size of the patches the model processes, relevant for architectures working on non-sequential data.
|
| 52 |
+
activation_fn (str, optional, defaults to "gelu-approximate"):
|
| 53 |
+
Activation function to use in feed-forward networks within Transformer blocks.
|
| 54 |
+
num_embeds_ada_norm (int, optional, defaults to 1000):
|
| 55 |
+
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
|
| 56 |
+
inference.
|
| 57 |
+
upcast_attention (bool, optional, defaults to False):
|
| 58 |
+
If true, upcasts the attention mechanism dimensions for potentially improved performance.
|
| 59 |
+
norm_type (str, optional, defaults to "ada_norm_zero"):
|
| 60 |
+
Specifies the type of normalization used, can be 'ada_norm_zero'.
|
| 61 |
+
norm_elementwise_affine (bool, optional, defaults to False):
|
| 62 |
+
If true, enables element-wise affine parameters in the normalization layers.
|
| 63 |
+
norm_eps (float, optional, defaults to 1e-5):
|
| 64 |
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 68 |
+
_supports_gradient_checkpointing = True
|
| 69 |
+
_supports_group_offloading = False
|
| 70 |
+
|
| 71 |
+
@register_to_config
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
num_attention_heads: int = 16,
|
| 75 |
+
attention_head_dim: int = 72,
|
| 76 |
+
in_channels: int = 4,
|
| 77 |
+
out_channels: Optional[int] = None,
|
| 78 |
+
num_layers: int = 28,
|
| 79 |
+
dropout: float = 0.0,
|
| 80 |
+
norm_num_groups: int = 32,
|
| 81 |
+
attention_bias: bool = True,
|
| 82 |
+
sample_size: int = 32,
|
| 83 |
+
patch_size: int = 2,
|
| 84 |
+
activation_fn: str = "gelu-approximate",
|
| 85 |
+
num_embeds_ada_norm: Optional[int] = 1000,
|
| 86 |
+
upcast_attention: bool = False,
|
| 87 |
+
norm_type: str = "ada_norm_zero",
|
| 88 |
+
norm_elementwise_affine: bool = False,
|
| 89 |
+
norm_eps: float = 1e-5,
|
| 90 |
+
):
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
# Validate inputs.
|
| 94 |
+
if norm_type != "ada_norm_zero":
|
| 95 |
+
raise NotImplementedError(
|
| 96 |
+
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
|
| 97 |
+
)
|
| 98 |
+
elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Set some common variables used across the board.
|
| 104 |
+
self.attention_head_dim = attention_head_dim
|
| 105 |
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
| 106 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
| 107 |
+
self.gradient_checkpointing = False
|
| 108 |
+
|
| 109 |
+
# 2. Initialize the position embedding and transformer blocks.
|
| 110 |
+
self.height = self.config.sample_size
|
| 111 |
+
self.width = self.config.sample_size
|
| 112 |
+
|
| 113 |
+
self.patch_size = self.config.patch_size
|
| 114 |
+
self.pos_embed = PatchEmbed(
|
| 115 |
+
height=self.config.sample_size,
|
| 116 |
+
width=self.config.sample_size,
|
| 117 |
+
patch_size=self.config.patch_size,
|
| 118 |
+
in_channels=self.config.in_channels,
|
| 119 |
+
embed_dim=self.inner_dim,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.transformer_blocks = nn.ModuleList(
|
| 123 |
+
[
|
| 124 |
+
BasicTransformerBlock(
|
| 125 |
+
self.inner_dim,
|
| 126 |
+
self.config.num_attention_heads,
|
| 127 |
+
self.config.attention_head_dim,
|
| 128 |
+
dropout=self.config.dropout,
|
| 129 |
+
activation_fn=self.config.activation_fn,
|
| 130 |
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
| 131 |
+
attention_bias=self.config.attention_bias,
|
| 132 |
+
upcast_attention=self.config.upcast_attention,
|
| 133 |
+
norm_type=norm_type,
|
| 134 |
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
| 135 |
+
norm_eps=self.config.norm_eps,
|
| 136 |
+
)
|
| 137 |
+
for _ in range(self.config.num_layers)
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# 3. Output blocks.
|
| 142 |
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 143 |
+
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
| 144 |
+
self.proj_out_2 = nn.Linear(
|
| 145 |
+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def forward(
|
| 149 |
+
self,
|
| 150 |
+
hidden_states: torch.Tensor,
|
| 151 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 152 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 153 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 154 |
+
return_dict: bool = True,
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
The [`DiTTransformer2DModel`] forward method.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
| 161 |
+
Input `hidden_states`.
|
| 162 |
+
timestep ( `torch.LongTensor`, *optional*):
|
| 163 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
| 164 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
| 165 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
| 166 |
+
`AdaLayerZeroNorm`.
|
| 167 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
| 168 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 169 |
+
`self.processor` in
|
| 170 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 171 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 172 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 173 |
+
tuple.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 177 |
+
`tuple` where the first element is the sample tensor.
|
| 178 |
+
"""
|
| 179 |
+
# 1. Input
|
| 180 |
+
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
| 181 |
+
hidden_states = self.pos_embed(hidden_states)
|
| 182 |
+
|
| 183 |
+
# 2. Blocks
|
| 184 |
+
for block in self.transformer_blocks:
|
| 185 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 186 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 187 |
+
block,
|
| 188 |
+
hidden_states,
|
| 189 |
+
None,
|
| 190 |
+
None,
|
| 191 |
+
None,
|
| 192 |
+
timestep,
|
| 193 |
+
cross_attention_kwargs,
|
| 194 |
+
class_labels,
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
hidden_states = block(
|
| 198 |
+
hidden_states,
|
| 199 |
+
attention_mask=None,
|
| 200 |
+
encoder_hidden_states=None,
|
| 201 |
+
encoder_attention_mask=None,
|
| 202 |
+
timestep=timestep,
|
| 203 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 204 |
+
class_labels=class_labels,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# 3. Output
|
| 208 |
+
conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
|
| 209 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
| 210 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
| 211 |
+
hidden_states = self.proj_out_2(hidden_states)
|
| 212 |
+
|
| 213 |
+
# unpatchify
|
| 214 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
| 215 |
+
hidden_states = hidden_states.reshape(
|
| 216 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
| 217 |
+
)
|
| 218 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
| 219 |
+
output = hidden_states.reshape(
|
| 220 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if not return_dict:
|
| 224 |
+
return (output,)
|
| 225 |
+
|
| 226 |
+
return Transformer2DModelOutput(sample=output)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/dual_transformer_2d.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 19 |
+
from .transformer_2d import Transformer2DModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DualTransformer2DModel(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
| 25 |
+
|
| 26 |
+
Parameters:
|
| 27 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
| 28 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
| 29 |
+
in_channels (`int`, *optional*):
|
| 30 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
| 31 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
| 32 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
| 33 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
| 34 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
| 35 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
| 36 |
+
`ImagePositionalEmbeddings`.
|
| 37 |
+
num_vector_embeds (`int`, *optional*):
|
| 38 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
| 39 |
+
Includes the class for the masked latent pixel.
|
| 40 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 41 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
| 42 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
| 43 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
| 44 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
| 45 |
+
attention_bias (`bool`, *optional*):
|
| 46 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
num_attention_heads: int = 16,
|
| 52 |
+
attention_head_dim: int = 88,
|
| 53 |
+
in_channels: Optional[int] = None,
|
| 54 |
+
num_layers: int = 1,
|
| 55 |
+
dropout: float = 0.0,
|
| 56 |
+
norm_num_groups: int = 32,
|
| 57 |
+
cross_attention_dim: Optional[int] = None,
|
| 58 |
+
attention_bias: bool = False,
|
| 59 |
+
sample_size: Optional[int] = None,
|
| 60 |
+
num_vector_embeds: Optional[int] = None,
|
| 61 |
+
activation_fn: str = "geglu",
|
| 62 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 63 |
+
):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.transformers = nn.ModuleList(
|
| 66 |
+
[
|
| 67 |
+
Transformer2DModel(
|
| 68 |
+
num_attention_heads=num_attention_heads,
|
| 69 |
+
attention_head_dim=attention_head_dim,
|
| 70 |
+
in_channels=in_channels,
|
| 71 |
+
num_layers=num_layers,
|
| 72 |
+
dropout=dropout,
|
| 73 |
+
norm_num_groups=norm_num_groups,
|
| 74 |
+
cross_attention_dim=cross_attention_dim,
|
| 75 |
+
attention_bias=attention_bias,
|
| 76 |
+
sample_size=sample_size,
|
| 77 |
+
num_vector_embeds=num_vector_embeds,
|
| 78 |
+
activation_fn=activation_fn,
|
| 79 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 80 |
+
)
|
| 81 |
+
for _ in range(2)
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Variables that can be set by a pipeline:
|
| 86 |
+
|
| 87 |
+
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
| 88 |
+
self.mix_ratio = 0.5
|
| 89 |
+
|
| 90 |
+
# The shape of `encoder_hidden_states` is expected to be
|
| 91 |
+
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
| 92 |
+
self.condition_lengths = [77, 257]
|
| 93 |
+
|
| 94 |
+
# Which transformer to use to encode which condition.
|
| 95 |
+
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
| 96 |
+
self.transformer_index_for_condition = [1, 0]
|
| 97 |
+
|
| 98 |
+
def forward(
|
| 99 |
+
self,
|
| 100 |
+
hidden_states,
|
| 101 |
+
encoder_hidden_states,
|
| 102 |
+
timestep=None,
|
| 103 |
+
attention_mask=None,
|
| 104 |
+
cross_attention_kwargs=None,
|
| 105 |
+
return_dict: bool = True,
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Args:
|
| 109 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
| 110 |
+
When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states.
|
| 111 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
| 112 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
| 113 |
+
self-attention.
|
| 114 |
+
timestep ( `torch.long`, *optional*):
|
| 115 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
| 116 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 117 |
+
Optional attention mask to be applied in Attention.
|
| 118 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 119 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 120 |
+
`self.processor` in
|
| 121 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 122 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 123 |
+
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 124 |
+
tuple.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
[`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
| 128 |
+
[`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a
|
| 129 |
+
`tuple`. When returning a tuple, the first element is the sample tensor.
|
| 130 |
+
"""
|
| 131 |
+
input_states = hidden_states
|
| 132 |
+
|
| 133 |
+
encoded_states = []
|
| 134 |
+
tokens_start = 0
|
| 135 |
+
# attention_mask is not used yet
|
| 136 |
+
for i in range(2):
|
| 137 |
+
# for each of the two transformers, pass the corresponding condition tokens
|
| 138 |
+
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
| 139 |
+
transformer_index = self.transformer_index_for_condition[i]
|
| 140 |
+
encoded_state = self.transformers[transformer_index](
|
| 141 |
+
input_states,
|
| 142 |
+
encoder_hidden_states=condition_state,
|
| 143 |
+
timestep=timestep,
|
| 144 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 145 |
+
return_dict=False,
|
| 146 |
+
)[0]
|
| 147 |
+
encoded_states.append(encoded_state - input_states)
|
| 148 |
+
tokens_start += self.condition_lengths[i]
|
| 149 |
+
|
| 150 |
+
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
| 151 |
+
output_states = output_states + input_states
|
| 152 |
+
|
| 153 |
+
if not return_dict:
|
| 154 |
+
return (output_states,)
|
| 155 |
+
|
| 156 |
+
return Transformer2DModelOutput(sample=output_states)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/hunyuan_transformer_2d.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Dict, Optional, Union
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from ...utils.torch_utils import maybe_allow_in_graph
|
| 22 |
+
from ..attention import FeedForward
|
| 23 |
+
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
|
| 24 |
+
from ..embeddings import (
|
| 25 |
+
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
| 26 |
+
PatchEmbed,
|
| 27 |
+
PixArtAlphaTextProjection,
|
| 28 |
+
)
|
| 29 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 30 |
+
from ..modeling_utils import ModelMixin
|
| 31 |
+
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AdaLayerNormShift(nn.Module):
|
| 38 |
+
r"""
|
| 39 |
+
Norm layer modified to incorporate timestep embeddings.
|
| 40 |
+
|
| 41 |
+
Parameters:
|
| 42 |
+
embedding_dim (`int`): The size of each embedding vector.
|
| 43 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.silu = nn.SiLU()
|
| 49 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim)
|
| 50 |
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
|
| 54 |
+
x = self.norm(x) + shift.unsqueeze(dim=1)
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@maybe_allow_in_graph
|
| 59 |
+
class HunyuanDiTBlock(nn.Module):
|
| 60 |
+
r"""
|
| 61 |
+
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
|
| 62 |
+
QKNorm
|
| 63 |
+
|
| 64 |
+
Parameters:
|
| 65 |
+
dim (`int`):
|
| 66 |
+
The number of channels in the input and output.
|
| 67 |
+
num_attention_heads (`int`):
|
| 68 |
+
The number of headsto use for multi-head attention.
|
| 69 |
+
cross_attention_dim (`int`,*optional*):
|
| 70 |
+
The size of the encoder_hidden_states vector for cross attention.
|
| 71 |
+
dropout(`float`, *optional*, defaults to 0.0):
|
| 72 |
+
The dropout probability to use.
|
| 73 |
+
activation_fn (`str`,*optional*, defaults to `"geglu"`):
|
| 74 |
+
Activation function to be used in feed-forward. .
|
| 75 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
| 76 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 77 |
+
norm_eps (`float`, *optional*, defaults to 1e-6):
|
| 78 |
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
| 79 |
+
final_dropout (`bool` *optional*, defaults to False):
|
| 80 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 81 |
+
ff_inner_dim (`int`, *optional*):
|
| 82 |
+
The size of the hidden layer in the feed-forward block. Defaults to `None`.
|
| 83 |
+
ff_bias (`bool`, *optional*, defaults to `True`):
|
| 84 |
+
Whether to use bias in the feed-forward block.
|
| 85 |
+
skip (`bool`, *optional*, defaults to `False`):
|
| 86 |
+
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
|
| 87 |
+
qk_norm (`bool`, *optional*, defaults to `True`):
|
| 88 |
+
Whether to use normalization in QK calculation. Defaults to `True`.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
dim: int,
|
| 94 |
+
num_attention_heads: int,
|
| 95 |
+
cross_attention_dim: int = 1024,
|
| 96 |
+
dropout=0.0,
|
| 97 |
+
activation_fn: str = "geglu",
|
| 98 |
+
norm_elementwise_affine: bool = True,
|
| 99 |
+
norm_eps: float = 1e-6,
|
| 100 |
+
final_dropout: bool = False,
|
| 101 |
+
ff_inner_dim: Optional[int] = None,
|
| 102 |
+
ff_bias: bool = True,
|
| 103 |
+
skip: bool = False,
|
| 104 |
+
qk_norm: bool = True,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 109 |
+
# NOTE: when new version comes, check norm2 and norm 3
|
| 110 |
+
# 1. Self-Attn
|
| 111 |
+
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
| 112 |
+
|
| 113 |
+
self.attn1 = Attention(
|
| 114 |
+
query_dim=dim,
|
| 115 |
+
cross_attention_dim=None,
|
| 116 |
+
dim_head=dim // num_attention_heads,
|
| 117 |
+
heads=num_attention_heads,
|
| 118 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 119 |
+
eps=1e-6,
|
| 120 |
+
bias=True,
|
| 121 |
+
processor=HunyuanAttnProcessor2_0(),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# 2. Cross-Attn
|
| 125 |
+
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
| 126 |
+
|
| 127 |
+
self.attn2 = Attention(
|
| 128 |
+
query_dim=dim,
|
| 129 |
+
cross_attention_dim=cross_attention_dim,
|
| 130 |
+
dim_head=dim // num_attention_heads,
|
| 131 |
+
heads=num_attention_heads,
|
| 132 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 133 |
+
eps=1e-6,
|
| 134 |
+
bias=True,
|
| 135 |
+
processor=HunyuanAttnProcessor2_0(),
|
| 136 |
+
)
|
| 137 |
+
# 3. Feed-forward
|
| 138 |
+
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
| 139 |
+
|
| 140 |
+
self.ff = FeedForward(
|
| 141 |
+
dim,
|
| 142 |
+
dropout=dropout, ### 0.0
|
| 143 |
+
activation_fn=activation_fn, ### approx GeLU
|
| 144 |
+
final_dropout=final_dropout, ### 0.0
|
| 145 |
+
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
|
| 146 |
+
bias=ff_bias,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# 4. Skip Connection
|
| 150 |
+
if skip:
|
| 151 |
+
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
|
| 152 |
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
| 153 |
+
else:
|
| 154 |
+
self.skip_linear = None
|
| 155 |
+
|
| 156 |
+
# let chunk size default to None
|
| 157 |
+
self._chunk_size = None
|
| 158 |
+
self._chunk_dim = 0
|
| 159 |
+
|
| 160 |
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
| 161 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
| 162 |
+
# Sets chunk feed-forward
|
| 163 |
+
self._chunk_size = chunk_size
|
| 164 |
+
self._chunk_dim = dim
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
hidden_states: torch.Tensor,
|
| 169 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 170 |
+
temb: Optional[torch.Tensor] = None,
|
| 171 |
+
image_rotary_emb=None,
|
| 172 |
+
skip=None,
|
| 173 |
+
) -> torch.Tensor:
|
| 174 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 175 |
+
# 0. Long Skip Connection
|
| 176 |
+
if self.skip_linear is not None:
|
| 177 |
+
cat = torch.cat([hidden_states, skip], dim=-1)
|
| 178 |
+
cat = self.skip_norm(cat)
|
| 179 |
+
hidden_states = self.skip_linear(cat)
|
| 180 |
+
|
| 181 |
+
# 1. Self-Attention
|
| 182 |
+
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
|
| 183 |
+
attn_output = self.attn1(
|
| 184 |
+
norm_hidden_states,
|
| 185 |
+
image_rotary_emb=image_rotary_emb,
|
| 186 |
+
)
|
| 187 |
+
hidden_states = hidden_states + attn_output
|
| 188 |
+
|
| 189 |
+
# 2. Cross-Attention
|
| 190 |
+
hidden_states = hidden_states + self.attn2(
|
| 191 |
+
self.norm2(hidden_states),
|
| 192 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 193 |
+
image_rotary_emb=image_rotary_emb,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
|
| 197 |
+
mlp_inputs = self.norm3(hidden_states)
|
| 198 |
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
| 199 |
+
|
| 200 |
+
return hidden_states
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
| 204 |
+
"""
|
| 205 |
+
HunYuanDiT: Diffusion model with a Transformer backbone.
|
| 206 |
+
|
| 207 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
| 208 |
+
|
| 209 |
+
Parameters:
|
| 210 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 211 |
+
The number of heads to use for multi-head attention.
|
| 212 |
+
attention_head_dim (`int`, *optional*, defaults to 88):
|
| 213 |
+
The number of channels in each head.
|
| 214 |
+
in_channels (`int`, *optional*):
|
| 215 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
| 216 |
+
patch_size (`int`, *optional*):
|
| 217 |
+
The size of the patch to use for the input.
|
| 218 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
| 219 |
+
Activation function to use in feed-forward.
|
| 220 |
+
sample_size (`int`, *optional*):
|
| 221 |
+
The width of the latent images. This is fixed during training since it is used to learn a number of
|
| 222 |
+
position embeddings.
|
| 223 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 224 |
+
The dropout probability to use.
|
| 225 |
+
cross_attention_dim (`int`, *optional*):
|
| 226 |
+
The number of dimension in the clip text embedding.
|
| 227 |
+
hidden_size (`int`, *optional*):
|
| 228 |
+
The size of hidden layer in the conditioning embedding layers.
|
| 229 |
+
num_layers (`int`, *optional*, defaults to 1):
|
| 230 |
+
The number of layers of Transformer blocks to use.
|
| 231 |
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
| 232 |
+
The ratio of the hidden layer size to the input size.
|
| 233 |
+
learn_sigma (`bool`, *optional*, defaults to `True`):
|
| 234 |
+
Whether to predict variance.
|
| 235 |
+
cross_attention_dim_t5 (`int`, *optional*):
|
| 236 |
+
The number dimensions in t5 text embedding.
|
| 237 |
+
pooled_projection_dim (`int`, *optional*):
|
| 238 |
+
The size of the pooled projection.
|
| 239 |
+
text_len (`int`, *optional*):
|
| 240 |
+
The length of the clip text embedding.
|
| 241 |
+
text_len_t5 (`int`, *optional*):
|
| 242 |
+
The length of the T5 text embedding.
|
| 243 |
+
use_style_cond_and_image_meta_size (`bool`, *optional*):
|
| 244 |
+
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
|
| 248 |
+
_supports_group_offloading = False
|
| 249 |
+
|
| 250 |
+
@register_to_config
|
| 251 |
+
def __init__(
|
| 252 |
+
self,
|
| 253 |
+
num_attention_heads: int = 16,
|
| 254 |
+
attention_head_dim: int = 88,
|
| 255 |
+
in_channels: Optional[int] = None,
|
| 256 |
+
patch_size: Optional[int] = None,
|
| 257 |
+
activation_fn: str = "gelu-approximate",
|
| 258 |
+
sample_size=32,
|
| 259 |
+
hidden_size=1152,
|
| 260 |
+
num_layers: int = 28,
|
| 261 |
+
mlp_ratio: float = 4.0,
|
| 262 |
+
learn_sigma: bool = True,
|
| 263 |
+
cross_attention_dim: int = 1024,
|
| 264 |
+
norm_type: str = "layer_norm",
|
| 265 |
+
cross_attention_dim_t5: int = 2048,
|
| 266 |
+
pooled_projection_dim: int = 1024,
|
| 267 |
+
text_len: int = 77,
|
| 268 |
+
text_len_t5: int = 256,
|
| 269 |
+
use_style_cond_and_image_meta_size: bool = True,
|
| 270 |
+
):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 273 |
+
self.num_heads = num_attention_heads
|
| 274 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 275 |
+
|
| 276 |
+
self.text_embedder = PixArtAlphaTextProjection(
|
| 277 |
+
in_features=cross_attention_dim_t5,
|
| 278 |
+
hidden_size=cross_attention_dim_t5 * 4,
|
| 279 |
+
out_features=cross_attention_dim,
|
| 280 |
+
act_fn="silu_fp32",
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
|
| 284 |
+
|
| 285 |
+
self.pos_embed = PatchEmbed(
|
| 286 |
+
height=sample_size,
|
| 287 |
+
width=sample_size,
|
| 288 |
+
in_channels=in_channels,
|
| 289 |
+
embed_dim=hidden_size,
|
| 290 |
+
patch_size=patch_size,
|
| 291 |
+
pos_embed_type=None,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
|
| 295 |
+
hidden_size,
|
| 296 |
+
pooled_projection_dim=pooled_projection_dim,
|
| 297 |
+
seq_len=text_len_t5,
|
| 298 |
+
cross_attention_dim=cross_attention_dim_t5,
|
| 299 |
+
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# HunyuanDiT Blocks
|
| 303 |
+
self.blocks = nn.ModuleList(
|
| 304 |
+
[
|
| 305 |
+
HunyuanDiTBlock(
|
| 306 |
+
dim=self.inner_dim,
|
| 307 |
+
num_attention_heads=self.config.num_attention_heads,
|
| 308 |
+
activation_fn=activation_fn,
|
| 309 |
+
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
| 310 |
+
cross_attention_dim=cross_attention_dim,
|
| 311 |
+
qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
|
| 312 |
+
skip=layer > num_layers // 2,
|
| 313 |
+
)
|
| 314 |
+
for layer in range(num_layers)
|
| 315 |
+
]
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 319 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 320 |
+
|
| 321 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0
|
| 322 |
+
def fuse_qkv_projections(self):
|
| 323 |
+
"""
|
| 324 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 325 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 326 |
+
|
| 327 |
+
<Tip warning={true}>
|
| 328 |
+
|
| 329 |
+
This API is 🧪 experimental.
|
| 330 |
+
|
| 331 |
+
</Tip>
|
| 332 |
+
"""
|
| 333 |
+
self.original_attn_processors = None
|
| 334 |
+
|
| 335 |
+
for _, attn_processor in self.attn_processors.items():
|
| 336 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 337 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 338 |
+
|
| 339 |
+
self.original_attn_processors = self.attn_processors
|
| 340 |
+
|
| 341 |
+
for module in self.modules():
|
| 342 |
+
if isinstance(module, Attention):
|
| 343 |
+
module.fuse_projections(fuse=True)
|
| 344 |
+
|
| 345 |
+
self.set_attn_processor(FusedHunyuanAttnProcessor2_0())
|
| 346 |
+
|
| 347 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 348 |
+
def unfuse_qkv_projections(self):
|
| 349 |
+
"""Disables the fused QKV projection if enabled.
|
| 350 |
+
|
| 351 |
+
<Tip warning={true}>
|
| 352 |
+
|
| 353 |
+
This API is 🧪 experimental.
|
| 354 |
+
|
| 355 |
+
</Tip>
|
| 356 |
+
|
| 357 |
+
"""
|
| 358 |
+
if self.original_attn_processors is not None:
|
| 359 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 360 |
+
|
| 361 |
+
@property
|
| 362 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 363 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 364 |
+
r"""
|
| 365 |
+
Returns:
|
| 366 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 367 |
+
indexed by its weight name.
|
| 368 |
+
"""
|
| 369 |
+
# set recursively
|
| 370 |
+
processors = {}
|
| 371 |
+
|
| 372 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 373 |
+
if hasattr(module, "get_processor"):
|
| 374 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 375 |
+
|
| 376 |
+
for sub_name, child in module.named_children():
|
| 377 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 378 |
+
|
| 379 |
+
return processors
|
| 380 |
+
|
| 381 |
+
for name, module in self.named_children():
|
| 382 |
+
fn_recursive_add_processors(name, module, processors)
|
| 383 |
+
|
| 384 |
+
return processors
|
| 385 |
+
|
| 386 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 387 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 388 |
+
r"""
|
| 389 |
+
Sets the attention processor to use to compute attention.
|
| 390 |
+
|
| 391 |
+
Parameters:
|
| 392 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 393 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 394 |
+
for **all** `Attention` layers.
|
| 395 |
+
|
| 396 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 397 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 398 |
+
|
| 399 |
+
"""
|
| 400 |
+
count = len(self.attn_processors.keys())
|
| 401 |
+
|
| 402 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 403 |
+
raise ValueError(
|
| 404 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 405 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 409 |
+
if hasattr(module, "set_processor"):
|
| 410 |
+
if not isinstance(processor, dict):
|
| 411 |
+
module.set_processor(processor)
|
| 412 |
+
else:
|
| 413 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 414 |
+
|
| 415 |
+
for sub_name, child in module.named_children():
|
| 416 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 417 |
+
|
| 418 |
+
for name, module in self.named_children():
|
| 419 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 420 |
+
|
| 421 |
+
def set_default_attn_processor(self):
|
| 422 |
+
"""
|
| 423 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 424 |
+
"""
|
| 425 |
+
self.set_attn_processor(HunyuanAttnProcessor2_0())
|
| 426 |
+
|
| 427 |
+
def forward(
|
| 428 |
+
self,
|
| 429 |
+
hidden_states,
|
| 430 |
+
timestep,
|
| 431 |
+
encoder_hidden_states=None,
|
| 432 |
+
text_embedding_mask=None,
|
| 433 |
+
encoder_hidden_states_t5=None,
|
| 434 |
+
text_embedding_mask_t5=None,
|
| 435 |
+
image_meta_size=None,
|
| 436 |
+
style=None,
|
| 437 |
+
image_rotary_emb=None,
|
| 438 |
+
controlnet_block_samples=None,
|
| 439 |
+
return_dict=True,
|
| 440 |
+
):
|
| 441 |
+
"""
|
| 442 |
+
The [`HunyuanDiT2DModel`] forward method.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
| 446 |
+
The input tensor.
|
| 447 |
+
timestep ( `torch.LongTensor`, *optional*):
|
| 448 |
+
Used to indicate denoising step.
|
| 449 |
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
| 450 |
+
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
|
| 451 |
+
text_embedding_mask: torch.Tensor
|
| 452 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
| 453 |
+
of `BertModel`.
|
| 454 |
+
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
| 455 |
+
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
|
| 456 |
+
text_embedding_mask_t5: torch.Tensor
|
| 457 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
| 458 |
+
of T5 Text Encoder.
|
| 459 |
+
image_meta_size (torch.Tensor):
|
| 460 |
+
Conditional embedding indicate the image sizes
|
| 461 |
+
style: torch.Tensor:
|
| 462 |
+
Conditional embedding indicate the style
|
| 463 |
+
image_rotary_emb (`torch.Tensor`):
|
| 464 |
+
The image rotary embeddings to apply on query and key tensors during attention calculation.
|
| 465 |
+
return_dict: bool
|
| 466 |
+
Whether to return a dictionary.
|
| 467 |
+
"""
|
| 468 |
+
|
| 469 |
+
height, width = hidden_states.shape[-2:]
|
| 470 |
+
|
| 471 |
+
hidden_states = self.pos_embed(hidden_states)
|
| 472 |
+
|
| 473 |
+
temb = self.time_extra_emb(
|
| 474 |
+
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
|
| 475 |
+
) # [B, D]
|
| 476 |
+
|
| 477 |
+
# text projection
|
| 478 |
+
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
| 479 |
+
encoder_hidden_states_t5 = self.text_embedder(
|
| 480 |
+
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
| 481 |
+
)
|
| 482 |
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
|
| 483 |
+
|
| 484 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
|
| 485 |
+
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
|
| 486 |
+
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
| 487 |
+
|
| 488 |
+
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
|
| 489 |
+
|
| 490 |
+
skips = []
|
| 491 |
+
for layer, block in enumerate(self.blocks):
|
| 492 |
+
if layer > self.config.num_layers // 2:
|
| 493 |
+
if controlnet_block_samples is not None:
|
| 494 |
+
skip = skips.pop() + controlnet_block_samples.pop()
|
| 495 |
+
else:
|
| 496 |
+
skip = skips.pop()
|
| 497 |
+
hidden_states = block(
|
| 498 |
+
hidden_states,
|
| 499 |
+
temb=temb,
|
| 500 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 501 |
+
image_rotary_emb=image_rotary_emb,
|
| 502 |
+
skip=skip,
|
| 503 |
+
) # (N, L, D)
|
| 504 |
+
else:
|
| 505 |
+
hidden_states = block(
|
| 506 |
+
hidden_states,
|
| 507 |
+
temb=temb,
|
| 508 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 509 |
+
image_rotary_emb=image_rotary_emb,
|
| 510 |
+
) # (N, L, D)
|
| 511 |
+
|
| 512 |
+
if layer < (self.config.num_layers // 2 - 1):
|
| 513 |
+
skips.append(hidden_states)
|
| 514 |
+
|
| 515 |
+
if controlnet_block_samples is not None and len(controlnet_block_samples) != 0:
|
| 516 |
+
raise ValueError("The number of controls is not equal to the number of skip connections.")
|
| 517 |
+
|
| 518 |
+
# final layer
|
| 519 |
+
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
|
| 520 |
+
hidden_states = self.proj_out(hidden_states)
|
| 521 |
+
# (N, L, patch_size ** 2 * out_channels)
|
| 522 |
+
|
| 523 |
+
# unpatchify: (N, out_channels, H, W)
|
| 524 |
+
patch_size = self.pos_embed.patch_size
|
| 525 |
+
height = height // patch_size
|
| 526 |
+
width = width // patch_size
|
| 527 |
+
|
| 528 |
+
hidden_states = hidden_states.reshape(
|
| 529 |
+
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
| 530 |
+
)
|
| 531 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
| 532 |
+
output = hidden_states.reshape(
|
| 533 |
+
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
| 534 |
+
)
|
| 535 |
+
if not return_dict:
|
| 536 |
+
return (output,)
|
| 537 |
+
return Transformer2DModelOutput(sample=output)
|
| 538 |
+
|
| 539 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
| 540 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
| 541 |
+
"""
|
| 542 |
+
Sets the attention processor to use [feed forward
|
| 543 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
| 544 |
+
|
| 545 |
+
Parameters:
|
| 546 |
+
chunk_size (`int`, *optional*):
|
| 547 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
| 548 |
+
over each tensor of dim=`dim`.
|
| 549 |
+
dim (`int`, *optional*, defaults to `0`):
|
| 550 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
| 551 |
+
or dim=1 (sequence length).
|
| 552 |
+
"""
|
| 553 |
+
if dim not in [0, 1]:
|
| 554 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
| 555 |
+
|
| 556 |
+
# By default chunk size is 1
|
| 557 |
+
chunk_size = chunk_size or 1
|
| 558 |
+
|
| 559 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
| 560 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
| 561 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
| 562 |
+
|
| 563 |
+
for child in module.children():
|
| 564 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
| 565 |
+
|
| 566 |
+
for module in self.children():
|
| 567 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
| 568 |
+
|
| 569 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
| 570 |
+
def disable_forward_chunking(self):
|
| 571 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
| 572 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
| 573 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
| 574 |
+
|
| 575 |
+
for child in module.children():
|
| 576 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
| 577 |
+
|
| 578 |
+
for module in self.children():
|
| 579 |
+
fn_recursive_feed_forward(module, None, 0)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/latte_transformer_3d.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 the Latte Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from ..attention import BasicTransformerBlock
|
| 22 |
+
from ..cache_utils import CacheMixin
|
| 23 |
+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
| 24 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 25 |
+
from ..modeling_utils import ModelMixin
|
| 26 |
+
from ..normalization import AdaLayerNormSingle
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
| 30 |
+
_supports_gradient_checkpointing = True
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
|
| 34 |
+
https://github.com/Vchitect/Latte
|
| 35 |
+
|
| 36 |
+
Parameters:
|
| 37 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
| 38 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
| 39 |
+
in_channels (`int`, *optional*):
|
| 40 |
+
The number of channels in the input.
|
| 41 |
+
out_channels (`int`, *optional*):
|
| 42 |
+
The number of channels in the output.
|
| 43 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
| 44 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 45 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
| 46 |
+
attention_bias (`bool`, *optional*):
|
| 47 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
| 48 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
| 49 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
| 50 |
+
patch_size (`int`, *optional*):
|
| 51 |
+
The size of the patches to use in the patch embedding layer.
|
| 52 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
| 53 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
| 54 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
| 55 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
| 56 |
+
added to the hidden states. During inference, you can denoise for up to but not more steps than
|
| 57 |
+
`num_embeds_ada_norm`.
|
| 58 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
| 59 |
+
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
|
| 60 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
| 61 |
+
Whether or not to use elementwise affine in normalization layers.
|
| 62 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
|
| 63 |
+
caption_channels (`int`, *optional*):
|
| 64 |
+
The number of channels in the caption embeddings.
|
| 65 |
+
video_length (`int`, *optional*):
|
| 66 |
+
The number of frames in the video-like data.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 70 |
+
|
| 71 |
+
@register_to_config
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
num_attention_heads: int = 16,
|
| 75 |
+
attention_head_dim: int = 88,
|
| 76 |
+
in_channels: Optional[int] = None,
|
| 77 |
+
out_channels: Optional[int] = None,
|
| 78 |
+
num_layers: int = 1,
|
| 79 |
+
dropout: float = 0.0,
|
| 80 |
+
cross_attention_dim: Optional[int] = None,
|
| 81 |
+
attention_bias: bool = False,
|
| 82 |
+
sample_size: int = 64,
|
| 83 |
+
patch_size: Optional[int] = None,
|
| 84 |
+
activation_fn: str = "geglu",
|
| 85 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 86 |
+
norm_type: str = "layer_norm",
|
| 87 |
+
norm_elementwise_affine: bool = True,
|
| 88 |
+
norm_eps: float = 1e-5,
|
| 89 |
+
caption_channels: int = None,
|
| 90 |
+
video_length: int = 16,
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 94 |
+
|
| 95 |
+
# 1. Define input layers
|
| 96 |
+
self.height = sample_size
|
| 97 |
+
self.width = sample_size
|
| 98 |
+
|
| 99 |
+
interpolation_scale = self.config.sample_size // 64
|
| 100 |
+
interpolation_scale = max(interpolation_scale, 1)
|
| 101 |
+
self.pos_embed = PatchEmbed(
|
| 102 |
+
height=sample_size,
|
| 103 |
+
width=sample_size,
|
| 104 |
+
patch_size=patch_size,
|
| 105 |
+
in_channels=in_channels,
|
| 106 |
+
embed_dim=inner_dim,
|
| 107 |
+
interpolation_scale=interpolation_scale,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# 2. Define spatial transformers blocks
|
| 111 |
+
self.transformer_blocks = nn.ModuleList(
|
| 112 |
+
[
|
| 113 |
+
BasicTransformerBlock(
|
| 114 |
+
inner_dim,
|
| 115 |
+
num_attention_heads,
|
| 116 |
+
attention_head_dim,
|
| 117 |
+
dropout=dropout,
|
| 118 |
+
cross_attention_dim=cross_attention_dim,
|
| 119 |
+
activation_fn=activation_fn,
|
| 120 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 121 |
+
attention_bias=attention_bias,
|
| 122 |
+
norm_type=norm_type,
|
| 123 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 124 |
+
norm_eps=norm_eps,
|
| 125 |
+
)
|
| 126 |
+
for d in range(num_layers)
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# 3. Define temporal transformers blocks
|
| 131 |
+
self.temporal_transformer_blocks = nn.ModuleList(
|
| 132 |
+
[
|
| 133 |
+
BasicTransformerBlock(
|
| 134 |
+
inner_dim,
|
| 135 |
+
num_attention_heads,
|
| 136 |
+
attention_head_dim,
|
| 137 |
+
dropout=dropout,
|
| 138 |
+
cross_attention_dim=None,
|
| 139 |
+
activation_fn=activation_fn,
|
| 140 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 141 |
+
attention_bias=attention_bias,
|
| 142 |
+
norm_type=norm_type,
|
| 143 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 144 |
+
norm_eps=norm_eps,
|
| 145 |
+
)
|
| 146 |
+
for d in range(num_layers)
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# 4. Define output layers
|
| 151 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
| 152 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 153 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
| 154 |
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
| 155 |
+
|
| 156 |
+
# 5. Latte other blocks.
|
| 157 |
+
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
|
| 158 |
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
| 159 |
+
|
| 160 |
+
# define temporal positional embedding
|
| 161 |
+
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
|
| 162 |
+
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
|
| 163 |
+
) # 1152 hidden size
|
| 164 |
+
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
|
| 165 |
+
|
| 166 |
+
self.gradient_checkpointing = False
|
| 167 |
+
|
| 168 |
+
def forward(
|
| 169 |
+
self,
|
| 170 |
+
hidden_states: torch.Tensor,
|
| 171 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 172 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 173 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 174 |
+
enable_temporal_attentions: bool = True,
|
| 175 |
+
return_dict: bool = True,
|
| 176 |
+
):
|
| 177 |
+
"""
|
| 178 |
+
The [`LatteTransformer3DModel`] forward method.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
hidden_states shape `(batch size, channel, num_frame, height, width)`:
|
| 182 |
+
Input `hidden_states`.
|
| 183 |
+
timestep ( `torch.LongTensor`, *optional*):
|
| 184 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
| 185 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
| 186 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
| 187 |
+
self-attention.
|
| 188 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
| 189 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
| 190 |
+
|
| 191 |
+
* Mask `(batcheight, sequence_length)` True = keep, False = discard.
|
| 192 |
+
* Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.
|
| 193 |
+
|
| 194 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
| 195 |
+
above. This bias will be added to the cross-attention scores.
|
| 196 |
+
enable_temporal_attentions:
|
| 197 |
+
(`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
|
| 198 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 199 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 200 |
+
tuple.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 204 |
+
`tuple` where the first element is the sample tensor.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
# Reshape hidden states
|
| 208 |
+
batch_size, channels, num_frame, height, width = hidden_states.shape
|
| 209 |
+
# batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
|
| 210 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
|
| 211 |
+
|
| 212 |
+
# Input
|
| 213 |
+
height, width = (
|
| 214 |
+
hidden_states.shape[-2] // self.config.patch_size,
|
| 215 |
+
hidden_states.shape[-1] // self.config.patch_size,
|
| 216 |
+
)
|
| 217 |
+
num_patches = height * width
|
| 218 |
+
|
| 219 |
+
hidden_states = self.pos_embed(hidden_states) # already add positional embeddings
|
| 220 |
+
|
| 221 |
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
| 222 |
+
timestep, embedded_timestep = self.adaln_single(
|
| 223 |
+
timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Prepare text embeddings for spatial block
|
| 227 |
+
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
| 228 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
|
| 229 |
+
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
|
| 230 |
+
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
|
| 231 |
+
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
|
| 232 |
+
|
| 233 |
+
# Prepare timesteps for spatial and temporal block
|
| 234 |
+
timestep_spatial = timestep.repeat_interleave(
|
| 235 |
+
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
|
| 236 |
+
).view(-1, timestep.shape[-1])
|
| 237 |
+
timestep_temp = timestep.repeat_interleave(
|
| 238 |
+
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
|
| 239 |
+
).view(-1, timestep.shape[-1])
|
| 240 |
+
|
| 241 |
+
# Spatial and temporal transformer blocks
|
| 242 |
+
for i, (spatial_block, temp_block) in enumerate(
|
| 243 |
+
zip(self.transformer_blocks, self.temporal_transformer_blocks)
|
| 244 |
+
):
|
| 245 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 246 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 247 |
+
spatial_block,
|
| 248 |
+
hidden_states,
|
| 249 |
+
None, # attention_mask
|
| 250 |
+
encoder_hidden_states_spatial,
|
| 251 |
+
encoder_attention_mask,
|
| 252 |
+
timestep_spatial,
|
| 253 |
+
None, # cross_attention_kwargs
|
| 254 |
+
None, # class_labels
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
hidden_states = spatial_block(
|
| 258 |
+
hidden_states,
|
| 259 |
+
None, # attention_mask
|
| 260 |
+
encoder_hidden_states_spatial,
|
| 261 |
+
encoder_attention_mask,
|
| 262 |
+
timestep_spatial,
|
| 263 |
+
None, # cross_attention_kwargs
|
| 264 |
+
None, # class_labels
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if enable_temporal_attentions:
|
| 268 |
+
# (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
|
| 269 |
+
hidden_states = hidden_states.reshape(
|
| 270 |
+
batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
|
| 271 |
+
).permute(0, 2, 1, 3)
|
| 272 |
+
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
| 273 |
+
|
| 274 |
+
if i == 0 and num_frame > 1:
|
| 275 |
+
hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
|
| 276 |
+
|
| 277 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 278 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 279 |
+
temp_block,
|
| 280 |
+
hidden_states,
|
| 281 |
+
None, # attention_mask
|
| 282 |
+
None, # encoder_hidden_states
|
| 283 |
+
None, # encoder_attention_mask
|
| 284 |
+
timestep_temp,
|
| 285 |
+
None, # cross_attention_kwargs
|
| 286 |
+
None, # class_labels
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
hidden_states = temp_block(
|
| 290 |
+
hidden_states,
|
| 291 |
+
None, # attention_mask
|
| 292 |
+
None, # encoder_hidden_states
|
| 293 |
+
None, # encoder_attention_mask
|
| 294 |
+
timestep_temp,
|
| 295 |
+
None, # cross_attention_kwargs
|
| 296 |
+
None, # class_labels
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
| 300 |
+
hidden_states = hidden_states.reshape(
|
| 301 |
+
batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
|
| 302 |
+
).permute(0, 2, 1, 3)
|
| 303 |
+
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
| 304 |
+
|
| 305 |
+
embedded_timestep = embedded_timestep.repeat_interleave(
|
| 306 |
+
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
|
| 307 |
+
).view(-1, embedded_timestep.shape[-1])
|
| 308 |
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
| 309 |
+
hidden_states = self.norm_out(hidden_states)
|
| 310 |
+
# Modulation
|
| 311 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
| 312 |
+
hidden_states = self.proj_out(hidden_states)
|
| 313 |
+
|
| 314 |
+
# unpatchify
|
| 315 |
+
if self.adaln_single is None:
|
| 316 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
| 317 |
+
hidden_states = hidden_states.reshape(
|
| 318 |
+
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
|
| 319 |
+
)
|
| 320 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
| 321 |
+
output = hidden_states.reshape(
|
| 322 |
+
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
|
| 323 |
+
)
|
| 324 |
+
output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute(
|
| 325 |
+
0, 2, 1, 3, 4
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if not return_dict:
|
| 329 |
+
return (output,)
|
| 330 |
+
|
| 331 |
+
return Transformer2DModelOutput(sample=output)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/t5_film_transformer.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import math
|
| 15 |
+
from typing import Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from ..attention_processor import Attention
|
| 22 |
+
from ..embeddings import get_timestep_embedding
|
| 23 |
+
from ..modeling_utils import ModelMixin
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
| 27 |
+
r"""
|
| 28 |
+
T5 style decoder with FiLM conditioning.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
input_dims (`int`, *optional*, defaults to `128`):
|
| 32 |
+
The number of input dimensions.
|
| 33 |
+
targets_length (`int`, *optional*, defaults to `256`):
|
| 34 |
+
The length of the targets.
|
| 35 |
+
d_model (`int`, *optional*, defaults to `768`):
|
| 36 |
+
Size of the input hidden states.
|
| 37 |
+
num_layers (`int`, *optional*, defaults to `12`):
|
| 38 |
+
The number of `DecoderLayer`'s to use.
|
| 39 |
+
num_heads (`int`, *optional*, defaults to `12`):
|
| 40 |
+
The number of attention heads to use.
|
| 41 |
+
d_kv (`int`, *optional*, defaults to `64`):
|
| 42 |
+
Size of the key-value projection vectors.
|
| 43 |
+
d_ff (`int`, *optional*, defaults to `2048`):
|
| 44 |
+
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
|
| 45 |
+
dropout_rate (`float`, *optional*, defaults to `0.1`):
|
| 46 |
+
Dropout probability.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
@register_to_config
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
input_dims: int = 128,
|
| 53 |
+
targets_length: int = 256,
|
| 54 |
+
max_decoder_noise_time: float = 2000.0,
|
| 55 |
+
d_model: int = 768,
|
| 56 |
+
num_layers: int = 12,
|
| 57 |
+
num_heads: int = 12,
|
| 58 |
+
d_kv: int = 64,
|
| 59 |
+
d_ff: int = 2048,
|
| 60 |
+
dropout_rate: float = 0.1,
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
self.conditioning_emb = nn.Sequential(
|
| 65 |
+
nn.Linear(d_model, d_model * 4, bias=False),
|
| 66 |
+
nn.SiLU(),
|
| 67 |
+
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
| 68 |
+
nn.SiLU(),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.position_encoding = nn.Embedding(targets_length, d_model)
|
| 72 |
+
self.position_encoding.weight.requires_grad = False
|
| 73 |
+
|
| 74 |
+
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
| 75 |
+
|
| 76 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 77 |
+
|
| 78 |
+
self.decoders = nn.ModuleList()
|
| 79 |
+
for lyr_num in range(num_layers):
|
| 80 |
+
# FiLM conditional T5 decoder
|
| 81 |
+
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
| 82 |
+
self.decoders.append(lyr)
|
| 83 |
+
|
| 84 |
+
self.decoder_norm = T5LayerNorm(d_model)
|
| 85 |
+
|
| 86 |
+
self.post_dropout = nn.Dropout(p=dropout_rate)
|
| 87 |
+
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
| 88 |
+
|
| 89 |
+
def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
| 91 |
+
return mask.unsqueeze(-3)
|
| 92 |
+
|
| 93 |
+
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
| 94 |
+
batch, _, _ = decoder_input_tokens.shape
|
| 95 |
+
assert decoder_noise_time.shape == (batch,)
|
| 96 |
+
|
| 97 |
+
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
| 98 |
+
time_steps = get_timestep_embedding(
|
| 99 |
+
decoder_noise_time * self.config.max_decoder_noise_time,
|
| 100 |
+
embedding_dim=self.config.d_model,
|
| 101 |
+
max_period=self.config.max_decoder_noise_time,
|
| 102 |
+
).to(dtype=self.dtype)
|
| 103 |
+
|
| 104 |
+
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
| 105 |
+
|
| 106 |
+
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
| 107 |
+
|
| 108 |
+
seq_length = decoder_input_tokens.shape[1]
|
| 109 |
+
|
| 110 |
+
# If we want to use relative positions for audio context, we can just offset
|
| 111 |
+
# this sequence by the length of encodings_and_masks.
|
| 112 |
+
decoder_positions = torch.broadcast_to(
|
| 113 |
+
torch.arange(seq_length, device=decoder_input_tokens.device),
|
| 114 |
+
(batch, seq_length),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
position_encodings = self.position_encoding(decoder_positions)
|
| 118 |
+
|
| 119 |
+
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
| 120 |
+
inputs += position_encodings
|
| 121 |
+
y = self.dropout(inputs)
|
| 122 |
+
|
| 123 |
+
# decoder: No padding present.
|
| 124 |
+
decoder_mask = torch.ones(
|
| 125 |
+
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Translate encoding masks to encoder-decoder masks.
|
| 129 |
+
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
| 130 |
+
|
| 131 |
+
# cross attend style: concat encodings
|
| 132 |
+
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
| 133 |
+
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
| 134 |
+
|
| 135 |
+
for lyr in self.decoders:
|
| 136 |
+
y = lyr(
|
| 137 |
+
y,
|
| 138 |
+
conditioning_emb=conditioning_emb,
|
| 139 |
+
encoder_hidden_states=encoded,
|
| 140 |
+
encoder_attention_mask=encoder_decoder_mask,
|
| 141 |
+
)[0]
|
| 142 |
+
|
| 143 |
+
y = self.decoder_norm(y)
|
| 144 |
+
y = self.post_dropout(y)
|
| 145 |
+
|
| 146 |
+
spec_out = self.spec_out(y)
|
| 147 |
+
return spec_out
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class DecoderLayer(nn.Module):
|
| 151 |
+
r"""
|
| 152 |
+
T5 decoder layer.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
d_model (`int`):
|
| 156 |
+
Size of the input hidden states.
|
| 157 |
+
d_kv (`int`):
|
| 158 |
+
Size of the key-value projection vectors.
|
| 159 |
+
num_heads (`int`):
|
| 160 |
+
Number of attention heads.
|
| 161 |
+
d_ff (`int`):
|
| 162 |
+
Size of the intermediate feed-forward layer.
|
| 163 |
+
dropout_rate (`float`):
|
| 164 |
+
Dropout probability.
|
| 165 |
+
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
|
| 166 |
+
A small value used for numerical stability to avoid dividing by zero.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.layer = nn.ModuleList()
|
| 174 |
+
|
| 175 |
+
# cond self attention: layer 0
|
| 176 |
+
self.layer.append(
|
| 177 |
+
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# cross attention: layer 1
|
| 181 |
+
self.layer.append(
|
| 182 |
+
T5LayerCrossAttention(
|
| 183 |
+
d_model=d_model,
|
| 184 |
+
d_kv=d_kv,
|
| 185 |
+
num_heads=num_heads,
|
| 186 |
+
dropout_rate=dropout_rate,
|
| 187 |
+
layer_norm_epsilon=layer_norm_epsilon,
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Film Cond MLP + dropout: last layer
|
| 192 |
+
self.layer.append(
|
| 193 |
+
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def forward(
|
| 197 |
+
self,
|
| 198 |
+
hidden_states: torch.Tensor,
|
| 199 |
+
conditioning_emb: Optional[torch.Tensor] = None,
|
| 200 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 201 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 202 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 203 |
+
encoder_decoder_position_bias=None,
|
| 204 |
+
) -> Tuple[torch.Tensor]:
|
| 205 |
+
hidden_states = self.layer[0](
|
| 206 |
+
hidden_states,
|
| 207 |
+
conditioning_emb=conditioning_emb,
|
| 208 |
+
attention_mask=attention_mask,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if encoder_hidden_states is not None:
|
| 212 |
+
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
| 213 |
+
encoder_hidden_states.dtype
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
hidden_states = self.layer[1](
|
| 217 |
+
hidden_states,
|
| 218 |
+
key_value_states=encoder_hidden_states,
|
| 219 |
+
attention_mask=encoder_extended_attention_mask,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Apply Film Conditional Feed Forward layer
|
| 223 |
+
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
| 224 |
+
|
| 225 |
+
return (hidden_states,)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class T5LayerSelfAttentionCond(nn.Module):
|
| 229 |
+
r"""
|
| 230 |
+
T5 style self-attention layer with conditioning.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
d_model (`int`):
|
| 234 |
+
Size of the input hidden states.
|
| 235 |
+
d_kv (`int`):
|
| 236 |
+
Size of the key-value projection vectors.
|
| 237 |
+
num_heads (`int`):
|
| 238 |
+
Number of attention heads.
|
| 239 |
+
dropout_rate (`float`):
|
| 240 |
+
Dropout probability.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.layer_norm = T5LayerNorm(d_model)
|
| 246 |
+
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
| 247 |
+
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
| 248 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 249 |
+
|
| 250 |
+
def forward(
|
| 251 |
+
self,
|
| 252 |
+
hidden_states: torch.Tensor,
|
| 253 |
+
conditioning_emb: Optional[torch.Tensor] = None,
|
| 254 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 255 |
+
) -> torch.Tensor:
|
| 256 |
+
# pre_self_attention_layer_norm
|
| 257 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
| 258 |
+
|
| 259 |
+
if conditioning_emb is not None:
|
| 260 |
+
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
| 261 |
+
|
| 262 |
+
# Self-attention block
|
| 263 |
+
attention_output = self.attention(normed_hidden_states)
|
| 264 |
+
|
| 265 |
+
hidden_states = hidden_states + self.dropout(attention_output)
|
| 266 |
+
|
| 267 |
+
return hidden_states
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class T5LayerCrossAttention(nn.Module):
|
| 271 |
+
r"""
|
| 272 |
+
T5 style cross-attention layer.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
d_model (`int`):
|
| 276 |
+
Size of the input hidden states.
|
| 277 |
+
d_kv (`int`):
|
| 278 |
+
Size of the key-value projection vectors.
|
| 279 |
+
num_heads (`int`):
|
| 280 |
+
Number of attention heads.
|
| 281 |
+
dropout_rate (`float`):
|
| 282 |
+
Dropout probability.
|
| 283 |
+
layer_norm_epsilon (`float`):
|
| 284 |
+
A small value used for numerical stability to avoid dividing by zero.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
|
| 288 |
+
super().__init__()
|
| 289 |
+
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
| 290 |
+
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
| 291 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 292 |
+
|
| 293 |
+
def forward(
|
| 294 |
+
self,
|
| 295 |
+
hidden_states: torch.Tensor,
|
| 296 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 297 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 298 |
+
) -> torch.Tensor:
|
| 299 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
| 300 |
+
attention_output = self.attention(
|
| 301 |
+
normed_hidden_states,
|
| 302 |
+
encoder_hidden_states=key_value_states,
|
| 303 |
+
attention_mask=attention_mask.squeeze(1),
|
| 304 |
+
)
|
| 305 |
+
layer_output = hidden_states + self.dropout(attention_output)
|
| 306 |
+
return layer_output
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class T5LayerFFCond(nn.Module):
|
| 310 |
+
r"""
|
| 311 |
+
T5 style feed-forward conditional layer.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
d_model (`int`):
|
| 315 |
+
Size of the input hidden states.
|
| 316 |
+
d_ff (`int`):
|
| 317 |
+
Size of the intermediate feed-forward layer.
|
| 318 |
+
dropout_rate (`float`):
|
| 319 |
+
Dropout probability.
|
| 320 |
+
layer_norm_epsilon (`float`):
|
| 321 |
+
A small value used for numerical stability to avoid dividing by zero.
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
| 327 |
+
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
| 328 |
+
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
| 329 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 330 |
+
|
| 331 |
+
def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 332 |
+
forwarded_states = self.layer_norm(hidden_states)
|
| 333 |
+
if conditioning_emb is not None:
|
| 334 |
+
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
| 335 |
+
|
| 336 |
+
forwarded_states = self.DenseReluDense(forwarded_states)
|
| 337 |
+
hidden_states = hidden_states + self.dropout(forwarded_states)
|
| 338 |
+
return hidden_states
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class T5DenseGatedActDense(nn.Module):
|
| 342 |
+
r"""
|
| 343 |
+
T5 style feed-forward layer with gated activations and dropout.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
d_model (`int`):
|
| 347 |
+
Size of the input hidden states.
|
| 348 |
+
d_ff (`int`):
|
| 349 |
+
Size of the intermediate feed-forward layer.
|
| 350 |
+
dropout_rate (`float`):
|
| 351 |
+
Dropout probability.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
| 357 |
+
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
| 358 |
+
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
| 359 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 360 |
+
self.act = NewGELUActivation()
|
| 361 |
+
|
| 362 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 363 |
+
hidden_gelu = self.act(self.wi_0(hidden_states))
|
| 364 |
+
hidden_linear = self.wi_1(hidden_states)
|
| 365 |
+
hidden_states = hidden_gelu * hidden_linear
|
| 366 |
+
hidden_states = self.dropout(hidden_states)
|
| 367 |
+
|
| 368 |
+
hidden_states = self.wo(hidden_states)
|
| 369 |
+
return hidden_states
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class T5LayerNorm(nn.Module):
|
| 373 |
+
r"""
|
| 374 |
+
T5 style layer normalization module.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
hidden_size (`int`):
|
| 378 |
+
Size of the input hidden states.
|
| 379 |
+
eps (`float`, `optional`, defaults to `1e-6`):
|
| 380 |
+
A small value used for numerical stability to avoid dividing by zero.
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
| 384 |
+
"""
|
| 385 |
+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
| 386 |
+
"""
|
| 387 |
+
super().__init__()
|
| 388 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 389 |
+
self.variance_epsilon = eps
|
| 390 |
+
|
| 391 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 392 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
| 393 |
+
# Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
|
| 394 |
+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
| 395 |
+
# half-precision inputs is done in fp32
|
| 396 |
+
|
| 397 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 398 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 399 |
+
|
| 400 |
+
# convert into half-precision if necessary
|
| 401 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 402 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 403 |
+
|
| 404 |
+
return self.weight * hidden_states
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class NewGELUActivation(nn.Module):
|
| 408 |
+
"""
|
| 409 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
| 410 |
+
the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 414 |
+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class T5FiLMLayer(nn.Module):
|
| 418 |
+
"""
|
| 419 |
+
T5 style FiLM Layer.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
in_features (`int`):
|
| 423 |
+
Number of input features.
|
| 424 |
+
out_features (`int`):
|
| 425 |
+
Number of output features.
|
| 426 |
+
"""
|
| 427 |
+
|
| 428 |
+
def __init__(self, in_features: int, out_features: int):
|
| 429 |
+
super().__init__()
|
| 430 |
+
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
| 431 |
+
|
| 432 |
+
def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
|
| 433 |
+
emb = self.scale_bias(conditioning_emb)
|
| 434 |
+
scale, shift = torch.chunk(emb, 2, -1)
|
| 435 |
+
x = x * (1 + scale) + shift
|
| 436 |
+
return x
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 23 |
+
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
|
| 24 |
+
from ..cache_utils import CacheMixin
|
| 25 |
+
from ..embeddings import get_1d_rotary_pos_embed
|
| 26 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 27 |
+
from ..modeling_utils import ModelMixin
|
| 28 |
+
from ..normalization import AdaLayerNormContinuous
|
| 29 |
+
from .transformer_hunyuan_video import (
|
| 30 |
+
HunyuanVideoConditionEmbedding,
|
| 31 |
+
HunyuanVideoPatchEmbed,
|
| 32 |
+
HunyuanVideoSingleTransformerBlock,
|
| 33 |
+
HunyuanVideoTokenRefiner,
|
| 34 |
+
HunyuanVideoTransformerBlock,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
|
| 42 |
+
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.patch_size = patch_size
|
| 46 |
+
self.patch_size_t = patch_size_t
|
| 47 |
+
self.rope_dim = rope_dim
|
| 48 |
+
self.theta = theta
|
| 49 |
+
|
| 50 |
+
def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
|
| 51 |
+
height = height // self.patch_size
|
| 52 |
+
width = width // self.patch_size
|
| 53 |
+
grid = torch.meshgrid(
|
| 54 |
+
frame_indices.to(device=device, dtype=torch.float32),
|
| 55 |
+
torch.arange(0, height, device=device, dtype=torch.float32),
|
| 56 |
+
torch.arange(0, width, device=device, dtype=torch.float32),
|
| 57 |
+
indexing="ij",
|
| 58 |
+
) # 3 * [W, H, T]
|
| 59 |
+
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
| 60 |
+
|
| 61 |
+
freqs = []
|
| 62 |
+
for i in range(3):
|
| 63 |
+
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
| 64 |
+
freqs.append(freq)
|
| 65 |
+
|
| 66 |
+
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 67 |
+
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 68 |
+
|
| 69 |
+
return freqs_cos, freqs_sin
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class FramepackClipVisionProjection(nn.Module):
|
| 73 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.up = nn.Linear(in_channels, out_channels * 3)
|
| 76 |
+
self.down = nn.Linear(out_channels * 3, out_channels)
|
| 77 |
+
|
| 78 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
hidden_states = self.up(hidden_states)
|
| 80 |
+
hidden_states = F.silu(hidden_states)
|
| 81 |
+
hidden_states = self.down(hidden_states)
|
| 82 |
+
return hidden_states
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class HunyuanVideoHistoryPatchEmbed(nn.Module):
|
| 86 |
+
def __init__(self, in_channels: int, inner_dim: int):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 89 |
+
self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 90 |
+
self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
latents_clean: Optional[torch.Tensor] = None,
|
| 95 |
+
latents_clean_2x: Optional[torch.Tensor] = None,
|
| 96 |
+
latents_clean_4x: Optional[torch.Tensor] = None,
|
| 97 |
+
):
|
| 98 |
+
if latents_clean is not None:
|
| 99 |
+
latents_clean = self.proj(latents_clean)
|
| 100 |
+
latents_clean = latents_clean.flatten(2).transpose(1, 2)
|
| 101 |
+
if latents_clean_2x is not None:
|
| 102 |
+
latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
|
| 103 |
+
latents_clean_2x = self.proj_2x(latents_clean_2x)
|
| 104 |
+
latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
|
| 105 |
+
if latents_clean_4x is not None:
|
| 106 |
+
latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
|
| 107 |
+
latents_clean_4x = self.proj_4x(latents_clean_4x)
|
| 108 |
+
latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
|
| 109 |
+
return latents_clean, latents_clean_2x, latents_clean_4x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class HunyuanVideoFramepackTransformer3DModel(
|
| 113 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
|
| 114 |
+
):
|
| 115 |
+
_supports_gradient_checkpointing = True
|
| 116 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
| 117 |
+
_no_split_modules = [
|
| 118 |
+
"HunyuanVideoTransformerBlock",
|
| 119 |
+
"HunyuanVideoSingleTransformerBlock",
|
| 120 |
+
"HunyuanVideoHistoryPatchEmbed",
|
| 121 |
+
"HunyuanVideoTokenRefiner",
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
@register_to_config
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
in_channels: int = 16,
|
| 128 |
+
out_channels: int = 16,
|
| 129 |
+
num_attention_heads: int = 24,
|
| 130 |
+
attention_head_dim: int = 128,
|
| 131 |
+
num_layers: int = 20,
|
| 132 |
+
num_single_layers: int = 40,
|
| 133 |
+
num_refiner_layers: int = 2,
|
| 134 |
+
mlp_ratio: float = 4.0,
|
| 135 |
+
patch_size: int = 2,
|
| 136 |
+
patch_size_t: int = 1,
|
| 137 |
+
qk_norm: str = "rms_norm",
|
| 138 |
+
guidance_embeds: bool = True,
|
| 139 |
+
text_embed_dim: int = 4096,
|
| 140 |
+
pooled_projection_dim: int = 768,
|
| 141 |
+
rope_theta: float = 256.0,
|
| 142 |
+
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
| 143 |
+
image_condition_type: Optional[str] = None,
|
| 144 |
+
has_image_proj: int = False,
|
| 145 |
+
image_proj_dim: int = 1152,
|
| 146 |
+
has_clean_x_embedder: int = False,
|
| 147 |
+
) -> None:
|
| 148 |
+
super().__init__()
|
| 149 |
+
|
| 150 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 151 |
+
out_channels = out_channels or in_channels
|
| 152 |
+
|
| 153 |
+
# 1. Latent and condition embedders
|
| 154 |
+
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
| 155 |
+
|
| 156 |
+
# Framepack history projection embedder
|
| 157 |
+
self.clean_x_embedder = None
|
| 158 |
+
if has_clean_x_embedder:
|
| 159 |
+
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
|
| 160 |
+
|
| 161 |
+
self.context_embedder = HunyuanVideoTokenRefiner(
|
| 162 |
+
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Framepack image-conditioning embedder
|
| 166 |
+
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
|
| 167 |
+
|
| 168 |
+
self.time_text_embed = HunyuanVideoConditionEmbedding(
|
| 169 |
+
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# 2. RoPE
|
| 173 |
+
self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
| 174 |
+
|
| 175 |
+
# 3. Dual stream transformer blocks
|
| 176 |
+
self.transformer_blocks = nn.ModuleList(
|
| 177 |
+
[
|
| 178 |
+
HunyuanVideoTransformerBlock(
|
| 179 |
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
| 180 |
+
)
|
| 181 |
+
for _ in range(num_layers)
|
| 182 |
+
]
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# 4. Single stream transformer blocks
|
| 186 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 187 |
+
[
|
| 188 |
+
HunyuanVideoSingleTransformerBlock(
|
| 189 |
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
| 190 |
+
)
|
| 191 |
+
for _ in range(num_single_layers)
|
| 192 |
+
]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# 5. Output projection
|
| 196 |
+
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
| 197 |
+
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
| 198 |
+
|
| 199 |
+
self.gradient_checkpointing = False
|
| 200 |
+
|
| 201 |
+
def forward(
|
| 202 |
+
self,
|
| 203 |
+
hidden_states: torch.Tensor,
|
| 204 |
+
timestep: torch.LongTensor,
|
| 205 |
+
encoder_hidden_states: torch.Tensor,
|
| 206 |
+
encoder_attention_mask: torch.Tensor,
|
| 207 |
+
pooled_projections: torch.Tensor,
|
| 208 |
+
image_embeds: torch.Tensor,
|
| 209 |
+
indices_latents: torch.Tensor,
|
| 210 |
+
guidance: Optional[torch.Tensor] = None,
|
| 211 |
+
latents_clean: Optional[torch.Tensor] = None,
|
| 212 |
+
indices_latents_clean: Optional[torch.Tensor] = None,
|
| 213 |
+
latents_history_2x: Optional[torch.Tensor] = None,
|
| 214 |
+
indices_latents_history_2x: Optional[torch.Tensor] = None,
|
| 215 |
+
latents_history_4x: Optional[torch.Tensor] = None,
|
| 216 |
+
indices_latents_history_4x: Optional[torch.Tensor] = None,
|
| 217 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 218 |
+
return_dict: bool = True,
|
| 219 |
+
):
|
| 220 |
+
if attention_kwargs is not None:
|
| 221 |
+
attention_kwargs = attention_kwargs.copy()
|
| 222 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 223 |
+
else:
|
| 224 |
+
lora_scale = 1.0
|
| 225 |
+
|
| 226 |
+
if USE_PEFT_BACKEND:
|
| 227 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 228 |
+
scale_lora_layers(self, lora_scale)
|
| 229 |
+
else:
|
| 230 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 231 |
+
logger.warning(
|
| 232 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 236 |
+
p, p_t = self.config.patch_size, self.config.patch_size_t
|
| 237 |
+
post_patch_num_frames = num_frames // p_t
|
| 238 |
+
post_patch_height = height // p
|
| 239 |
+
post_patch_width = width // p
|
| 240 |
+
original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
|
| 241 |
+
|
| 242 |
+
if indices_latents is None:
|
| 243 |
+
indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
|
| 244 |
+
|
| 245 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 246 |
+
image_rotary_emb = self.rope(
|
| 247 |
+
frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
|
| 251 |
+
latents_clean, latents_history_2x, latents_history_4x
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if latents_clean is not None and indices_latents_clean is not None:
|
| 255 |
+
image_rotary_emb_clean = self.rope(
|
| 256 |
+
frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
|
| 257 |
+
)
|
| 258 |
+
if latents_history_2x is not None and indices_latents_history_2x is not None:
|
| 259 |
+
image_rotary_emb_history_2x = self.rope(
|
| 260 |
+
frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
|
| 261 |
+
)
|
| 262 |
+
if latents_history_4x is not None and indices_latents_history_4x is not None:
|
| 263 |
+
image_rotary_emb_history_4x = self.rope(
|
| 264 |
+
frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
hidden_states, image_rotary_emb = self._pack_history_states(
|
| 268 |
+
hidden_states,
|
| 269 |
+
latents_clean,
|
| 270 |
+
latents_history_2x,
|
| 271 |
+
latents_history_4x,
|
| 272 |
+
image_rotary_emb,
|
| 273 |
+
image_rotary_emb_clean,
|
| 274 |
+
image_rotary_emb_history_2x,
|
| 275 |
+
image_rotary_emb_history_4x,
|
| 276 |
+
post_patch_height,
|
| 277 |
+
post_patch_width,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
|
| 281 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
| 282 |
+
|
| 283 |
+
encoder_hidden_states_image = self.image_projection(image_embeds)
|
| 284 |
+
attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
|
| 285 |
+
|
| 286 |
+
# must cat before (not after) encoder_hidden_states, due to attn masking
|
| 287 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
| 288 |
+
encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
|
| 289 |
+
|
| 290 |
+
latent_sequence_length = hidden_states.shape[1]
|
| 291 |
+
condition_sequence_length = encoder_hidden_states.shape[1]
|
| 292 |
+
sequence_length = latent_sequence_length + condition_sequence_length
|
| 293 |
+
attention_mask = torch.zeros(
|
| 294 |
+
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
|
| 295 |
+
) # [B, N]
|
| 296 |
+
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
| 297 |
+
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
| 298 |
+
|
| 299 |
+
if batch_size == 1:
|
| 300 |
+
encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
|
| 301 |
+
attention_mask = None
|
| 302 |
+
else:
|
| 303 |
+
for i in range(batch_size):
|
| 304 |
+
attention_mask[i, : effective_sequence_length[i]] = True
|
| 305 |
+
# [B, 1, 1, N], for broadcasting across attention heads
|
| 306 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
| 307 |
+
|
| 308 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 309 |
+
for block in self.transformer_blocks:
|
| 310 |
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
| 311 |
+
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
for block in self.single_transformer_blocks:
|
| 315 |
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
| 316 |
+
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
else:
|
| 320 |
+
for block in self.transformer_blocks:
|
| 321 |
+
hidden_states, encoder_hidden_states = block(
|
| 322 |
+
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
for block in self.single_transformer_blocks:
|
| 326 |
+
hidden_states, encoder_hidden_states = block(
|
| 327 |
+
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
hidden_states = hidden_states[:, -original_context_length:]
|
| 331 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 332 |
+
hidden_states = self.proj_out(hidden_states)
|
| 333 |
+
|
| 334 |
+
hidden_states = hidden_states.reshape(
|
| 335 |
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
|
| 336 |
+
)
|
| 337 |
+
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
| 338 |
+
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 339 |
+
|
| 340 |
+
if USE_PEFT_BACKEND:
|
| 341 |
+
# remove `lora_scale` from each PEFT layer
|
| 342 |
+
unscale_lora_layers(self, lora_scale)
|
| 343 |
+
|
| 344 |
+
if not return_dict:
|
| 345 |
+
return (hidden_states,)
|
| 346 |
+
return Transformer2DModelOutput(sample=hidden_states)
|
| 347 |
+
|
| 348 |
+
def _pack_history_states(
|
| 349 |
+
self,
|
| 350 |
+
hidden_states: torch.Tensor,
|
| 351 |
+
latents_clean: Optional[torch.Tensor] = None,
|
| 352 |
+
latents_history_2x: Optional[torch.Tensor] = None,
|
| 353 |
+
latents_history_4x: Optional[torch.Tensor] = None,
|
| 354 |
+
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 355 |
+
image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 356 |
+
image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 357 |
+
image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 358 |
+
height: int = None,
|
| 359 |
+
width: int = None,
|
| 360 |
+
):
|
| 361 |
+
image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
|
| 362 |
+
|
| 363 |
+
if latents_clean is not None and image_rotary_emb_clean is not None:
|
| 364 |
+
hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
|
| 365 |
+
image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
|
| 366 |
+
image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
|
| 367 |
+
|
| 368 |
+
if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
|
| 369 |
+
hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
|
| 370 |
+
image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
|
| 371 |
+
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
|
| 372 |
+
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
|
| 373 |
+
|
| 374 |
+
if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
|
| 375 |
+
hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
|
| 376 |
+
image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
|
| 377 |
+
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
|
| 378 |
+
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
|
| 379 |
+
|
| 380 |
+
return hidden_states, tuple(image_rotary_emb)
|
| 381 |
+
|
| 382 |
+
def _pad_rotary_emb(
|
| 383 |
+
self,
|
| 384 |
+
image_rotary_emb: Tuple[torch.Tensor],
|
| 385 |
+
height: int,
|
| 386 |
+
width: int,
|
| 387 |
+
kernel_size: Tuple[int, int, int],
|
| 388 |
+
):
|
| 389 |
+
# freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
|
| 390 |
+
freqs_cos, freqs_sin = image_rotary_emb
|
| 391 |
+
freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
|
| 392 |
+
freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
|
| 393 |
+
freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
|
| 394 |
+
freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
|
| 395 |
+
freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
|
| 396 |
+
freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
|
| 397 |
+
freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
|
| 398 |
+
freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
|
| 399 |
+
return freqs_cos, freqs_sin
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _pad_for_3d_conv(x, kernel_size):
|
| 403 |
+
if isinstance(x, (tuple, list)):
|
| 404 |
+
return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
|
| 405 |
+
b, c, t, h, w = x.shape
|
| 406 |
+
pt, ph, pw = kernel_size
|
| 407 |
+
pad_t = (pt - (t % pt)) % pt
|
| 408 |
+
pad_h = (ph - (h % ph)) % ph
|
| 409 |
+
pad_w = (pw - (w % pw)) % pw
|
| 410 |
+
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _center_down_sample_3d(x, kernel_size):
|
| 414 |
+
if isinstance(x, (tuple, list)):
|
| 415 |
+
return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
|
| 416 |
+
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_ltx.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 25 |
+
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
| 26 |
+
from ...utils.torch_utils import maybe_allow_in_graph
|
| 27 |
+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
| 28 |
+
from ..attention_dispatch import dispatch_attention_fn
|
| 29 |
+
from ..cache_utils import CacheMixin
|
| 30 |
+
from ..embeddings import PixArtAlphaTextProjection
|
| 31 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 32 |
+
from ..modeling_utils import ModelMixin
|
| 33 |
+
from ..normalization import AdaLayerNormSingle, RMSNorm
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class LTXVideoAttentionProcessor2_0:
|
| 40 |
+
def __new__(cls, *args, **kwargs):
|
| 41 |
+
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
|
| 42 |
+
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
|
| 43 |
+
|
| 44 |
+
return LTXVideoAttnProcessor(*args, **kwargs)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LTXVideoAttnProcessor:
|
| 48 |
+
r"""
|
| 49 |
+
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
|
| 50 |
+
model. It applies a normalization layer and rotary embedding on the query and key vector.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
_attention_backend = None
|
| 54 |
+
|
| 55 |
+
def __init__(self):
|
| 56 |
+
if is_torch_version("<", "2.0"):
|
| 57 |
+
raise ValueError(
|
| 58 |
+
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def __call__(
|
| 62 |
+
self,
|
| 63 |
+
attn: "LTXAttention",
|
| 64 |
+
hidden_states: torch.Tensor,
|
| 65 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 66 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 67 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
batch_size, sequence_length, _ = (
|
| 70 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if attention_mask is not None:
|
| 74 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 75 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 76 |
+
|
| 77 |
+
if encoder_hidden_states is None:
|
| 78 |
+
encoder_hidden_states = hidden_states
|
| 79 |
+
|
| 80 |
+
query = attn.to_q(hidden_states)
|
| 81 |
+
key = attn.to_k(encoder_hidden_states)
|
| 82 |
+
value = attn.to_v(encoder_hidden_states)
|
| 83 |
+
|
| 84 |
+
query = attn.norm_q(query)
|
| 85 |
+
key = attn.norm_k(key)
|
| 86 |
+
|
| 87 |
+
if image_rotary_emb is not None:
|
| 88 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 89 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 90 |
+
|
| 91 |
+
query = query.unflatten(2, (attn.heads, -1))
|
| 92 |
+
key = key.unflatten(2, (attn.heads, -1))
|
| 93 |
+
value = value.unflatten(2, (attn.heads, -1))
|
| 94 |
+
|
| 95 |
+
hidden_states = dispatch_attention_fn(
|
| 96 |
+
query,
|
| 97 |
+
key,
|
| 98 |
+
value,
|
| 99 |
+
attn_mask=attention_mask,
|
| 100 |
+
dropout_p=0.0,
|
| 101 |
+
is_causal=False,
|
| 102 |
+
backend=self._attention_backend,
|
| 103 |
+
)
|
| 104 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 105 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 106 |
+
|
| 107 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 108 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 109 |
+
return hidden_states
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
|
| 113 |
+
_default_processor_cls = LTXVideoAttnProcessor
|
| 114 |
+
_available_processors = [LTXVideoAttnProcessor]
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
query_dim: int,
|
| 119 |
+
heads: int = 8,
|
| 120 |
+
kv_heads: int = 8,
|
| 121 |
+
dim_head: int = 64,
|
| 122 |
+
dropout: float = 0.0,
|
| 123 |
+
bias: bool = True,
|
| 124 |
+
cross_attention_dim: Optional[int] = None,
|
| 125 |
+
out_bias: bool = True,
|
| 126 |
+
qk_norm: str = "rms_norm_across_heads",
|
| 127 |
+
processor=None,
|
| 128 |
+
):
|
| 129 |
+
super().__init__()
|
| 130 |
+
if qk_norm != "rms_norm_across_heads":
|
| 131 |
+
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
|
| 132 |
+
|
| 133 |
+
self.head_dim = dim_head
|
| 134 |
+
self.inner_dim = dim_head * heads
|
| 135 |
+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
| 136 |
+
self.query_dim = query_dim
|
| 137 |
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 138 |
+
self.use_bias = bias
|
| 139 |
+
self.dropout = dropout
|
| 140 |
+
self.out_dim = query_dim
|
| 141 |
+
self.heads = heads
|
| 142 |
+
|
| 143 |
+
norm_eps = 1e-5
|
| 144 |
+
norm_elementwise_affine = True
|
| 145 |
+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
| 146 |
+
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
| 147 |
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 148 |
+
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
| 149 |
+
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
| 150 |
+
self.to_out = torch.nn.ModuleList([])
|
| 151 |
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
| 152 |
+
self.to_out.append(torch.nn.Dropout(dropout))
|
| 153 |
+
|
| 154 |
+
if processor is None:
|
| 155 |
+
processor = self._default_processor_cls()
|
| 156 |
+
self.set_processor(processor)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self,
|
| 160 |
+
hidden_states: torch.Tensor,
|
| 161 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 162 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 163 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 164 |
+
**kwargs,
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 167 |
+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
| 168 |
+
if len(unused_kwargs) > 0:
|
| 169 |
+
logger.warning(
|
| 170 |
+
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
| 171 |
+
)
|
| 172 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 173 |
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class LTXVideoRotaryPosEmbed(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
dim: int,
|
| 180 |
+
base_num_frames: int = 20,
|
| 181 |
+
base_height: int = 2048,
|
| 182 |
+
base_width: int = 2048,
|
| 183 |
+
patch_size: int = 1,
|
| 184 |
+
patch_size_t: int = 1,
|
| 185 |
+
theta: float = 10000.0,
|
| 186 |
+
) -> None:
|
| 187 |
+
super().__init__()
|
| 188 |
+
|
| 189 |
+
self.dim = dim
|
| 190 |
+
self.base_num_frames = base_num_frames
|
| 191 |
+
self.base_height = base_height
|
| 192 |
+
self.base_width = base_width
|
| 193 |
+
self.patch_size = patch_size
|
| 194 |
+
self.patch_size_t = patch_size_t
|
| 195 |
+
self.theta = theta
|
| 196 |
+
|
| 197 |
+
def _prepare_video_coords(
|
| 198 |
+
self,
|
| 199 |
+
batch_size: int,
|
| 200 |
+
num_frames: int,
|
| 201 |
+
height: int,
|
| 202 |
+
width: int,
|
| 203 |
+
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
|
| 204 |
+
device: torch.device,
|
| 205 |
+
) -> torch.Tensor:
|
| 206 |
+
# Always compute rope in fp32
|
| 207 |
+
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
| 208 |
+
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
| 209 |
+
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
|
| 210 |
+
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
|
| 211 |
+
grid = torch.stack(grid, dim=0)
|
| 212 |
+
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
| 213 |
+
|
| 214 |
+
if rope_interpolation_scale is not None:
|
| 215 |
+
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
|
| 216 |
+
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
|
| 217 |
+
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
|
| 218 |
+
|
| 219 |
+
grid = grid.flatten(2, 4).transpose(1, 2)
|
| 220 |
+
|
| 221 |
+
return grid
|
| 222 |
+
|
| 223 |
+
def forward(
|
| 224 |
+
self,
|
| 225 |
+
hidden_states: torch.Tensor,
|
| 226 |
+
num_frames: Optional[int] = None,
|
| 227 |
+
height: Optional[int] = None,
|
| 228 |
+
width: Optional[int] = None,
|
| 229 |
+
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
| 230 |
+
video_coords: Optional[torch.Tensor] = None,
|
| 231 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 232 |
+
batch_size = hidden_states.size(0)
|
| 233 |
+
|
| 234 |
+
if video_coords is None:
|
| 235 |
+
grid = self._prepare_video_coords(
|
| 236 |
+
batch_size,
|
| 237 |
+
num_frames,
|
| 238 |
+
height,
|
| 239 |
+
width,
|
| 240 |
+
rope_interpolation_scale=rope_interpolation_scale,
|
| 241 |
+
device=hidden_states.device,
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
grid = torch.stack(
|
| 245 |
+
[
|
| 246 |
+
video_coords[:, 0] / self.base_num_frames,
|
| 247 |
+
video_coords[:, 1] / self.base_height,
|
| 248 |
+
video_coords[:, 2] / self.base_width,
|
| 249 |
+
],
|
| 250 |
+
dim=-1,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
start = 1.0
|
| 254 |
+
end = self.theta
|
| 255 |
+
freqs = self.theta ** torch.linspace(
|
| 256 |
+
math.log(start, self.theta),
|
| 257 |
+
math.log(end, self.theta),
|
| 258 |
+
self.dim // 6,
|
| 259 |
+
device=hidden_states.device,
|
| 260 |
+
dtype=torch.float32,
|
| 261 |
+
)
|
| 262 |
+
freqs = freqs * math.pi / 2.0
|
| 263 |
+
freqs = freqs * (grid.unsqueeze(-1) * 2 - 1)
|
| 264 |
+
freqs = freqs.transpose(-1, -2).flatten(2)
|
| 265 |
+
|
| 266 |
+
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
| 267 |
+
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
| 268 |
+
|
| 269 |
+
if self.dim % 6 != 0:
|
| 270 |
+
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6])
|
| 271 |
+
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6])
|
| 272 |
+
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
| 273 |
+
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
| 274 |
+
|
| 275 |
+
return cos_freqs, sin_freqs
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@maybe_allow_in_graph
|
| 279 |
+
class LTXVideoTransformerBlock(nn.Module):
|
| 280 |
+
r"""
|
| 281 |
+
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
dim (`int`):
|
| 285 |
+
The number of channels in the input and output.
|
| 286 |
+
num_attention_heads (`int`):
|
| 287 |
+
The number of heads to use for multi-head attention.
|
| 288 |
+
attention_head_dim (`int`):
|
| 289 |
+
The number of channels in each head.
|
| 290 |
+
qk_norm (`str`, defaults to `"rms_norm"`):
|
| 291 |
+
The normalization layer to use.
|
| 292 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 293 |
+
Activation function to use in feed-forward.
|
| 294 |
+
eps (`float`, defaults to `1e-6`):
|
| 295 |
+
Epsilon value for normalization layers.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
dim: int,
|
| 301 |
+
num_attention_heads: int,
|
| 302 |
+
attention_head_dim: int,
|
| 303 |
+
cross_attention_dim: int,
|
| 304 |
+
qk_norm: str = "rms_norm_across_heads",
|
| 305 |
+
activation_fn: str = "gelu-approximate",
|
| 306 |
+
attention_bias: bool = True,
|
| 307 |
+
attention_out_bias: bool = True,
|
| 308 |
+
eps: float = 1e-6,
|
| 309 |
+
elementwise_affine: bool = False,
|
| 310 |
+
):
|
| 311 |
+
super().__init__()
|
| 312 |
+
|
| 313 |
+
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
| 314 |
+
self.attn1 = LTXAttention(
|
| 315 |
+
query_dim=dim,
|
| 316 |
+
heads=num_attention_heads,
|
| 317 |
+
kv_heads=num_attention_heads,
|
| 318 |
+
dim_head=attention_head_dim,
|
| 319 |
+
bias=attention_bias,
|
| 320 |
+
cross_attention_dim=None,
|
| 321 |
+
out_bias=attention_out_bias,
|
| 322 |
+
qk_norm=qk_norm,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
| 326 |
+
self.attn2 = LTXAttention(
|
| 327 |
+
query_dim=dim,
|
| 328 |
+
cross_attention_dim=cross_attention_dim,
|
| 329 |
+
heads=num_attention_heads,
|
| 330 |
+
kv_heads=num_attention_heads,
|
| 331 |
+
dim_head=attention_head_dim,
|
| 332 |
+
bias=attention_bias,
|
| 333 |
+
out_bias=attention_out_bias,
|
| 334 |
+
qk_norm=qk_norm,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
| 338 |
+
|
| 339 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
| 340 |
+
|
| 341 |
+
def forward(
|
| 342 |
+
self,
|
| 343 |
+
hidden_states: torch.Tensor,
|
| 344 |
+
encoder_hidden_states: torch.Tensor,
|
| 345 |
+
temb: torch.Tensor,
|
| 346 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 347 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 348 |
+
) -> torch.Tensor:
|
| 349 |
+
batch_size = hidden_states.size(0)
|
| 350 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 351 |
+
|
| 352 |
+
num_ada_params = self.scale_shift_table.shape[0]
|
| 353 |
+
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
| 354 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
| 355 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 356 |
+
|
| 357 |
+
attn_hidden_states = self.attn1(
|
| 358 |
+
hidden_states=norm_hidden_states,
|
| 359 |
+
encoder_hidden_states=None,
|
| 360 |
+
image_rotary_emb=image_rotary_emb,
|
| 361 |
+
)
|
| 362 |
+
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
| 363 |
+
|
| 364 |
+
attn_hidden_states = self.attn2(
|
| 365 |
+
hidden_states,
|
| 366 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 367 |
+
image_rotary_emb=None,
|
| 368 |
+
attention_mask=encoder_attention_mask,
|
| 369 |
+
)
|
| 370 |
+
hidden_states = hidden_states + attn_hidden_states
|
| 371 |
+
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
| 372 |
+
|
| 373 |
+
ff_output = self.ff(norm_hidden_states)
|
| 374 |
+
hidden_states = hidden_states + ff_output * gate_mlp
|
| 375 |
+
|
| 376 |
+
return hidden_states
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
@maybe_allow_in_graph
|
| 380 |
+
class LTXVideoTransformer3DModel(
|
| 381 |
+
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
|
| 382 |
+
):
|
| 383 |
+
r"""
|
| 384 |
+
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
in_channels (`int`, defaults to `128`):
|
| 388 |
+
The number of channels in the input.
|
| 389 |
+
out_channels (`int`, defaults to `128`):
|
| 390 |
+
The number of channels in the output.
|
| 391 |
+
patch_size (`int`, defaults to `1`):
|
| 392 |
+
The size of the spatial patches to use in the patch embedding layer.
|
| 393 |
+
patch_size_t (`int`, defaults to `1`):
|
| 394 |
+
The size of the tmeporal patches to use in the patch embedding layer.
|
| 395 |
+
num_attention_heads (`int`, defaults to `32`):
|
| 396 |
+
The number of heads to use for multi-head attention.
|
| 397 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 398 |
+
The number of channels in each head.
|
| 399 |
+
cross_attention_dim (`int`, defaults to `2048 `):
|
| 400 |
+
The number of channels for cross attention heads.
|
| 401 |
+
num_layers (`int`, defaults to `28`):
|
| 402 |
+
The number of layers of Transformer blocks to use.
|
| 403 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 404 |
+
Activation function to use in feed-forward.
|
| 405 |
+
qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
|
| 406 |
+
The normalization layer to use.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
_supports_gradient_checkpointing = True
|
| 410 |
+
_skip_layerwise_casting_patterns = ["norm"]
|
| 411 |
+
_repeated_blocks = ["LTXVideoTransformerBlock"]
|
| 412 |
+
|
| 413 |
+
@register_to_config
|
| 414 |
+
def __init__(
|
| 415 |
+
self,
|
| 416 |
+
in_channels: int = 128,
|
| 417 |
+
out_channels: int = 128,
|
| 418 |
+
patch_size: int = 1,
|
| 419 |
+
patch_size_t: int = 1,
|
| 420 |
+
num_attention_heads: int = 32,
|
| 421 |
+
attention_head_dim: int = 64,
|
| 422 |
+
cross_attention_dim: int = 2048,
|
| 423 |
+
num_layers: int = 28,
|
| 424 |
+
activation_fn: str = "gelu-approximate",
|
| 425 |
+
qk_norm: str = "rms_norm_across_heads",
|
| 426 |
+
norm_elementwise_affine: bool = False,
|
| 427 |
+
norm_eps: float = 1e-6,
|
| 428 |
+
caption_channels: int = 4096,
|
| 429 |
+
attention_bias: bool = True,
|
| 430 |
+
attention_out_bias: bool = True,
|
| 431 |
+
) -> None:
|
| 432 |
+
super().__init__()
|
| 433 |
+
|
| 434 |
+
out_channels = out_channels or in_channels
|
| 435 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 436 |
+
|
| 437 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 438 |
+
|
| 439 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
| 440 |
+
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
|
| 441 |
+
|
| 442 |
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
| 443 |
+
|
| 444 |
+
self.rope = LTXVideoRotaryPosEmbed(
|
| 445 |
+
dim=inner_dim,
|
| 446 |
+
base_num_frames=20,
|
| 447 |
+
base_height=2048,
|
| 448 |
+
base_width=2048,
|
| 449 |
+
patch_size=patch_size,
|
| 450 |
+
patch_size_t=patch_size_t,
|
| 451 |
+
theta=10000.0,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
self.transformer_blocks = nn.ModuleList(
|
| 455 |
+
[
|
| 456 |
+
LTXVideoTransformerBlock(
|
| 457 |
+
dim=inner_dim,
|
| 458 |
+
num_attention_heads=num_attention_heads,
|
| 459 |
+
attention_head_dim=attention_head_dim,
|
| 460 |
+
cross_attention_dim=cross_attention_dim,
|
| 461 |
+
qk_norm=qk_norm,
|
| 462 |
+
activation_fn=activation_fn,
|
| 463 |
+
attention_bias=attention_bias,
|
| 464 |
+
attention_out_bias=attention_out_bias,
|
| 465 |
+
eps=norm_eps,
|
| 466 |
+
elementwise_affine=norm_elementwise_affine,
|
| 467 |
+
)
|
| 468 |
+
for _ in range(num_layers)
|
| 469 |
+
]
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
| 473 |
+
self.proj_out = nn.Linear(inner_dim, out_channels)
|
| 474 |
+
|
| 475 |
+
self.gradient_checkpointing = False
|
| 476 |
+
|
| 477 |
+
def forward(
|
| 478 |
+
self,
|
| 479 |
+
hidden_states: torch.Tensor,
|
| 480 |
+
encoder_hidden_states: torch.Tensor,
|
| 481 |
+
timestep: torch.LongTensor,
|
| 482 |
+
encoder_attention_mask: torch.Tensor,
|
| 483 |
+
num_frames: Optional[int] = None,
|
| 484 |
+
height: Optional[int] = None,
|
| 485 |
+
width: Optional[int] = None,
|
| 486 |
+
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
|
| 487 |
+
video_coords: Optional[torch.Tensor] = None,
|
| 488 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 489 |
+
return_dict: bool = True,
|
| 490 |
+
) -> torch.Tensor:
|
| 491 |
+
if attention_kwargs is not None:
|
| 492 |
+
attention_kwargs = attention_kwargs.copy()
|
| 493 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 494 |
+
else:
|
| 495 |
+
lora_scale = 1.0
|
| 496 |
+
|
| 497 |
+
if USE_PEFT_BACKEND:
|
| 498 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 499 |
+
scale_lora_layers(self, lora_scale)
|
| 500 |
+
else:
|
| 501 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 502 |
+
logger.warning(
|
| 503 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
|
| 507 |
+
|
| 508 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 509 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
| 510 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
| 511 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 512 |
+
|
| 513 |
+
batch_size = hidden_states.size(0)
|
| 514 |
+
hidden_states = self.proj_in(hidden_states)
|
| 515 |
+
|
| 516 |
+
temb, embedded_timestep = self.time_embed(
|
| 517 |
+
timestep.flatten(),
|
| 518 |
+
batch_size=batch_size,
|
| 519 |
+
hidden_dtype=hidden_states.dtype,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
temb = temb.view(batch_size, -1, temb.size(-1))
|
| 523 |
+
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
| 524 |
+
|
| 525 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
| 526 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
| 527 |
+
|
| 528 |
+
for block in self.transformer_blocks:
|
| 529 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 530 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 531 |
+
block,
|
| 532 |
+
hidden_states,
|
| 533 |
+
encoder_hidden_states,
|
| 534 |
+
temb,
|
| 535 |
+
image_rotary_emb,
|
| 536 |
+
encoder_attention_mask,
|
| 537 |
+
)
|
| 538 |
+
else:
|
| 539 |
+
hidden_states = block(
|
| 540 |
+
hidden_states=hidden_states,
|
| 541 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 542 |
+
temb=temb,
|
| 543 |
+
image_rotary_emb=image_rotary_emb,
|
| 544 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
| 548 |
+
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
| 549 |
+
|
| 550 |
+
hidden_states = self.norm_out(hidden_states)
|
| 551 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
| 552 |
+
output = self.proj_out(hidden_states)
|
| 553 |
+
|
| 554 |
+
if USE_PEFT_BACKEND:
|
| 555 |
+
# remove `lora_scale` from each PEFT layer
|
| 556 |
+
unscale_lora_layers(self, lora_scale)
|
| 557 |
+
|
| 558 |
+
if not return_dict:
|
| 559 |
+
return (output,)
|
| 560 |
+
return Transformer2DModelOutput(sample=output)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def apply_rotary_emb(x, freqs):
|
| 564 |
+
cos, sin = freqs
|
| 565 |
+
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
|
| 566 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
|
| 567 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 568 |
+
return out
|
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_lumina2.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ...loaders import PeftAdapterMixin
|
| 24 |
+
from ...loaders.single_file_model import FromOriginalModelMixin
|
| 25 |
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 26 |
+
from ..attention import LuminaFeedForward
|
| 27 |
+
from ..attention_processor import Attention
|
| 28 |
+
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
| 29 |
+
from ..modeling_outputs import Transformer2DModelOutput
|
| 30 |
+
from ..modeling_utils import ModelMixin
|
| 31 |
+
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
hidden_size: int = 4096,
|
| 41 |
+
cap_feat_dim: int = 2048,
|
| 42 |
+
frequency_embedding_size: int = 256,
|
| 43 |
+
norm_eps: float = 1e-5,
|
| 44 |
+
) -> None:
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.time_proj = Timesteps(
|
| 48 |
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 52 |
+
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.caption_embedder = nn.Sequential(
|
| 56 |
+
RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def forward(
|
| 60 |
+
self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
|
| 61 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 62 |
+
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
|
| 63 |
+
time_embed = self.timestep_embedder(timestep_proj)
|
| 64 |
+
caption_embed = self.caption_embedder(encoder_hidden_states)
|
| 65 |
+
return time_embed, caption_embed
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Lumina2AttnProcessor2_0:
|
| 69 |
+
r"""
|
| 70 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
| 71 |
+
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self):
|
| 75 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 76 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 77 |
+
|
| 78 |
+
def __call__(
|
| 79 |
+
self,
|
| 80 |
+
attn: Attention,
|
| 81 |
+
hidden_states: torch.Tensor,
|
| 82 |
+
encoder_hidden_states: torch.Tensor,
|
| 83 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 84 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 85 |
+
base_sequence_length: Optional[int] = None,
|
| 86 |
+
) -> torch.Tensor:
|
| 87 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 88 |
+
|
| 89 |
+
# Get Query-Key-Value Pair
|
| 90 |
+
query = attn.to_q(hidden_states)
|
| 91 |
+
key = attn.to_k(encoder_hidden_states)
|
| 92 |
+
value = attn.to_v(encoder_hidden_states)
|
| 93 |
+
|
| 94 |
+
query_dim = query.shape[-1]
|
| 95 |
+
inner_dim = key.shape[-1]
|
| 96 |
+
head_dim = query_dim // attn.heads
|
| 97 |
+
dtype = query.dtype
|
| 98 |
+
|
| 99 |
+
# Get key-value heads
|
| 100 |
+
kv_heads = inner_dim // head_dim
|
| 101 |
+
|
| 102 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 103 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 104 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 105 |
+
|
| 106 |
+
# Apply Query-Key Norm if needed
|
| 107 |
+
if attn.norm_q is not None:
|
| 108 |
+
query = attn.norm_q(query)
|
| 109 |
+
if attn.norm_k is not None:
|
| 110 |
+
key = attn.norm_k(key)
|
| 111 |
+
|
| 112 |
+
# Apply RoPE if needed
|
| 113 |
+
if image_rotary_emb is not None:
|
| 114 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 115 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 116 |
+
|
| 117 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 118 |
+
|
| 119 |
+
# Apply proportional attention if true
|
| 120 |
+
if base_sequence_length is not None:
|
| 121 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 122 |
+
else:
|
| 123 |
+
softmax_scale = attn.scale
|
| 124 |
+
|
| 125 |
+
# perform Grouped-qurey Attention (GQA)
|
| 126 |
+
n_rep = attn.heads // kv_heads
|
| 127 |
+
if n_rep >= 1:
|
| 128 |
+
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| 129 |
+
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| 130 |
+
|
| 131 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 132 |
+
# (batch, heads, source_length, target_length)
|
| 133 |
+
if attention_mask is not None:
|
| 134 |
+
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
| 135 |
+
|
| 136 |
+
query = query.transpose(1, 2)
|
| 137 |
+
key = key.transpose(1, 2)
|
| 138 |
+
value = value.transpose(1, 2)
|
| 139 |
+
|
| 140 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 141 |
+
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
| 142 |
+
)
|
| 143 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 144 |
+
hidden_states = hidden_states.type_as(query)
|
| 145 |
+
|
| 146 |
+
# linear proj
|
| 147 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 148 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 149 |
+
return hidden_states
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Lumina2TransformerBlock(nn.Module):
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
dim: int,
|
| 156 |
+
num_attention_heads: int,
|
| 157 |
+
num_kv_heads: int,
|
| 158 |
+
multiple_of: int,
|
| 159 |
+
ffn_dim_multiplier: float,
|
| 160 |
+
norm_eps: float,
|
| 161 |
+
modulation: bool = True,
|
| 162 |
+
) -> None:
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.head_dim = dim // num_attention_heads
|
| 165 |
+
self.modulation = modulation
|
| 166 |
+
|
| 167 |
+
self.attn = Attention(
|
| 168 |
+
query_dim=dim,
|
| 169 |
+
cross_attention_dim=None,
|
| 170 |
+
dim_head=dim // num_attention_heads,
|
| 171 |
+
qk_norm="rms_norm",
|
| 172 |
+
heads=num_attention_heads,
|
| 173 |
+
kv_heads=num_kv_heads,
|
| 174 |
+
eps=1e-5,
|
| 175 |
+
bias=False,
|
| 176 |
+
out_bias=False,
|
| 177 |
+
processor=Lumina2AttnProcessor2_0(),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.feed_forward = LuminaFeedForward(
|
| 181 |
+
dim=dim,
|
| 182 |
+
inner_dim=4 * dim,
|
| 183 |
+
multiple_of=multiple_of,
|
| 184 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if modulation:
|
| 188 |
+
self.norm1 = LuminaRMSNormZero(
|
| 189 |
+
embedding_dim=dim,
|
| 190 |
+
norm_eps=norm_eps,
|
| 191 |
+
norm_elementwise_affine=True,
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
| 195 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 196 |
+
|
| 197 |
+
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
| 198 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 199 |
+
|
| 200 |
+
def forward(
|
| 201 |
+
self,
|
| 202 |
+
hidden_states: torch.Tensor,
|
| 203 |
+
attention_mask: torch.Tensor,
|
| 204 |
+
image_rotary_emb: torch.Tensor,
|
| 205 |
+
temb: Optional[torch.Tensor] = None,
|
| 206 |
+
) -> torch.Tensor:
|
| 207 |
+
if self.modulation:
|
| 208 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 209 |
+
attn_output = self.attn(
|
| 210 |
+
hidden_states=norm_hidden_states,
|
| 211 |
+
encoder_hidden_states=norm_hidden_states,
|
| 212 |
+
attention_mask=attention_mask,
|
| 213 |
+
image_rotary_emb=image_rotary_emb,
|
| 214 |
+
)
|
| 215 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 216 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 217 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 218 |
+
else:
|
| 219 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 220 |
+
attn_output = self.attn(
|
| 221 |
+
hidden_states=norm_hidden_states,
|
| 222 |
+
encoder_hidden_states=norm_hidden_states,
|
| 223 |
+
attention_mask=attention_mask,
|
| 224 |
+
image_rotary_emb=image_rotary_emb,
|
| 225 |
+
)
|
| 226 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 227 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 228 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 229 |
+
|
| 230 |
+
return hidden_states
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class Lumina2RotaryPosEmbed(nn.Module):
|
| 234 |
+
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.theta = theta
|
| 237 |
+
self.axes_dim = axes_dim
|
| 238 |
+
self.axes_lens = axes_lens
|
| 239 |
+
self.patch_size = patch_size
|
| 240 |
+
|
| 241 |
+
self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
|
| 242 |
+
|
| 243 |
+
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
|
| 244 |
+
freqs_cis = []
|
| 245 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 246 |
+
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
| 247 |
+
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
|
| 248 |
+
freqs_cis.append(emb)
|
| 249 |
+
return freqs_cis
|
| 250 |
+
|
| 251 |
+
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
|
| 252 |
+
device = ids.device
|
| 253 |
+
if ids.device.type == "mps":
|
| 254 |
+
ids = ids.to("cpu")
|
| 255 |
+
|
| 256 |
+
result = []
|
| 257 |
+
for i in range(len(self.axes_dim)):
|
| 258 |
+
freqs = self.freqs_cis[i].to(ids.device)
|
| 259 |
+
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
| 260 |
+
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
| 261 |
+
return torch.cat(result, dim=-1).to(device)
|
| 262 |
+
|
| 263 |
+
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
| 264 |
+
batch_size, channels, height, width = hidden_states.shape
|
| 265 |
+
p = self.patch_size
|
| 266 |
+
post_patch_height, post_patch_width = height // p, width // p
|
| 267 |
+
image_seq_len = post_patch_height * post_patch_width
|
| 268 |
+
device = hidden_states.device
|
| 269 |
+
|
| 270 |
+
encoder_seq_len = attention_mask.shape[1]
|
| 271 |
+
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
| 272 |
+
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
|
| 273 |
+
max_seq_len = max(seq_lengths)
|
| 274 |
+
|
| 275 |
+
# Create position IDs
|
| 276 |
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
| 277 |
+
|
| 278 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 279 |
+
# add caption position ids
|
| 280 |
+
position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
|
| 281 |
+
position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
|
| 282 |
+
|
| 283 |
+
# add image position ids
|
| 284 |
+
row_ids = (
|
| 285 |
+
torch.arange(post_patch_height, dtype=torch.int32, device=device)
|
| 286 |
+
.view(-1, 1)
|
| 287 |
+
.repeat(1, post_patch_width)
|
| 288 |
+
.flatten()
|
| 289 |
+
)
|
| 290 |
+
col_ids = (
|
| 291 |
+
torch.arange(post_patch_width, dtype=torch.int32, device=device)
|
| 292 |
+
.view(1, -1)
|
| 293 |
+
.repeat(post_patch_height, 1)
|
| 294 |
+
.flatten()
|
| 295 |
+
)
|
| 296 |
+
position_ids[i, cap_seq_len:seq_len, 1] = row_ids
|
| 297 |
+
position_ids[i, cap_seq_len:seq_len, 2] = col_ids
|
| 298 |
+
|
| 299 |
+
# Get combined rotary embeddings
|
| 300 |
+
freqs_cis = self._get_freqs_cis(position_ids)
|
| 301 |
+
|
| 302 |
+
# create separate rotary embeddings for captions and images
|
| 303 |
+
cap_freqs_cis = torch.zeros(
|
| 304 |
+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 305 |
+
)
|
| 306 |
+
img_freqs_cis = torch.zeros(
|
| 307 |
+
batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 311 |
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
| 312 |
+
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
|
| 313 |
+
|
| 314 |
+
# image patch embeddings
|
| 315 |
+
hidden_states = (
|
| 316 |
+
hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
|
| 317 |
+
.permute(0, 2, 4, 3, 5, 1)
|
| 318 |
+
.flatten(3)
|
| 319 |
+
.flatten(1, 2)
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 326 |
+
r"""
|
| 327 |
+
Lumina2NextDiT: Diffusion model with a Transformer backbone.
|
| 328 |
+
|
| 329 |
+
Parameters:
|
| 330 |
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
| 331 |
+
it is used to learn a number of position embeddings.
|
| 332 |
+
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
| 333 |
+
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
| 334 |
+
in_channels (`int`, *optional*, defaults to 4):
|
| 335 |
+
The number of input channels for the model. Typically, this matches the number of channels in the input
|
| 336 |
+
images.
|
| 337 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 338 |
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
| 339 |
+
hidden representations.
|
| 340 |
+
num_layers (`int`, *optional*, default to 32):
|
| 341 |
+
The number of layers in the model. This defines the depth of the neural network.
|
| 342 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 343 |
+
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
| 344 |
+
mechanisms are used.
|
| 345 |
+
num_kv_heads (`int`, *optional*, defaults to 8):
|
| 346 |
+
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
| 347 |
+
If None, it defaults to num_attention_heads.
|
| 348 |
+
multiple_of (`int`, *optional*, defaults to 256):
|
| 349 |
+
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
| 350 |
+
configurations.
|
| 351 |
+
ffn_dim_multiplier (`float`, *optional*):
|
| 352 |
+
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
| 353 |
+
the model configuration.
|
| 354 |
+
norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 355 |
+
A small value added to the denominator for numerical stability in normalization layers.
|
| 356 |
+
scaling_factor (`float`, *optional*, defaults to 1.0):
|
| 357 |
+
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
| 358 |
+
overall scale of the model's operations.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
_supports_gradient_checkpointing = True
|
| 362 |
+
_no_split_modules = ["Lumina2TransformerBlock"]
|
| 363 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
| 364 |
+
|
| 365 |
+
@register_to_config
|
| 366 |
+
def __init__(
|
| 367 |
+
self,
|
| 368 |
+
sample_size: int = 128,
|
| 369 |
+
patch_size: int = 2,
|
| 370 |
+
in_channels: int = 16,
|
| 371 |
+
out_channels: Optional[int] = None,
|
| 372 |
+
hidden_size: int = 2304,
|
| 373 |
+
num_layers: int = 26,
|
| 374 |
+
num_refiner_layers: int = 2,
|
| 375 |
+
num_attention_heads: int = 24,
|
| 376 |
+
num_kv_heads: int = 8,
|
| 377 |
+
multiple_of: int = 256,
|
| 378 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 379 |
+
norm_eps: float = 1e-5,
|
| 380 |
+
scaling_factor: float = 1.0,
|
| 381 |
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
| 382 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 383 |
+
cap_feat_dim: int = 1024,
|
| 384 |
+
) -> None:
|
| 385 |
+
super().__init__()
|
| 386 |
+
self.out_channels = out_channels or in_channels
|
| 387 |
+
|
| 388 |
+
# 1. Positional, patch & conditional embeddings
|
| 389 |
+
self.rope_embedder = Lumina2RotaryPosEmbed(
|
| 390 |
+
theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
|
| 394 |
+
|
| 395 |
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
| 396 |
+
hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# 2. Noise and context refinement blocks
|
| 400 |
+
self.noise_refiner = nn.ModuleList(
|
| 401 |
+
[
|
| 402 |
+
Lumina2TransformerBlock(
|
| 403 |
+
hidden_size,
|
| 404 |
+
num_attention_heads,
|
| 405 |
+
num_kv_heads,
|
| 406 |
+
multiple_of,
|
| 407 |
+
ffn_dim_multiplier,
|
| 408 |
+
norm_eps,
|
| 409 |
+
modulation=True,
|
| 410 |
+
)
|
| 411 |
+
for _ in range(num_refiner_layers)
|
| 412 |
+
]
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
self.context_refiner = nn.ModuleList(
|
| 416 |
+
[
|
| 417 |
+
Lumina2TransformerBlock(
|
| 418 |
+
hidden_size,
|
| 419 |
+
num_attention_heads,
|
| 420 |
+
num_kv_heads,
|
| 421 |
+
multiple_of,
|
| 422 |
+
ffn_dim_multiplier,
|
| 423 |
+
norm_eps,
|
| 424 |
+
modulation=False,
|
| 425 |
+
)
|
| 426 |
+
for _ in range(num_refiner_layers)
|
| 427 |
+
]
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# 3. Transformer blocks
|
| 431 |
+
self.layers = nn.ModuleList(
|
| 432 |
+
[
|
| 433 |
+
Lumina2TransformerBlock(
|
| 434 |
+
hidden_size,
|
| 435 |
+
num_attention_heads,
|
| 436 |
+
num_kv_heads,
|
| 437 |
+
multiple_of,
|
| 438 |
+
ffn_dim_multiplier,
|
| 439 |
+
norm_eps,
|
| 440 |
+
modulation=True,
|
| 441 |
+
)
|
| 442 |
+
for _ in range(num_layers)
|
| 443 |
+
]
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# 4. Output norm & projection
|
| 447 |
+
self.norm_out = LuminaLayerNormContinuous(
|
| 448 |
+
embedding_dim=hidden_size,
|
| 449 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
| 450 |
+
elementwise_affine=False,
|
| 451 |
+
eps=1e-6,
|
| 452 |
+
bias=True,
|
| 453 |
+
out_dim=patch_size * patch_size * self.out_channels,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
self.gradient_checkpointing = False
|
| 457 |
+
|
| 458 |
+
def forward(
|
| 459 |
+
self,
|
| 460 |
+
hidden_states: torch.Tensor,
|
| 461 |
+
timestep: torch.Tensor,
|
| 462 |
+
encoder_hidden_states: torch.Tensor,
|
| 463 |
+
encoder_attention_mask: torch.Tensor,
|
| 464 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 465 |
+
return_dict: bool = True,
|
| 466 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 467 |
+
if attention_kwargs is not None:
|
| 468 |
+
attention_kwargs = attention_kwargs.copy()
|
| 469 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 470 |
+
else:
|
| 471 |
+
lora_scale = 1.0
|
| 472 |
+
|
| 473 |
+
if USE_PEFT_BACKEND:
|
| 474 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 475 |
+
scale_lora_layers(self, lora_scale)
|
| 476 |
+
else:
|
| 477 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 478 |
+
logger.warning(
|
| 479 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# 1. Condition, positional & patch embedding
|
| 483 |
+
batch_size, _, height, width = hidden_states.shape
|
| 484 |
+
|
| 485 |
+
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
| 486 |
+
|
| 487 |
+
(
|
| 488 |
+
hidden_states,
|
| 489 |
+
context_rotary_emb,
|
| 490 |
+
noise_rotary_emb,
|
| 491 |
+
rotary_emb,
|
| 492 |
+
encoder_seq_lengths,
|
| 493 |
+
seq_lengths,
|
| 494 |
+
) = self.rope_embedder(hidden_states, encoder_attention_mask)
|
| 495 |
+
|
| 496 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 497 |
+
|
| 498 |
+
# 2. Context & noise refinement
|
| 499 |
+
for layer in self.context_refiner:
|
| 500 |
+
encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
|
| 501 |
+
|
| 502 |
+
for layer in self.noise_refiner:
|
| 503 |
+
hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
|
| 504 |
+
|
| 505 |
+
# 3. Joint Transformer blocks
|
| 506 |
+
max_seq_len = max(seq_lengths)
|
| 507 |
+
use_mask = len(set(seq_lengths)) > 1
|
| 508 |
+
|
| 509 |
+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
| 510 |
+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
| 511 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
| 512 |
+
attention_mask[i, :seq_len] = True
|
| 513 |
+
joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
|
| 514 |
+
joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
|
| 515 |
+
|
| 516 |
+
hidden_states = joint_hidden_states
|
| 517 |
+
|
| 518 |
+
for layer in self.layers:
|
| 519 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 520 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 521 |
+
layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
|
| 522 |
+
)
|
| 523 |
+
else:
|
| 524 |
+
hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
|
| 525 |
+
|
| 526 |
+
# 4. Output norm & projection
|
| 527 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 528 |
+
|
| 529 |
+
# 5. Unpatchify
|
| 530 |
+
p = self.config.patch_size
|
| 531 |
+
output = []
|
| 532 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
| 533 |
+
output.append(
|
| 534 |
+
hidden_states[i][encoder_seq_len:seq_len]
|
| 535 |
+
.view(height // p, width // p, p, p, self.out_channels)
|
| 536 |
+
.permute(4, 0, 2, 1, 3)
|
| 537 |
+
.flatten(3, 4)
|
| 538 |
+
.flatten(1, 2)
|
| 539 |
+
)
|
| 540 |
+
output = torch.stack(output, dim=0)
|
| 541 |
+
|
| 542 |
+
if USE_PEFT_BACKEND:
|
| 543 |
+
# remove `lora_scale` from each PEFT layer
|
| 544 |
+
unscale_lora_layers(self, lora_scale)
|
| 545 |
+
|
| 546 |
+
if not return_dict:
|
| 547 |
+
return (output,)
|
| 548 |
+
return Transformer2DModelOutput(sample=output)
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ...utils import is_flax_available, is_torch_available
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
if is_torch_available():
|
| 5 |
+
from .unet_1d import UNet1DModel
|
| 6 |
+
from .unet_2d import UNet2DModel
|
| 7 |
+
from .unet_2d_condition import UNet2DConditionModel
|
| 8 |
+
from .unet_3d_condition import UNet3DConditionModel
|
| 9 |
+
from .unet_i2vgen_xl import I2VGenXLUNet
|
| 10 |
+
from .unet_kandinsky3 import Kandinsky3UNet
|
| 11 |
+
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
| 12 |
+
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
| 13 |
+
from .unet_stable_cascade import StableCascadeUNet
|
| 14 |
+
from .uvit_2d import UVit2DModel
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if is_flax_available():
|
| 18 |
+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (940 Bytes). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_1d.cpython-310.pyc
ADDED
|
Binary file (7.94 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_1d_blocks.cpython-310.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d.cpython-310.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_blocks.cpython-310.pyc
ADDED
|
Binary file (60.9 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_blocks_flax.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition.cpython-310.pyc
ADDED
|
Binary file (40.6 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition_flax.cpython-310.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_3d_blocks.cpython-310.pyc
ADDED
|
Binary file (26.4 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_3d_condition.cpython-310.pyc
ADDED
|
Binary file (24 kB). View file
|
|
|