Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
6f25f9f
1
Parent(s):
3daab90
new
Browse files- .gitignore +0 -1
- videox_fun/models/__init__.py +110 -0
- videox_fun/models/attention_utils.py +211 -0
- videox_fun/models/cache_utils.py +76 -0
- videox_fun/models/cogvideox_transformer3d.py +840 -0
- videox_fun/models/cogvideox_vae.py +1675 -0
- videox_fun/models/flux_transformer2d.py +940 -0
- videox_fun/models/qwenimage_transformer2d.py +893 -0
- videox_fun/models/qwenimage_vae.py +1087 -0
- videox_fun/models/wan_camera_adapter.py +64 -0
- videox_fun/models/wan_image_encoder.py +553 -0
- videox_fun/models/wan_text_encoder.py +395 -0
- videox_fun/models/wan_transformer3d.py +1399 -0
- videox_fun/models/wan_transformer3d_s2v.py +887 -0
- videox_fun/models/wan_transformer3d_vace.py +392 -0
- videox_fun/models/wan_vae.py +706 -0
- videox_fun/models/wan_vae3_8.py +1080 -0
- videox_fun/models/wan_xlm_roberta.py +170 -0
.gitignore
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
samples/
|
| 2 |
-
models/
|
| 3 |
__pycache__/
|
| 4 |
*.pyc
|
|
|
|
| 1 |
samples/
|
|
|
|
| 2 |
__pycache__/
|
| 3 |
*.pyc
|
videox_fun/models/__init__.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
from diffusers import AutoencoderKL
|
| 4 |
+
from transformers import (AutoTokenizer, CLIPImageProcessor, CLIPTextModel,
|
| 5 |
+
CLIPTokenizer, CLIPVisionModelWithProjection,
|
| 6 |
+
T5EncoderModel, T5Tokenizer, T5TokenizerFast)
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
| 10 |
+
except:
|
| 11 |
+
Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
|
| 12 |
+
print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")
|
| 13 |
+
|
| 14 |
+
from .cogvideox_transformer3d import CogVideoXTransformer3DModel
|
| 15 |
+
from .cogvideox_vae import AutoencoderKLCogVideoX
|
| 16 |
+
from .flux_transformer2d import FluxTransformer2DModel
|
| 17 |
+
from .qwenimage_transformer2d import QwenImageTransformer2DModel
|
| 18 |
+
from .qwenimage_vae import AutoencoderKLQwenImage
|
| 19 |
+
# from .wan_audio_encoder import WanAudioEncoder
|
| 20 |
+
from .wan_image_encoder import CLIPModel
|
| 21 |
+
from .wan_text_encoder import WanT5EncoderModel
|
| 22 |
+
from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
|
| 23 |
+
WanSelfAttention, WanTransformer3DModel)
|
| 24 |
+
# from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
|
| 25 |
+
from .wan_transformer3d_vace import VaceWanTransformer3DModel
|
| 26 |
+
from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
|
| 27 |
+
from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
|
| 28 |
+
|
| 29 |
+
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
|
| 30 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 31 |
+
# --------------------------------------------------------------- #
|
| 32 |
+
# The simple_wrapper is used to solve the problem
|
| 33 |
+
# about conflicts between cython and torch.compile
|
| 34 |
+
# --------------------------------------------------------------- #
|
| 35 |
+
def simple_wrapper(func):
|
| 36 |
+
def inner(*args, **kwargs):
|
| 37 |
+
return func(*args, **kwargs)
|
| 38 |
+
return inner
|
| 39 |
+
|
| 40 |
+
# --------------------------------------------------------------- #
|
| 41 |
+
# VAE Parallel Kernel
|
| 42 |
+
# --------------------------------------------------------------- #
|
| 43 |
+
from ..dist import parallel_magvit_vae
|
| 44 |
+
AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
|
| 45 |
+
AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))
|
| 46 |
+
|
| 47 |
+
# --------------------------------------------------------------- #
|
| 48 |
+
# Sparse Attention
|
| 49 |
+
# --------------------------------------------------------------- #
|
| 50 |
+
import torch
|
| 51 |
+
from paifuser.ops import wan_sparse_attention_wrapper
|
| 52 |
+
|
| 53 |
+
WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
|
| 54 |
+
print("Import Sparse Attention")
|
| 55 |
+
|
| 56 |
+
WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)
|
| 57 |
+
|
| 58 |
+
# --------------------------------------------------------------- #
|
| 59 |
+
# CFG Skip Turbo
|
| 60 |
+
# --------------------------------------------------------------- #
|
| 61 |
+
import os
|
| 62 |
+
|
| 63 |
+
if importlib.util.find_spec("paifuser.accelerator") is not None:
|
| 64 |
+
from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
|
| 65 |
+
enable_cfg_skip, share_cfg_skip)
|
| 66 |
+
else:
|
| 67 |
+
from paifuser import (cfg_skip_turbo, disable_cfg_skip,
|
| 68 |
+
enable_cfg_skip, share_cfg_skip)
|
| 69 |
+
|
| 70 |
+
WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
|
| 71 |
+
WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
|
| 72 |
+
WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)
|
| 73 |
+
print("Import CFG Skip Turbo")
|
| 74 |
+
|
| 75 |
+
# --------------------------------------------------------------- #
|
| 76 |
+
# RMS Norm Kernel
|
| 77 |
+
# --------------------------------------------------------------- #
|
| 78 |
+
from paifuser.ops import rms_norm_forward
|
| 79 |
+
WanRMSNorm.forward = rms_norm_forward
|
| 80 |
+
print("Import PAI RMS Fuse")
|
| 81 |
+
|
| 82 |
+
# --------------------------------------------------------------- #
|
| 83 |
+
# Fast Rope Kernel
|
| 84 |
+
# --------------------------------------------------------------- #
|
| 85 |
+
import types
|
| 86 |
+
|
| 87 |
+
import torch
|
| 88 |
+
from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
|
| 89 |
+
rope_apply_real_qk)
|
| 90 |
+
|
| 91 |
+
from . import wan_transformer3d
|
| 92 |
+
|
| 93 |
+
def deepcopy_function(f):
|
| 94 |
+
return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
|
| 95 |
+
|
| 96 |
+
local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)
|
| 97 |
+
|
| 98 |
+
if ENABLE_KERNEL:
|
| 99 |
+
def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 100 |
+
if torch.is_grad_enabled():
|
| 101 |
+
return local_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 102 |
+
else:
|
| 103 |
+
return fast_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 104 |
+
else:
|
| 105 |
+
def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 106 |
+
return rope_apply_real_qk(q, k, grid_sizes, freqs)
|
| 107 |
+
|
| 108 |
+
wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
|
| 109 |
+
rope_apply_qk = adaptive_fast_rope_apply_qk
|
| 110 |
+
print("Import PAI Fast rope")
|
videox_fun/models/attention_utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import flash_attn_interface
|
| 8 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 9 |
+
except ModuleNotFoundError:
|
| 10 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import flash_attn
|
| 14 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 15 |
+
except ModuleNotFoundError:
|
| 16 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
major, minor = torch.cuda.get_device_capability(0)
|
| 20 |
+
if f"{major}.{minor}" == "8.0":
|
| 21 |
+
from sageattention_sm80 import sageattn
|
| 22 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 23 |
+
elif f"{major}.{minor}" == "8.6":
|
| 24 |
+
from sageattention_sm86 import sageattn
|
| 25 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 26 |
+
elif f"{major}.{minor}" == "8.9":
|
| 27 |
+
from sageattention_sm89 import sageattn
|
| 28 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 29 |
+
elif f"{major}.{minor}" == "9.0":
|
| 30 |
+
from sageattention_sm90 import sageattn
|
| 31 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 32 |
+
elif major>9:
|
| 33 |
+
from sageattention_sm120 import sageattn
|
| 34 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 35 |
+
except:
|
| 36 |
+
try:
|
| 37 |
+
from sageattention import sageattn
|
| 38 |
+
SAGE_ATTENTION_AVAILABLE = True
|
| 39 |
+
except:
|
| 40 |
+
sageattn = None
|
| 41 |
+
SAGE_ATTENTION_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
def flash_attention(
|
| 44 |
+
q,
|
| 45 |
+
k,
|
| 46 |
+
v,
|
| 47 |
+
q_lens=None,
|
| 48 |
+
k_lens=None,
|
| 49 |
+
dropout_p=0.,
|
| 50 |
+
softmax_scale=None,
|
| 51 |
+
q_scale=None,
|
| 52 |
+
causal=False,
|
| 53 |
+
window_size=(-1, -1),
|
| 54 |
+
deterministic=False,
|
| 55 |
+
dtype=torch.bfloat16,
|
| 56 |
+
version=None,
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
q: [B, Lq, Nq, C1].
|
| 60 |
+
k: [B, Lk, Nk, C1].
|
| 61 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
| 62 |
+
q_lens: [B].
|
| 63 |
+
k_lens: [B].
|
| 64 |
+
dropout_p: float. Dropout probability.
|
| 65 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 66 |
+
causal: bool. Whether to apply causal attention mask.
|
| 67 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 68 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
| 69 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 70 |
+
"""
|
| 71 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 72 |
+
assert dtype in half_dtypes
|
| 73 |
+
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
| 74 |
+
|
| 75 |
+
# params
|
| 76 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
| 77 |
+
|
| 78 |
+
def half(x):
|
| 79 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 80 |
+
|
| 81 |
+
# preprocess query
|
| 82 |
+
if q_lens is None:
|
| 83 |
+
q = half(q.flatten(0, 1))
|
| 84 |
+
q_lens = torch.tensor(
|
| 85 |
+
[lq] * b, dtype=torch.int32).to(
|
| 86 |
+
device=q.device, non_blocking=True)
|
| 87 |
+
else:
|
| 88 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
| 89 |
+
|
| 90 |
+
# preprocess key, value
|
| 91 |
+
if k_lens is None:
|
| 92 |
+
k = half(k.flatten(0, 1))
|
| 93 |
+
v = half(v.flatten(0, 1))
|
| 94 |
+
k_lens = torch.tensor(
|
| 95 |
+
[lk] * b, dtype=torch.int32).to(
|
| 96 |
+
device=k.device, non_blocking=True)
|
| 97 |
+
else:
|
| 98 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
| 99 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
| 100 |
+
|
| 101 |
+
q = q.to(v.dtype)
|
| 102 |
+
k = k.to(v.dtype)
|
| 103 |
+
|
| 104 |
+
if q_scale is not None:
|
| 105 |
+
q = q * q_scale
|
| 106 |
+
|
| 107 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 108 |
+
warnings.warn(
|
| 109 |
+
'Flash attention 3 is not available, use flash attention 2 instead.'
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# apply attention
|
| 113 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
| 114 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 115 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
| 116 |
+
q=q,
|
| 117 |
+
k=k,
|
| 118 |
+
v=v,
|
| 119 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 120 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 121 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 122 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 123 |
+
seqused_q=None,
|
| 124 |
+
seqused_k=None,
|
| 125 |
+
max_seqlen_q=lq,
|
| 126 |
+
max_seqlen_k=lk,
|
| 127 |
+
softmax_scale=softmax_scale,
|
| 128 |
+
causal=causal,
|
| 129 |
+
deterministic=deterministic).unflatten(0, (b, lq))
|
| 130 |
+
else:
|
| 131 |
+
assert FLASH_ATTN_2_AVAILABLE
|
| 132 |
+
x = flash_attn.flash_attn_varlen_func(
|
| 133 |
+
q=q,
|
| 134 |
+
k=k,
|
| 135 |
+
v=v,
|
| 136 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 137 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 138 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 139 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 140 |
+
max_seqlen_q=lq,
|
| 141 |
+
max_seqlen_k=lk,
|
| 142 |
+
dropout_p=dropout_p,
|
| 143 |
+
softmax_scale=softmax_scale,
|
| 144 |
+
causal=causal,
|
| 145 |
+
window_size=window_size,
|
| 146 |
+
deterministic=deterministic).unflatten(0, (b, lq))
|
| 147 |
+
|
| 148 |
+
# output
|
| 149 |
+
return x.type(out_dtype)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def attention(
|
| 153 |
+
q,
|
| 154 |
+
k,
|
| 155 |
+
v,
|
| 156 |
+
q_lens=None,
|
| 157 |
+
k_lens=None,
|
| 158 |
+
dropout_p=0.,
|
| 159 |
+
softmax_scale=None,
|
| 160 |
+
q_scale=None,
|
| 161 |
+
causal=False,
|
| 162 |
+
window_size=(-1, -1),
|
| 163 |
+
deterministic=False,
|
| 164 |
+
dtype=torch.bfloat16,
|
| 165 |
+
fa_version=None,
|
| 166 |
+
attention_type=None,
|
| 167 |
+
attn_mask=None,
|
| 168 |
+
):
|
| 169 |
+
attention_type = os.environ.get("VIDEOX_ATTENTION_TYPE", "FLASH_ATTENTION") if attention_type is None else attention_type
|
| 170 |
+
if torch.is_grad_enabled() and attention_type == "SAGE_ATTENTION":
|
| 171 |
+
attention_type = "FLASH_ATTENTION"
|
| 172 |
+
|
| 173 |
+
if attention_type == "SAGE_ATTENTION" and SAGE_ATTENTION_AVAILABLE:
|
| 174 |
+
if q_lens is not None or k_lens is not None:
|
| 175 |
+
warnings.warn(
|
| 176 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
out = sageattn(
|
| 180 |
+
q, k, v, attn_mask=attn_mask, tensor_layout="NHD", is_causal=causal, dropout_p=dropout_p)
|
| 181 |
+
|
| 182 |
+
elif attention_type == "FLASH_ATTENTION" and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE):
|
| 183 |
+
return flash_attention(
|
| 184 |
+
q=q,
|
| 185 |
+
k=k,
|
| 186 |
+
v=v,
|
| 187 |
+
q_lens=q_lens,
|
| 188 |
+
k_lens=k_lens,
|
| 189 |
+
dropout_p=dropout_p,
|
| 190 |
+
softmax_scale=softmax_scale,
|
| 191 |
+
q_scale=q_scale,
|
| 192 |
+
causal=causal,
|
| 193 |
+
window_size=window_size,
|
| 194 |
+
deterministic=deterministic,
|
| 195 |
+
dtype=dtype,
|
| 196 |
+
version=fa_version,
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
if q_lens is not None or k_lens is not None:
|
| 200 |
+
warnings.warn(
|
| 201 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
| 202 |
+
)
|
| 203 |
+
q = q.transpose(1, 2)
|
| 204 |
+
k = k.transpose(1, 2)
|
| 205 |
+
v = v.transpose(1, 2)
|
| 206 |
+
|
| 207 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 208 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
| 209 |
+
|
| 210 |
+
out = out.transpose(1, 2).contiguous()
|
| 211 |
+
return out
|
videox_fun/models/cache_utils.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
def get_teacache_coefficients(model_name):
|
| 5 |
+
if "wan2.1-t2v-1.3b" in model_name.lower() or "wan2.1-fun-1.3b" in model_name.lower() \
|
| 6 |
+
or "wan2.1-fun-v1.1-1.3b" in model_name.lower() or "wan2.1-vace-1.3b" in model_name.lower():
|
| 7 |
+
return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
|
| 8 |
+
elif "wan2.1-t2v-14b" in model_name.lower():
|
| 9 |
+
return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
|
| 10 |
+
elif "wan2.1-i2v-14b-480p" in model_name.lower():
|
| 11 |
+
return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
| 12 |
+
elif "wan2.1-i2v-14b-720p" in model_name.lower() or "wan2.1-fun-14b" in model_name.lower() or "wan2.2-fun" in model_name.lower() \
|
| 13 |
+
or "wan2.2-i2v-a14b" in model_name.lower() or "wan2.2-t2v-a14b" in model_name.lower() or "wan2.2-ti2v-5b" in model_name.lower() \
|
| 14 |
+
or "wan2.2-s2v" in model_name.lower() or "wan2.1-vace-14b" in model_name.lower() or "wan2.2-vace-fun" in model_name.lower():
|
| 15 |
+
return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
| 16 |
+
else:
|
| 17 |
+
print(f"The model {model_name} is not supported by TeaCache.")
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TeaCache():
|
| 22 |
+
"""
|
| 23 |
+
Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
|
| 24 |
+
the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
|
| 25 |
+
Please refer to:
|
| 26 |
+
1. https://github.com/ali-vilab/TeaCache.
|
| 27 |
+
2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
|
| 28 |
+
"""
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
coefficients: list[float],
|
| 32 |
+
num_steps: int,
|
| 33 |
+
rel_l1_thresh: float = 0.0,
|
| 34 |
+
num_skip_start_steps: int = 0,
|
| 35 |
+
offload: bool = True,
|
| 36 |
+
):
|
| 37 |
+
if num_steps < 1:
|
| 38 |
+
raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
|
| 39 |
+
if rel_l1_thresh < 0:
|
| 40 |
+
raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
|
| 41 |
+
if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"`num_skip_start_steps` must be great than or equal to 0 and "
|
| 44 |
+
f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
|
| 45 |
+
)
|
| 46 |
+
self.coefficients = coefficients
|
| 47 |
+
self.num_steps = num_steps
|
| 48 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 49 |
+
self.num_skip_start_steps = num_skip_start_steps
|
| 50 |
+
self.offload = offload
|
| 51 |
+
self.rescale_func = np.poly1d(self.coefficients)
|
| 52 |
+
|
| 53 |
+
self.cnt = 0
|
| 54 |
+
self.should_calc = True
|
| 55 |
+
self.accumulated_rel_l1_distance = 0
|
| 56 |
+
self.previous_modulated_input = None
|
| 57 |
+
# Some pipelines concatenate the unconditional and text guide in forward.
|
| 58 |
+
self.previous_residual = None
|
| 59 |
+
# Some pipelines perform forward propagation separately on the unconditional and text guide.
|
| 60 |
+
self.previous_residual_cond = None
|
| 61 |
+
self.previous_residual_uncond = None
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
|
| 66 |
+
|
| 67 |
+
return rel_l1_distance.cpu().item()
|
| 68 |
+
|
| 69 |
+
def reset(self):
|
| 70 |
+
self.cnt = 0
|
| 71 |
+
self.should_calc = True
|
| 72 |
+
self.accumulated_rel_l1_distance = 0
|
| 73 |
+
self.previous_modulated_input = None
|
| 74 |
+
self.previous_residual = None
|
| 75 |
+
self.previous_residual_cond = None
|
| 76 |
+
self.previous_residual_uncond = None
|
videox_fun/models/cogvideox_transformer3d.py
ADDED
|
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import glob
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.models.attention import Attention, FeedForward
|
| 25 |
+
from diffusers.models.attention_processor import (
|
| 26 |
+
AttentionProcessor, CogVideoXAttnProcessor2_0,
|
| 27 |
+
FusedCogVideoXAttnProcessor2_0)
|
| 28 |
+
from diffusers.models.embeddings import (CogVideoXPatchEmbed,
|
| 29 |
+
TimestepEmbedding, Timesteps,
|
| 30 |
+
get_3d_sincos_pos_embed)
|
| 31 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
| 34 |
+
from diffusers.utils import is_torch_version, logging
|
| 35 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 36 |
+
from torch import nn
|
| 37 |
+
|
| 38 |
+
from ..dist import (get_sequence_parallel_rank,
|
| 39 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 40 |
+
xFuserLongContextAttention)
|
| 41 |
+
from ..dist.cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class CogVideoXPatchEmbed(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
patch_size: int = 2,
|
| 50 |
+
patch_size_t: Optional[int] = None,
|
| 51 |
+
in_channels: int = 16,
|
| 52 |
+
embed_dim: int = 1920,
|
| 53 |
+
text_embed_dim: int = 4096,
|
| 54 |
+
bias: bool = True,
|
| 55 |
+
sample_width: int = 90,
|
| 56 |
+
sample_height: int = 60,
|
| 57 |
+
sample_frames: int = 49,
|
| 58 |
+
temporal_compression_ratio: int = 4,
|
| 59 |
+
max_text_seq_length: int = 226,
|
| 60 |
+
spatial_interpolation_scale: float = 1.875,
|
| 61 |
+
temporal_interpolation_scale: float = 1.0,
|
| 62 |
+
use_positional_embeddings: bool = True,
|
| 63 |
+
use_learned_positional_embeddings: bool = True,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
post_patch_height = sample_height // patch_size
|
| 68 |
+
post_patch_width = sample_width // patch_size
|
| 69 |
+
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
| 70 |
+
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
| 71 |
+
self.post_patch_height = post_patch_height
|
| 72 |
+
self.post_patch_width = post_patch_width
|
| 73 |
+
self.post_time_compression_frames = post_time_compression_frames
|
| 74 |
+
self.patch_size = patch_size
|
| 75 |
+
self.patch_size_t = patch_size_t
|
| 76 |
+
self.embed_dim = embed_dim
|
| 77 |
+
self.sample_height = sample_height
|
| 78 |
+
self.sample_width = sample_width
|
| 79 |
+
self.sample_frames = sample_frames
|
| 80 |
+
self.temporal_compression_ratio = temporal_compression_ratio
|
| 81 |
+
self.max_text_seq_length = max_text_seq_length
|
| 82 |
+
self.spatial_interpolation_scale = spatial_interpolation_scale
|
| 83 |
+
self.temporal_interpolation_scale = temporal_interpolation_scale
|
| 84 |
+
self.use_positional_embeddings = use_positional_embeddings
|
| 85 |
+
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
| 86 |
+
|
| 87 |
+
if patch_size_t is None:
|
| 88 |
+
# CogVideoX 1.0 checkpoints
|
| 89 |
+
self.proj = nn.Conv2d(
|
| 90 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
# CogVideoX 1.5 checkpoints
|
| 94 |
+
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
|
| 95 |
+
|
| 96 |
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
| 97 |
+
|
| 98 |
+
if use_positional_embeddings or use_learned_positional_embeddings:
|
| 99 |
+
persistent = use_learned_positional_embeddings
|
| 100 |
+
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
| 101 |
+
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
| 102 |
+
|
| 103 |
+
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
| 104 |
+
post_patch_height = sample_height // self.patch_size
|
| 105 |
+
post_patch_width = sample_width // self.patch_size
|
| 106 |
+
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
| 107 |
+
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
| 108 |
+
|
| 109 |
+
pos_embedding = get_3d_sincos_pos_embed(
|
| 110 |
+
self.embed_dim,
|
| 111 |
+
(post_patch_width, post_patch_height),
|
| 112 |
+
post_time_compression_frames,
|
| 113 |
+
self.spatial_interpolation_scale,
|
| 114 |
+
self.temporal_interpolation_scale,
|
| 115 |
+
)
|
| 116 |
+
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
| 117 |
+
joint_pos_embedding = torch.zeros(
|
| 118 |
+
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
| 119 |
+
)
|
| 120 |
+
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
| 121 |
+
|
| 122 |
+
return joint_pos_embedding
|
| 123 |
+
|
| 124 |
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
| 125 |
+
r"""
|
| 126 |
+
Args:
|
| 127 |
+
text_embeds (`torch.Tensor`):
|
| 128 |
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
| 129 |
+
image_embeds (`torch.Tensor`):
|
| 130 |
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
| 131 |
+
"""
|
| 132 |
+
text_embeds = self.text_proj(text_embeds)
|
| 133 |
+
|
| 134 |
+
text_batch_size, text_seq_length, text_channels = text_embeds.shape
|
| 135 |
+
batch_size, num_frames, channels, height, width = image_embeds.shape
|
| 136 |
+
|
| 137 |
+
if self.patch_size_t is None:
|
| 138 |
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
| 139 |
+
image_embeds = self.proj(image_embeds)
|
| 140 |
+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
| 141 |
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
| 142 |
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
| 143 |
+
else:
|
| 144 |
+
p = self.patch_size
|
| 145 |
+
p_t = self.patch_size_t
|
| 146 |
+
|
| 147 |
+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
| 148 |
+
# b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
|
| 149 |
+
image_embeds = image_embeds.reshape(
|
| 150 |
+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
| 151 |
+
)
|
| 152 |
+
# b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
|
| 153 |
+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
| 154 |
+
image_embeds = self.proj(image_embeds)
|
| 155 |
+
|
| 156 |
+
embeds = torch.cat(
|
| 157 |
+
[text_embeds, image_embeds], dim=1
|
| 158 |
+
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
| 159 |
+
|
| 160 |
+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
| 161 |
+
seq_length = height * width * num_frames // (self.patch_size**2)
|
| 162 |
+
# pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
| 163 |
+
pos_embeds = self.pos_embedding
|
| 164 |
+
emb_size = embeds.size()[-1]
|
| 165 |
+
pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
|
| 166 |
+
pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
|
| 167 |
+
pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
|
| 168 |
+
pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
|
| 169 |
+
pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
|
| 170 |
+
pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
|
| 171 |
+
embeds = embeds + pos_embeds
|
| 172 |
+
|
| 173 |
+
return embeds
|
| 174 |
+
|
| 175 |
+
@maybe_allow_in_graph
|
| 176 |
+
class CogVideoXBlock(nn.Module):
|
| 177 |
+
r"""
|
| 178 |
+
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
| 179 |
+
|
| 180 |
+
Parameters:
|
| 181 |
+
dim (`int`):
|
| 182 |
+
The number of channels in the input and output.
|
| 183 |
+
num_attention_heads (`int`):
|
| 184 |
+
The number of heads to use for multi-head attention.
|
| 185 |
+
attention_head_dim (`int`):
|
| 186 |
+
The number of channels in each head.
|
| 187 |
+
time_embed_dim (`int`):
|
| 188 |
+
The number of channels in timestep embedding.
|
| 189 |
+
dropout (`float`, defaults to `0.0`):
|
| 190 |
+
The dropout probability to use.
|
| 191 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 192 |
+
Activation function to be used in feed-forward.
|
| 193 |
+
attention_bias (`bool`, defaults to `False`):
|
| 194 |
+
Whether or not to use bias in attention projection layers.
|
| 195 |
+
qk_norm (`bool`, defaults to `True`):
|
| 196 |
+
Whether or not to use normalization after query and key projections in Attention.
|
| 197 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 198 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 199 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 200 |
+
Epsilon value for normalization layers.
|
| 201 |
+
final_dropout (`bool` defaults to `False`):
|
| 202 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 203 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
| 204 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
| 205 |
+
ff_bias (`bool`, defaults to `True`):
|
| 206 |
+
Whether or not to use bias in Feed-forward layer.
|
| 207 |
+
attention_out_bias (`bool`, defaults to `True`):
|
| 208 |
+
Whether or not to use bias in Attention output projection layer.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
dim: int,
|
| 214 |
+
num_attention_heads: int,
|
| 215 |
+
attention_head_dim: int,
|
| 216 |
+
time_embed_dim: int,
|
| 217 |
+
dropout: float = 0.0,
|
| 218 |
+
activation_fn: str = "gelu-approximate",
|
| 219 |
+
attention_bias: bool = False,
|
| 220 |
+
qk_norm: bool = True,
|
| 221 |
+
norm_elementwise_affine: bool = True,
|
| 222 |
+
norm_eps: float = 1e-5,
|
| 223 |
+
final_dropout: bool = True,
|
| 224 |
+
ff_inner_dim: Optional[int] = None,
|
| 225 |
+
ff_bias: bool = True,
|
| 226 |
+
attention_out_bias: bool = True,
|
| 227 |
+
):
|
| 228 |
+
super().__init__()
|
| 229 |
+
|
| 230 |
+
# 1. Self Attention
|
| 231 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 232 |
+
|
| 233 |
+
self.attn1 = Attention(
|
| 234 |
+
query_dim=dim,
|
| 235 |
+
dim_head=attention_head_dim,
|
| 236 |
+
heads=num_attention_heads,
|
| 237 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 238 |
+
eps=1e-6,
|
| 239 |
+
bias=attention_bias,
|
| 240 |
+
out_bias=attention_out_bias,
|
| 241 |
+
processor=CogVideoXAttnProcessor2_0(),
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# 2. Feed Forward
|
| 245 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 246 |
+
|
| 247 |
+
self.ff = FeedForward(
|
| 248 |
+
dim,
|
| 249 |
+
dropout=dropout,
|
| 250 |
+
activation_fn=activation_fn,
|
| 251 |
+
final_dropout=final_dropout,
|
| 252 |
+
inner_dim=ff_inner_dim,
|
| 253 |
+
bias=ff_bias,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def forward(
|
| 257 |
+
self,
|
| 258 |
+
hidden_states: torch.Tensor,
|
| 259 |
+
encoder_hidden_states: torch.Tensor,
|
| 260 |
+
temb: torch.Tensor,
|
| 261 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 262 |
+
) -> torch.Tensor:
|
| 263 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 264 |
+
|
| 265 |
+
# norm & modulate
|
| 266 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
| 267 |
+
hidden_states, encoder_hidden_states, temb
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# attention
|
| 271 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
| 272 |
+
hidden_states=norm_hidden_states,
|
| 273 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 274 |
+
image_rotary_emb=image_rotary_emb,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
| 278 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
| 279 |
+
|
| 280 |
+
# norm & modulate
|
| 281 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
| 282 |
+
hidden_states, encoder_hidden_states, temb
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# feed-forward
|
| 286 |
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
| 287 |
+
ff_output = self.ff(norm_hidden_states)
|
| 288 |
+
|
| 289 |
+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
| 290 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
| 291 |
+
|
| 292 |
+
return hidden_states, encoder_hidden_states
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
| 296 |
+
"""
|
| 297 |
+
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
| 298 |
+
|
| 299 |
+
Parameters:
|
| 300 |
+
num_attention_heads (`int`, defaults to `30`):
|
| 301 |
+
The number of heads to use for multi-head attention.
|
| 302 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 303 |
+
The number of channels in each head.
|
| 304 |
+
in_channels (`int`, defaults to `16`):
|
| 305 |
+
The number of channels in the input.
|
| 306 |
+
out_channels (`int`, *optional*, defaults to `16`):
|
| 307 |
+
The number of channels in the output.
|
| 308 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 309 |
+
Whether to flip the sin to cos in the time embedding.
|
| 310 |
+
time_embed_dim (`int`, defaults to `512`):
|
| 311 |
+
Output dimension of timestep embeddings.
|
| 312 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 313 |
+
Input dimension of text embeddings from the text encoder.
|
| 314 |
+
num_layers (`int`, defaults to `30`):
|
| 315 |
+
The number of layers of Transformer blocks to use.
|
| 316 |
+
dropout (`float`, defaults to `0.0`):
|
| 317 |
+
The dropout probability to use.
|
| 318 |
+
attention_bias (`bool`, defaults to `True`):
|
| 319 |
+
Whether or not to use bias in the attention projection layers.
|
| 320 |
+
sample_width (`int`, defaults to `90`):
|
| 321 |
+
The width of the input latents.
|
| 322 |
+
sample_height (`int`, defaults to `60`):
|
| 323 |
+
The height of the input latents.
|
| 324 |
+
sample_frames (`int`, defaults to `49`):
|
| 325 |
+
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
| 326 |
+
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
| 327 |
+
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
| 328 |
+
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
| 329 |
+
patch_size (`int`, defaults to `2`):
|
| 330 |
+
The size of the patches to use in the patch embedding layer.
|
| 331 |
+
temporal_compression_ratio (`int`, defaults to `4`):
|
| 332 |
+
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
| 333 |
+
max_text_seq_length (`int`, defaults to `226`):
|
| 334 |
+
The maximum sequence length of the input text embeddings.
|
| 335 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 336 |
+
Activation function to use in feed-forward.
|
| 337 |
+
timestep_activation_fn (`str`, defaults to `"silu"`):
|
| 338 |
+
Activation function to use when generating the timestep embeddings.
|
| 339 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 340 |
+
Whether or not to use elementwise affine in normalization layers.
|
| 341 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 342 |
+
The epsilon value to use in normalization layers.
|
| 343 |
+
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
| 344 |
+
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
| 345 |
+
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
| 346 |
+
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
_supports_gradient_checkpointing = True
|
| 350 |
+
|
| 351 |
+
@register_to_config
|
| 352 |
+
def __init__(
|
| 353 |
+
self,
|
| 354 |
+
num_attention_heads: int = 30,
|
| 355 |
+
attention_head_dim: int = 64,
|
| 356 |
+
in_channels: int = 16,
|
| 357 |
+
out_channels: Optional[int] = 16,
|
| 358 |
+
flip_sin_to_cos: bool = True,
|
| 359 |
+
freq_shift: int = 0,
|
| 360 |
+
time_embed_dim: int = 512,
|
| 361 |
+
text_embed_dim: int = 4096,
|
| 362 |
+
num_layers: int = 30,
|
| 363 |
+
dropout: float = 0.0,
|
| 364 |
+
attention_bias: bool = True,
|
| 365 |
+
sample_width: int = 90,
|
| 366 |
+
sample_height: int = 60,
|
| 367 |
+
sample_frames: int = 49,
|
| 368 |
+
patch_size: int = 2,
|
| 369 |
+
patch_size_t: Optional[int] = None,
|
| 370 |
+
temporal_compression_ratio: int = 4,
|
| 371 |
+
max_text_seq_length: int = 226,
|
| 372 |
+
activation_fn: str = "gelu-approximate",
|
| 373 |
+
timestep_activation_fn: str = "silu",
|
| 374 |
+
norm_elementwise_affine: bool = True,
|
| 375 |
+
norm_eps: float = 1e-5,
|
| 376 |
+
spatial_interpolation_scale: float = 1.875,
|
| 377 |
+
temporal_interpolation_scale: float = 1.0,
|
| 378 |
+
use_rotary_positional_embeddings: bool = False,
|
| 379 |
+
use_learned_positional_embeddings: bool = False,
|
| 380 |
+
patch_bias: bool = True,
|
| 381 |
+
add_noise_in_inpaint_model: bool = False,
|
| 382 |
+
):
|
| 383 |
+
super().__init__()
|
| 384 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 385 |
+
self.patch_size_t = patch_size_t
|
| 386 |
+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
| 389 |
+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
| 390 |
+
"issue at https://github.com/huggingface/diffusers/issues."
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# 1. Patch embedding
|
| 394 |
+
self.patch_embed = CogVideoXPatchEmbed(
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
patch_size_t=patch_size_t,
|
| 397 |
+
in_channels=in_channels,
|
| 398 |
+
embed_dim=inner_dim,
|
| 399 |
+
text_embed_dim=text_embed_dim,
|
| 400 |
+
bias=patch_bias,
|
| 401 |
+
sample_width=sample_width,
|
| 402 |
+
sample_height=sample_height,
|
| 403 |
+
sample_frames=sample_frames,
|
| 404 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 405 |
+
max_text_seq_length=max_text_seq_length,
|
| 406 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
| 407 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
| 408 |
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
| 409 |
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
| 410 |
+
)
|
| 411 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
| 412 |
+
|
| 413 |
+
# 2. Time embeddings
|
| 414 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
| 415 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
| 416 |
+
|
| 417 |
+
# 3. Define spatio-temporal transformers blocks
|
| 418 |
+
self.transformer_blocks = nn.ModuleList(
|
| 419 |
+
[
|
| 420 |
+
CogVideoXBlock(
|
| 421 |
+
dim=inner_dim,
|
| 422 |
+
num_attention_heads=num_attention_heads,
|
| 423 |
+
attention_head_dim=attention_head_dim,
|
| 424 |
+
time_embed_dim=time_embed_dim,
|
| 425 |
+
dropout=dropout,
|
| 426 |
+
activation_fn=activation_fn,
|
| 427 |
+
attention_bias=attention_bias,
|
| 428 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 429 |
+
norm_eps=norm_eps,
|
| 430 |
+
)
|
| 431 |
+
for _ in range(num_layers)
|
| 432 |
+
]
|
| 433 |
+
)
|
| 434 |
+
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
| 435 |
+
|
| 436 |
+
# 4. Output blocks
|
| 437 |
+
self.norm_out = AdaLayerNorm(
|
| 438 |
+
embedding_dim=time_embed_dim,
|
| 439 |
+
output_dim=2 * inner_dim,
|
| 440 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 441 |
+
norm_eps=norm_eps,
|
| 442 |
+
chunk_dim=1,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if patch_size_t is None:
|
| 446 |
+
# For CogVideox 1.0
|
| 447 |
+
output_dim = patch_size * patch_size * out_channels
|
| 448 |
+
else:
|
| 449 |
+
# For CogVideoX 1.5
|
| 450 |
+
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
| 451 |
+
|
| 452 |
+
self.proj_out = nn.Linear(inner_dim, output_dim)
|
| 453 |
+
|
| 454 |
+
self.gradient_checkpointing = False
|
| 455 |
+
self.sp_world_size = 1
|
| 456 |
+
self.sp_world_rank = 0
|
| 457 |
+
|
| 458 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 459 |
+
self.gradient_checkpointing = value
|
| 460 |
+
|
| 461 |
+
def enable_multi_gpus_inference(self,):
|
| 462 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 463 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 464 |
+
self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
|
| 465 |
+
|
| 466 |
+
@property
|
| 467 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 468 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 469 |
+
r"""
|
| 470 |
+
Returns:
|
| 471 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 472 |
+
indexed by its weight name.
|
| 473 |
+
"""
|
| 474 |
+
# set recursively
|
| 475 |
+
processors = {}
|
| 476 |
+
|
| 477 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 478 |
+
if hasattr(module, "get_processor"):
|
| 479 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 480 |
+
|
| 481 |
+
for sub_name, child in module.named_children():
|
| 482 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 483 |
+
|
| 484 |
+
return processors
|
| 485 |
+
|
| 486 |
+
for name, module in self.named_children():
|
| 487 |
+
fn_recursive_add_processors(name, module, processors)
|
| 488 |
+
|
| 489 |
+
return processors
|
| 490 |
+
|
| 491 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 492 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 493 |
+
r"""
|
| 494 |
+
Sets the attention processor to use to compute attention.
|
| 495 |
+
|
| 496 |
+
Parameters:
|
| 497 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 498 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 499 |
+
for **all** `Attention` layers.
|
| 500 |
+
|
| 501 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 502 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 503 |
+
|
| 504 |
+
"""
|
| 505 |
+
count = len(self.attn_processors.keys())
|
| 506 |
+
|
| 507 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 508 |
+
raise ValueError(
|
| 509 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 510 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 514 |
+
if hasattr(module, "set_processor"):
|
| 515 |
+
if not isinstance(processor, dict):
|
| 516 |
+
module.set_processor(processor)
|
| 517 |
+
else:
|
| 518 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 519 |
+
|
| 520 |
+
for sub_name, child in module.named_children():
|
| 521 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 522 |
+
|
| 523 |
+
for name, module in self.named_children():
|
| 524 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 525 |
+
|
| 526 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
| 527 |
+
def fuse_qkv_projections(self):
|
| 528 |
+
"""
|
| 529 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 530 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 531 |
+
|
| 532 |
+
<Tip warning={true}>
|
| 533 |
+
|
| 534 |
+
This API is 🧪 experimental.
|
| 535 |
+
|
| 536 |
+
</Tip>
|
| 537 |
+
"""
|
| 538 |
+
self.original_attn_processors = None
|
| 539 |
+
|
| 540 |
+
for _, attn_processor in self.attn_processors.items():
|
| 541 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 542 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 543 |
+
|
| 544 |
+
self.original_attn_processors = self.attn_processors
|
| 545 |
+
|
| 546 |
+
for module in self.modules():
|
| 547 |
+
if isinstance(module, Attention):
|
| 548 |
+
module.fuse_projections(fuse=True)
|
| 549 |
+
|
| 550 |
+
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
| 551 |
+
|
| 552 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 553 |
+
def unfuse_qkv_projections(self):
|
| 554 |
+
"""Disables the fused QKV projection if enabled.
|
| 555 |
+
|
| 556 |
+
<Tip warning={true}>
|
| 557 |
+
|
| 558 |
+
This API is 🧪 experimental.
|
| 559 |
+
|
| 560 |
+
</Tip>
|
| 561 |
+
|
| 562 |
+
"""
|
| 563 |
+
if self.original_attn_processors is not None:
|
| 564 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 565 |
+
|
| 566 |
+
def forward(
|
| 567 |
+
self,
|
| 568 |
+
hidden_states: torch.Tensor,
|
| 569 |
+
encoder_hidden_states: torch.Tensor,
|
| 570 |
+
timestep: Union[int, float, torch.LongTensor],
|
| 571 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 572 |
+
inpaint_latents: Optional[torch.Tensor] = None,
|
| 573 |
+
control_latents: Optional[torch.Tensor] = None,
|
| 574 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 575 |
+
return_dict: bool = True,
|
| 576 |
+
):
|
| 577 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
| 578 |
+
if num_frames == 1 and self.patch_size_t is not None:
|
| 579 |
+
hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
|
| 580 |
+
if inpaint_latents is not None:
|
| 581 |
+
inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
|
| 582 |
+
if control_latents is not None:
|
| 583 |
+
control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
|
| 584 |
+
local_num_frames = num_frames + 1
|
| 585 |
+
else:
|
| 586 |
+
local_num_frames = num_frames
|
| 587 |
+
|
| 588 |
+
# 1. Time embedding
|
| 589 |
+
timesteps = timestep
|
| 590 |
+
t_emb = self.time_proj(timesteps)
|
| 591 |
+
|
| 592 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 593 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 594 |
+
# there might be better ways to encapsulate this.
|
| 595 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
| 596 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 597 |
+
|
| 598 |
+
# 2. Patch embedding
|
| 599 |
+
if inpaint_latents is not None:
|
| 600 |
+
hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
|
| 601 |
+
if control_latents is not None:
|
| 602 |
+
hidden_states = torch.concat([hidden_states, control_latents], 2)
|
| 603 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
| 604 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
| 605 |
+
|
| 606 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 607 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
| 608 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 609 |
+
|
| 610 |
+
# Context Parallel
|
| 611 |
+
if self.sp_world_size > 1:
|
| 612 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 613 |
+
if image_rotary_emb is not None:
|
| 614 |
+
image_rotary_emb = (
|
| 615 |
+
torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 616 |
+
torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
# 3. Transformer blocks
|
| 620 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 621 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 622 |
+
|
| 623 |
+
def create_custom_forward(module):
|
| 624 |
+
def custom_forward(*inputs):
|
| 625 |
+
return module(*inputs)
|
| 626 |
+
|
| 627 |
+
return custom_forward
|
| 628 |
+
|
| 629 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 630 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 631 |
+
create_custom_forward(block),
|
| 632 |
+
hidden_states,
|
| 633 |
+
encoder_hidden_states,
|
| 634 |
+
emb,
|
| 635 |
+
image_rotary_emb,
|
| 636 |
+
**ckpt_kwargs,
|
| 637 |
+
)
|
| 638 |
+
else:
|
| 639 |
+
hidden_states, encoder_hidden_states = block(
|
| 640 |
+
hidden_states=hidden_states,
|
| 641 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 642 |
+
temb=emb,
|
| 643 |
+
image_rotary_emb=image_rotary_emb,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if not self.config.use_rotary_positional_embeddings:
|
| 647 |
+
# CogVideoX-2B
|
| 648 |
+
hidden_states = self.norm_final(hidden_states)
|
| 649 |
+
else:
|
| 650 |
+
# CogVideoX-5B
|
| 651 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 652 |
+
hidden_states = self.norm_final(hidden_states)
|
| 653 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 654 |
+
|
| 655 |
+
# 4. Final block
|
| 656 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
| 657 |
+
hidden_states = self.proj_out(hidden_states)
|
| 658 |
+
|
| 659 |
+
if self.sp_world_size > 1:
|
| 660 |
+
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
| 661 |
+
|
| 662 |
+
# 5. Unpatchify
|
| 663 |
+
p = self.config.patch_size
|
| 664 |
+
p_t = self.config.patch_size_t
|
| 665 |
+
|
| 666 |
+
if p_t is None:
|
| 667 |
+
output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
|
| 668 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
| 669 |
+
else:
|
| 670 |
+
output = hidden_states.reshape(
|
| 671 |
+
batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
| 672 |
+
)
|
| 673 |
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
| 674 |
+
|
| 675 |
+
if num_frames == 1:
|
| 676 |
+
output = output[:, :num_frames, :]
|
| 677 |
+
|
| 678 |
+
if not return_dict:
|
| 679 |
+
return (output,)
|
| 680 |
+
return Transformer2DModelOutput(sample=output)
|
| 681 |
+
|
| 682 |
+
@classmethod
|
| 683 |
+
def from_pretrained(
|
| 684 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 685 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
| 686 |
+
):
|
| 687 |
+
if subfolder is not None:
|
| 688 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 689 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 690 |
+
|
| 691 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 692 |
+
if not os.path.isfile(config_file):
|
| 693 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 694 |
+
with open(config_file, "r") as f:
|
| 695 |
+
config = json.load(f)
|
| 696 |
+
|
| 697 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 698 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 699 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 700 |
+
|
| 701 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 702 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 703 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 704 |
+
|
| 705 |
+
if low_cpu_mem_usage:
|
| 706 |
+
try:
|
| 707 |
+
import re
|
| 708 |
+
|
| 709 |
+
from diffusers import __version__ as diffusers_version
|
| 710 |
+
if diffusers_version >= "0.33.0":
|
| 711 |
+
from diffusers.models.model_loading_utils import \
|
| 712 |
+
load_model_dict_into_meta
|
| 713 |
+
else:
|
| 714 |
+
from diffusers.models.modeling_utils import \
|
| 715 |
+
load_model_dict_into_meta
|
| 716 |
+
from diffusers.utils import is_accelerate_available
|
| 717 |
+
if is_accelerate_available():
|
| 718 |
+
import accelerate
|
| 719 |
+
|
| 720 |
+
# Instantiate model with empty weights
|
| 721 |
+
with accelerate.init_empty_weights():
|
| 722 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 723 |
+
|
| 724 |
+
param_device = "cpu"
|
| 725 |
+
if os.path.exists(model_file):
|
| 726 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 727 |
+
elif os.path.exists(model_file_safetensors):
|
| 728 |
+
from safetensors.torch import load_file, safe_open
|
| 729 |
+
state_dict = load_file(model_file_safetensors)
|
| 730 |
+
else:
|
| 731 |
+
from safetensors.torch import load_file, safe_open
|
| 732 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 733 |
+
state_dict = {}
|
| 734 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 735 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 736 |
+
for key in _state_dict:
|
| 737 |
+
state_dict[key] = _state_dict[key]
|
| 738 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 739 |
+
|
| 740 |
+
if diffusers_version >= "0.33.0":
|
| 741 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 742 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 743 |
+
load_model_dict_into_meta(
|
| 744 |
+
model,
|
| 745 |
+
state_dict,
|
| 746 |
+
dtype=torch_dtype,
|
| 747 |
+
model_name_or_path=pretrained_model_path,
|
| 748 |
+
)
|
| 749 |
+
else:
|
| 750 |
+
# move the params from meta device to cpu
|
| 751 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 752 |
+
if len(missing_keys) > 0:
|
| 753 |
+
raise ValueError(
|
| 754 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
| 755 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 756 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 757 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 761 |
+
model,
|
| 762 |
+
state_dict,
|
| 763 |
+
device=param_device,
|
| 764 |
+
dtype=torch_dtype,
|
| 765 |
+
model_name_or_path=pretrained_model_path,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 769 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 770 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 771 |
+
|
| 772 |
+
if len(unexpected_keys) > 0:
|
| 773 |
+
print(
|
| 774 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
return model
|
| 778 |
+
except Exception as e:
|
| 779 |
+
print(
|
| 780 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 784 |
+
if os.path.exists(model_file):
|
| 785 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 786 |
+
elif os.path.exists(model_file_safetensors):
|
| 787 |
+
from safetensors.torch import load_file, safe_open
|
| 788 |
+
state_dict = load_file(model_file_safetensors)
|
| 789 |
+
else:
|
| 790 |
+
from safetensors.torch import load_file, safe_open
|
| 791 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 792 |
+
state_dict = {}
|
| 793 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 794 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 795 |
+
for key in _state_dict:
|
| 796 |
+
state_dict[key] = _state_dict[key]
|
| 797 |
+
|
| 798 |
+
if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
|
| 799 |
+
new_shape = model.state_dict()['patch_embed.proj.weight'].size()
|
| 800 |
+
if len(new_shape) == 5:
|
| 801 |
+
state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
| 802 |
+
state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
|
| 803 |
+
elif len(new_shape) == 2:
|
| 804 |
+
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
| 805 |
+
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
|
| 806 |
+
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
|
| 807 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 808 |
+
else:
|
| 809 |
+
model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
|
| 810 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 811 |
+
else:
|
| 812 |
+
if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
|
| 813 |
+
model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
|
| 814 |
+
model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
|
| 815 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 816 |
+
else:
|
| 817 |
+
model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
|
| 818 |
+
state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
|
| 819 |
+
|
| 820 |
+
tmp_state_dict = {}
|
| 821 |
+
for key in state_dict:
|
| 822 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 823 |
+
tmp_state_dict[key] = state_dict[key]
|
| 824 |
+
else:
|
| 825 |
+
print(key, "Size don't match, skip")
|
| 826 |
+
|
| 827 |
+
state_dict = tmp_state_dict
|
| 828 |
+
|
| 829 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 830 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 831 |
+
print(m)
|
| 832 |
+
|
| 833 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 834 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 835 |
+
|
| 836 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 837 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 838 |
+
|
| 839 |
+
model = model.to(torch_dtype)
|
| 840 |
+
return model
|
videox_fun/models/cogvideox_vae.py
ADDED
|
@@ -0,0 +1,1675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 27 |
+
from diffusers.utils import logging
|
| 28 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 29 |
+
from diffusers.models.activations import get_activation
|
| 30 |
+
from diffusers.models.downsampling import CogVideoXDownsample3D
|
| 31 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.upsampling import CogVideoXUpsample3D
|
| 34 |
+
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CogVideoXSafeConv3d(nn.Conv3d):
|
| 41 |
+
r"""
|
| 42 |
+
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
memory_count = (
|
| 47 |
+
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Set to 2GB, suitable for CuDNN
|
| 51 |
+
if memory_count > 2:
|
| 52 |
+
kernel_size = self.kernel_size[0]
|
| 53 |
+
part_num = int(memory_count / 2) + 1
|
| 54 |
+
input_chunks = torch.chunk(input, part_num, dim=2)
|
| 55 |
+
|
| 56 |
+
if kernel_size > 1:
|
| 57 |
+
input_chunks = [input_chunks[0]] + [
|
| 58 |
+
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
| 59 |
+
for i in range(1, len(input_chunks))
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
output_chunks = []
|
| 63 |
+
for input_chunk in input_chunks:
|
| 64 |
+
output_chunks.append(super().forward(input_chunk))
|
| 65 |
+
output = torch.cat(output_chunks, dim=2)
|
| 66 |
+
return output
|
| 67 |
+
else:
|
| 68 |
+
return super().forward(input)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class CogVideoXCausalConv3d(nn.Module):
|
| 72 |
+
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
in_channels (`int`): Number of channels in the input tensor.
|
| 76 |
+
out_channels (`int`): Number of output channels produced by the convolution.
|
| 77 |
+
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
| 78 |
+
stride (`int`, defaults to `1`): Stride of the convolution.
|
| 79 |
+
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
| 80 |
+
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
in_channels: int,
|
| 86 |
+
out_channels: int,
|
| 87 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 88 |
+
stride: int = 1,
|
| 89 |
+
dilation: int = 1,
|
| 90 |
+
pad_mode: str = "constant",
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
|
| 94 |
+
if isinstance(kernel_size, int):
|
| 95 |
+
kernel_size = (kernel_size,) * 3
|
| 96 |
+
|
| 97 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
| 98 |
+
|
| 99 |
+
# TODO(aryan): configure calculation based on stride and dilation in the future.
|
| 100 |
+
# Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
|
| 101 |
+
time_pad = time_kernel_size - 1
|
| 102 |
+
height_pad = (height_kernel_size - 1) // 2
|
| 103 |
+
width_pad = (width_kernel_size - 1) // 2
|
| 104 |
+
|
| 105 |
+
self.pad_mode = pad_mode
|
| 106 |
+
self.height_pad = height_pad
|
| 107 |
+
self.width_pad = width_pad
|
| 108 |
+
self.time_pad = time_pad
|
| 109 |
+
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
| 110 |
+
|
| 111 |
+
self.temporal_dim = 2
|
| 112 |
+
self.time_kernel_size = time_kernel_size
|
| 113 |
+
|
| 114 |
+
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
| 115 |
+
dilation = (dilation, 1, 1)
|
| 116 |
+
self.conv = CogVideoXSafeConv3d(
|
| 117 |
+
in_channels=in_channels,
|
| 118 |
+
out_channels=out_channels,
|
| 119 |
+
kernel_size=kernel_size,
|
| 120 |
+
stride=stride,
|
| 121 |
+
dilation=dilation,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def fake_context_parallel_forward(
|
| 125 |
+
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
| 126 |
+
) -> torch.Tensor:
|
| 127 |
+
if self.pad_mode == "replicate":
|
| 128 |
+
inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
|
| 129 |
+
else:
|
| 130 |
+
kernel_size = self.time_kernel_size
|
| 131 |
+
if kernel_size > 1:
|
| 132 |
+
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
| 133 |
+
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
| 134 |
+
return inputs
|
| 135 |
+
|
| 136 |
+
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 137 |
+
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
| 138 |
+
|
| 139 |
+
if self.pad_mode == "replicate":
|
| 140 |
+
conv_cache = None
|
| 141 |
+
else:
|
| 142 |
+
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
| 143 |
+
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
| 144 |
+
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
| 145 |
+
|
| 146 |
+
output = self.conv(inputs)
|
| 147 |
+
return output, conv_cache
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class CogVideoXSpatialNorm3D(nn.Module):
|
| 151 |
+
r"""
|
| 152 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
|
| 153 |
+
to 3D-video like data.
|
| 154 |
+
|
| 155 |
+
CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
f_channels (`int`):
|
| 159 |
+
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
| 160 |
+
zq_channels (`int`):
|
| 161 |
+
The number of channels for the quantized vector as described in the paper.
|
| 162 |
+
groups (`int`):
|
| 163 |
+
Number of groups to separate the channels into for group normalization.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
f_channels: int,
|
| 169 |
+
zq_channels: int,
|
| 170 |
+
groups: int = 32,
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
| 174 |
+
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
| 175 |
+
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
| 176 |
+
|
| 177 |
+
def forward(
|
| 178 |
+
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
new_conv_cache = {}
|
| 181 |
+
conv_cache = conv_cache or {}
|
| 182 |
+
|
| 183 |
+
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
| 184 |
+
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
| 185 |
+
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
| 186 |
+
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
| 187 |
+
z_first = F.interpolate(z_first, size=f_first_size)
|
| 188 |
+
z_rest = F.interpolate(z_rest, size=f_rest_size)
|
| 189 |
+
zq = torch.cat([z_first, z_rest], dim=2)
|
| 190 |
+
else:
|
| 191 |
+
zq = F.interpolate(zq, size=f.shape[-3:])
|
| 192 |
+
|
| 193 |
+
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
| 194 |
+
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
| 195 |
+
|
| 196 |
+
norm_f = self.norm_layer(f)
|
| 197 |
+
new_f = norm_f * conv_y + conv_b
|
| 198 |
+
return new_f, new_conv_cache
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class CogVideoXUpsample3D(nn.Module):
|
| 202 |
+
r"""
|
| 203 |
+
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
in_channels (`int`):
|
| 207 |
+
Number of channels in the input image.
|
| 208 |
+
out_channels (`int`):
|
| 209 |
+
Number of channels produced by the convolution.
|
| 210 |
+
kernel_size (`int`, defaults to `3`):
|
| 211 |
+
Size of the convolving kernel.
|
| 212 |
+
stride (`int`, defaults to `1`):
|
| 213 |
+
Stride of the convolution.
|
| 214 |
+
padding (`int`, defaults to `1`):
|
| 215 |
+
Padding added to all four sides of the input.
|
| 216 |
+
compress_time (`bool`, defaults to `False`):
|
| 217 |
+
Whether or not to compress the time dimension.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
in_channels: int,
|
| 223 |
+
out_channels: int,
|
| 224 |
+
kernel_size: int = 3,
|
| 225 |
+
stride: int = 1,
|
| 226 |
+
padding: int = 1,
|
| 227 |
+
compress_time: bool = False,
|
| 228 |
+
) -> None:
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
| 232 |
+
self.compress_time = compress_time
|
| 233 |
+
|
| 234 |
+
self.auto_split_process = True
|
| 235 |
+
self.first_frame_flag = False
|
| 236 |
+
|
| 237 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 238 |
+
if self.compress_time:
|
| 239 |
+
if self.auto_split_process:
|
| 240 |
+
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
| 241 |
+
# split first frame
|
| 242 |
+
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
| 243 |
+
|
| 244 |
+
x_first = F.interpolate(x_first, scale_factor=2.0)
|
| 245 |
+
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
| 246 |
+
x_first = x_first[:, :, None, :, :]
|
| 247 |
+
inputs = torch.cat([x_first, x_rest], dim=2)
|
| 248 |
+
elif inputs.shape[2] > 1:
|
| 249 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 250 |
+
else:
|
| 251 |
+
inputs = inputs.squeeze(2)
|
| 252 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 253 |
+
inputs = inputs[:, :, None, :, :]
|
| 254 |
+
else:
|
| 255 |
+
if self.first_frame_flag:
|
| 256 |
+
inputs = inputs.squeeze(2)
|
| 257 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 258 |
+
inputs = inputs[:, :, None, :, :]
|
| 259 |
+
else:
|
| 260 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 261 |
+
else:
|
| 262 |
+
# only interpolate 2D
|
| 263 |
+
b, c, t, h, w = inputs.shape
|
| 264 |
+
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 265 |
+
inputs = F.interpolate(inputs, scale_factor=2.0)
|
| 266 |
+
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
| 267 |
+
|
| 268 |
+
b, c, t, h, w = inputs.shape
|
| 269 |
+
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 270 |
+
inputs = self.conv(inputs)
|
| 271 |
+
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
| 272 |
+
|
| 273 |
+
return inputs
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class CogVideoXResnetBlock3D(nn.Module):
|
| 277 |
+
r"""
|
| 278 |
+
A 3D ResNet block used in the CogVideoX model.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
in_channels (`int`):
|
| 282 |
+
Number of input channels.
|
| 283 |
+
out_channels (`int`, *optional*):
|
| 284 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 285 |
+
dropout (`float`, defaults to `0.0`):
|
| 286 |
+
Dropout rate.
|
| 287 |
+
temb_channels (`int`, defaults to `512`):
|
| 288 |
+
Number of time embedding channels.
|
| 289 |
+
groups (`int`, defaults to `32`):
|
| 290 |
+
Number of groups to separate the channels into for group normalization.
|
| 291 |
+
eps (`float`, defaults to `1e-6`):
|
| 292 |
+
Epsilon value for normalization layers.
|
| 293 |
+
non_linearity (`str`, defaults to `"swish"`):
|
| 294 |
+
Activation function to use.
|
| 295 |
+
conv_shortcut (bool, defaults to `False`):
|
| 296 |
+
Whether or not to use a convolution shortcut.
|
| 297 |
+
spatial_norm_dim (`int`, *optional*):
|
| 298 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 299 |
+
pad_mode (str, defaults to `"first"`):
|
| 300 |
+
Padding mode.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
in_channels: int,
|
| 306 |
+
out_channels: Optional[int] = None,
|
| 307 |
+
dropout: float = 0.0,
|
| 308 |
+
temb_channels: int = 512,
|
| 309 |
+
groups: int = 32,
|
| 310 |
+
eps: float = 1e-6,
|
| 311 |
+
non_linearity: str = "swish",
|
| 312 |
+
conv_shortcut: bool = False,
|
| 313 |
+
spatial_norm_dim: Optional[int] = None,
|
| 314 |
+
pad_mode: str = "first",
|
| 315 |
+
):
|
| 316 |
+
super().__init__()
|
| 317 |
+
|
| 318 |
+
out_channels = out_channels or in_channels
|
| 319 |
+
|
| 320 |
+
self.in_channels = in_channels
|
| 321 |
+
self.out_channels = out_channels
|
| 322 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 323 |
+
self.use_conv_shortcut = conv_shortcut
|
| 324 |
+
self.spatial_norm_dim = spatial_norm_dim
|
| 325 |
+
|
| 326 |
+
if spatial_norm_dim is None:
|
| 327 |
+
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
| 328 |
+
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
| 329 |
+
else:
|
| 330 |
+
self.norm1 = CogVideoXSpatialNorm3D(
|
| 331 |
+
f_channels=in_channels,
|
| 332 |
+
zq_channels=spatial_norm_dim,
|
| 333 |
+
groups=groups,
|
| 334 |
+
)
|
| 335 |
+
self.norm2 = CogVideoXSpatialNorm3D(
|
| 336 |
+
f_channels=out_channels,
|
| 337 |
+
zq_channels=spatial_norm_dim,
|
| 338 |
+
groups=groups,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
self.conv1 = CogVideoXCausalConv3d(
|
| 342 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if temb_channels > 0:
|
| 346 |
+
self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
|
| 347 |
+
|
| 348 |
+
self.dropout = nn.Dropout(dropout)
|
| 349 |
+
self.conv2 = CogVideoXCausalConv3d(
|
| 350 |
+
in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if self.in_channels != self.out_channels:
|
| 354 |
+
if self.use_conv_shortcut:
|
| 355 |
+
self.conv_shortcut = CogVideoXCausalConv3d(
|
| 356 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
self.conv_shortcut = CogVideoXSafeConv3d(
|
| 360 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
def forward(
|
| 364 |
+
self,
|
| 365 |
+
inputs: torch.Tensor,
|
| 366 |
+
temb: Optional[torch.Tensor] = None,
|
| 367 |
+
zq: Optional[torch.Tensor] = None,
|
| 368 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
new_conv_cache = {}
|
| 371 |
+
conv_cache = conv_cache or {}
|
| 372 |
+
|
| 373 |
+
hidden_states = inputs
|
| 374 |
+
|
| 375 |
+
if zq is not None:
|
| 376 |
+
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
| 377 |
+
else:
|
| 378 |
+
hidden_states = self.norm1(hidden_states)
|
| 379 |
+
|
| 380 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 381 |
+
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
| 382 |
+
|
| 383 |
+
if temb is not None:
|
| 384 |
+
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
| 385 |
+
|
| 386 |
+
if zq is not None:
|
| 387 |
+
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
| 388 |
+
else:
|
| 389 |
+
hidden_states = self.norm2(hidden_states)
|
| 390 |
+
|
| 391 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 392 |
+
hidden_states = self.dropout(hidden_states)
|
| 393 |
+
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
| 394 |
+
|
| 395 |
+
if self.in_channels != self.out_channels:
|
| 396 |
+
if self.use_conv_shortcut:
|
| 397 |
+
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
| 398 |
+
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
inputs = self.conv_shortcut(inputs)
|
| 402 |
+
|
| 403 |
+
hidden_states = hidden_states + inputs
|
| 404 |
+
return hidden_states, new_conv_cache
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class CogVideoXDownBlock3D(nn.Module):
|
| 408 |
+
r"""
|
| 409 |
+
A downsampling block used in the CogVideoX model.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
in_channels (`int`):
|
| 413 |
+
Number of input channels.
|
| 414 |
+
out_channels (`int`, *optional*):
|
| 415 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 416 |
+
temb_channels (`int`, defaults to `512`):
|
| 417 |
+
Number of time embedding channels.
|
| 418 |
+
num_layers (`int`, defaults to `1`):
|
| 419 |
+
Number of resnet layers.
|
| 420 |
+
dropout (`float`, defaults to `0.0`):
|
| 421 |
+
Dropout rate.
|
| 422 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 423 |
+
Epsilon value for normalization layers.
|
| 424 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 425 |
+
Activation function to use.
|
| 426 |
+
resnet_groups (`int`, defaults to `32`):
|
| 427 |
+
Number of groups to separate the channels into for group normalization.
|
| 428 |
+
add_downsample (`bool`, defaults to `True`):
|
| 429 |
+
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
| 430 |
+
compress_time (`bool`, defaults to `False`):
|
| 431 |
+
Whether or not to downsample across temporal dimension.
|
| 432 |
+
pad_mode (str, defaults to `"first"`):
|
| 433 |
+
Padding mode.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
_supports_gradient_checkpointing = True
|
| 437 |
+
|
| 438 |
+
def __init__(
|
| 439 |
+
self,
|
| 440 |
+
in_channels: int,
|
| 441 |
+
out_channels: int,
|
| 442 |
+
temb_channels: int,
|
| 443 |
+
dropout: float = 0.0,
|
| 444 |
+
num_layers: int = 1,
|
| 445 |
+
resnet_eps: float = 1e-6,
|
| 446 |
+
resnet_act_fn: str = "swish",
|
| 447 |
+
resnet_groups: int = 32,
|
| 448 |
+
add_downsample: bool = True,
|
| 449 |
+
downsample_padding: int = 0,
|
| 450 |
+
compress_time: bool = False,
|
| 451 |
+
pad_mode: str = "first",
|
| 452 |
+
):
|
| 453 |
+
super().__init__()
|
| 454 |
+
|
| 455 |
+
resnets = []
|
| 456 |
+
for i in range(num_layers):
|
| 457 |
+
in_channel = in_channels if i == 0 else out_channels
|
| 458 |
+
resnets.append(
|
| 459 |
+
CogVideoXResnetBlock3D(
|
| 460 |
+
in_channels=in_channel,
|
| 461 |
+
out_channels=out_channels,
|
| 462 |
+
dropout=dropout,
|
| 463 |
+
temb_channels=temb_channels,
|
| 464 |
+
groups=resnet_groups,
|
| 465 |
+
eps=resnet_eps,
|
| 466 |
+
non_linearity=resnet_act_fn,
|
| 467 |
+
pad_mode=pad_mode,
|
| 468 |
+
)
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
self.resnets = nn.ModuleList(resnets)
|
| 472 |
+
self.downsamplers = None
|
| 473 |
+
|
| 474 |
+
if add_downsample:
|
| 475 |
+
self.downsamplers = nn.ModuleList(
|
| 476 |
+
[
|
| 477 |
+
CogVideoXDownsample3D(
|
| 478 |
+
out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
|
| 479 |
+
)
|
| 480 |
+
]
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
self.gradient_checkpointing = False
|
| 484 |
+
|
| 485 |
+
def forward(
|
| 486 |
+
self,
|
| 487 |
+
hidden_states: torch.Tensor,
|
| 488 |
+
temb: Optional[torch.Tensor] = None,
|
| 489 |
+
zq: Optional[torch.Tensor] = None,
|
| 490 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 491 |
+
) -> torch.Tensor:
|
| 492 |
+
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
| 493 |
+
|
| 494 |
+
new_conv_cache = {}
|
| 495 |
+
conv_cache = conv_cache or {}
|
| 496 |
+
|
| 497 |
+
for i, resnet in enumerate(self.resnets):
|
| 498 |
+
conv_cache_key = f"resnet_{i}"
|
| 499 |
+
|
| 500 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 501 |
+
|
| 502 |
+
def create_custom_forward(module):
|
| 503 |
+
def create_forward(*inputs):
|
| 504 |
+
return module(*inputs)
|
| 505 |
+
|
| 506 |
+
return create_forward
|
| 507 |
+
|
| 508 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 509 |
+
create_custom_forward(resnet),
|
| 510 |
+
hidden_states,
|
| 511 |
+
temb,
|
| 512 |
+
zq,
|
| 513 |
+
conv_cache.get(conv_cache_key),
|
| 514 |
+
)
|
| 515 |
+
else:
|
| 516 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 517 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if self.downsamplers is not None:
|
| 521 |
+
for downsampler in self.downsamplers:
|
| 522 |
+
hidden_states = downsampler(hidden_states)
|
| 523 |
+
|
| 524 |
+
return hidden_states, new_conv_cache
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class CogVideoXMidBlock3D(nn.Module):
|
| 528 |
+
r"""
|
| 529 |
+
A middle block used in the CogVideoX model.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
in_channels (`int`):
|
| 533 |
+
Number of input channels.
|
| 534 |
+
temb_channels (`int`, defaults to `512`):
|
| 535 |
+
Number of time embedding channels.
|
| 536 |
+
dropout (`float`, defaults to `0.0`):
|
| 537 |
+
Dropout rate.
|
| 538 |
+
num_layers (`int`, defaults to `1`):
|
| 539 |
+
Number of resnet layers.
|
| 540 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 541 |
+
Epsilon value for normalization layers.
|
| 542 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 543 |
+
Activation function to use.
|
| 544 |
+
resnet_groups (`int`, defaults to `32`):
|
| 545 |
+
Number of groups to separate the channels into for group normalization.
|
| 546 |
+
spatial_norm_dim (`int`, *optional*):
|
| 547 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 548 |
+
pad_mode (str, defaults to `"first"`):
|
| 549 |
+
Padding mode.
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
_supports_gradient_checkpointing = True
|
| 553 |
+
|
| 554 |
+
def __init__(
|
| 555 |
+
self,
|
| 556 |
+
in_channels: int,
|
| 557 |
+
temb_channels: int,
|
| 558 |
+
dropout: float = 0.0,
|
| 559 |
+
num_layers: int = 1,
|
| 560 |
+
resnet_eps: float = 1e-6,
|
| 561 |
+
resnet_act_fn: str = "swish",
|
| 562 |
+
resnet_groups: int = 32,
|
| 563 |
+
spatial_norm_dim: Optional[int] = None,
|
| 564 |
+
pad_mode: str = "first",
|
| 565 |
+
):
|
| 566 |
+
super().__init__()
|
| 567 |
+
|
| 568 |
+
resnets = []
|
| 569 |
+
for _ in range(num_layers):
|
| 570 |
+
resnets.append(
|
| 571 |
+
CogVideoXResnetBlock3D(
|
| 572 |
+
in_channels=in_channels,
|
| 573 |
+
out_channels=in_channels,
|
| 574 |
+
dropout=dropout,
|
| 575 |
+
temb_channels=temb_channels,
|
| 576 |
+
groups=resnet_groups,
|
| 577 |
+
eps=resnet_eps,
|
| 578 |
+
spatial_norm_dim=spatial_norm_dim,
|
| 579 |
+
non_linearity=resnet_act_fn,
|
| 580 |
+
pad_mode=pad_mode,
|
| 581 |
+
)
|
| 582 |
+
)
|
| 583 |
+
self.resnets = nn.ModuleList(resnets)
|
| 584 |
+
|
| 585 |
+
self.gradient_checkpointing = False
|
| 586 |
+
|
| 587 |
+
def forward(
|
| 588 |
+
self,
|
| 589 |
+
hidden_states: torch.Tensor,
|
| 590 |
+
temb: Optional[torch.Tensor] = None,
|
| 591 |
+
zq: Optional[torch.Tensor] = None,
|
| 592 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 593 |
+
) -> torch.Tensor:
|
| 594 |
+
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
| 595 |
+
|
| 596 |
+
new_conv_cache = {}
|
| 597 |
+
conv_cache = conv_cache or {}
|
| 598 |
+
|
| 599 |
+
for i, resnet in enumerate(self.resnets):
|
| 600 |
+
conv_cache_key = f"resnet_{i}"
|
| 601 |
+
|
| 602 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 603 |
+
|
| 604 |
+
def create_custom_forward(module):
|
| 605 |
+
def create_forward(*inputs):
|
| 606 |
+
return module(*inputs)
|
| 607 |
+
|
| 608 |
+
return create_forward
|
| 609 |
+
|
| 610 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 611 |
+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
| 612 |
+
)
|
| 613 |
+
else:
|
| 614 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 615 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
return hidden_states, new_conv_cache
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class CogVideoXUpBlock3D(nn.Module):
|
| 622 |
+
r"""
|
| 623 |
+
An upsampling block used in the CogVideoX model.
|
| 624 |
+
|
| 625 |
+
Args:
|
| 626 |
+
in_channels (`int`):
|
| 627 |
+
Number of input channels.
|
| 628 |
+
out_channels (`int`, *optional*):
|
| 629 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 630 |
+
temb_channels (`int`, defaults to `512`):
|
| 631 |
+
Number of time embedding channels.
|
| 632 |
+
dropout (`float`, defaults to `0.0`):
|
| 633 |
+
Dropout rate.
|
| 634 |
+
num_layers (`int`, defaults to `1`):
|
| 635 |
+
Number of resnet layers.
|
| 636 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 637 |
+
Epsilon value for normalization layers.
|
| 638 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 639 |
+
Activation function to use.
|
| 640 |
+
resnet_groups (`int`, defaults to `32`):
|
| 641 |
+
Number of groups to separate the channels into for group normalization.
|
| 642 |
+
spatial_norm_dim (`int`, defaults to `16`):
|
| 643 |
+
The dimension to use for spatial norm if it is to be used instead of group norm.
|
| 644 |
+
add_upsample (`bool`, defaults to `True`):
|
| 645 |
+
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
|
| 646 |
+
compress_time (`bool`, defaults to `False`):
|
| 647 |
+
Whether or not to downsample across temporal dimension.
|
| 648 |
+
pad_mode (str, defaults to `"first"`):
|
| 649 |
+
Padding mode.
|
| 650 |
+
"""
|
| 651 |
+
|
| 652 |
+
def __init__(
|
| 653 |
+
self,
|
| 654 |
+
in_channels: int,
|
| 655 |
+
out_channels: int,
|
| 656 |
+
temb_channels: int,
|
| 657 |
+
dropout: float = 0.0,
|
| 658 |
+
num_layers: int = 1,
|
| 659 |
+
resnet_eps: float = 1e-6,
|
| 660 |
+
resnet_act_fn: str = "swish",
|
| 661 |
+
resnet_groups: int = 32,
|
| 662 |
+
spatial_norm_dim: int = 16,
|
| 663 |
+
add_upsample: bool = True,
|
| 664 |
+
upsample_padding: int = 1,
|
| 665 |
+
compress_time: bool = False,
|
| 666 |
+
pad_mode: str = "first",
|
| 667 |
+
):
|
| 668 |
+
super().__init__()
|
| 669 |
+
|
| 670 |
+
resnets = []
|
| 671 |
+
for i in range(num_layers):
|
| 672 |
+
in_channel = in_channels if i == 0 else out_channels
|
| 673 |
+
resnets.append(
|
| 674 |
+
CogVideoXResnetBlock3D(
|
| 675 |
+
in_channels=in_channel,
|
| 676 |
+
out_channels=out_channels,
|
| 677 |
+
dropout=dropout,
|
| 678 |
+
temb_channels=temb_channels,
|
| 679 |
+
groups=resnet_groups,
|
| 680 |
+
eps=resnet_eps,
|
| 681 |
+
non_linearity=resnet_act_fn,
|
| 682 |
+
spatial_norm_dim=spatial_norm_dim,
|
| 683 |
+
pad_mode=pad_mode,
|
| 684 |
+
)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
self.resnets = nn.ModuleList(resnets)
|
| 688 |
+
self.upsamplers = None
|
| 689 |
+
|
| 690 |
+
if add_upsample:
|
| 691 |
+
self.upsamplers = nn.ModuleList(
|
| 692 |
+
[
|
| 693 |
+
CogVideoXUpsample3D(
|
| 694 |
+
out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
|
| 695 |
+
)
|
| 696 |
+
]
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
self.gradient_checkpointing = False
|
| 700 |
+
|
| 701 |
+
def forward(
|
| 702 |
+
self,
|
| 703 |
+
hidden_states: torch.Tensor,
|
| 704 |
+
temb: Optional[torch.Tensor] = None,
|
| 705 |
+
zq: Optional[torch.Tensor] = None,
|
| 706 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 707 |
+
) -> torch.Tensor:
|
| 708 |
+
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
| 709 |
+
|
| 710 |
+
new_conv_cache = {}
|
| 711 |
+
conv_cache = conv_cache or {}
|
| 712 |
+
|
| 713 |
+
for i, resnet in enumerate(self.resnets):
|
| 714 |
+
conv_cache_key = f"resnet_{i}"
|
| 715 |
+
|
| 716 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 717 |
+
|
| 718 |
+
def create_custom_forward(module):
|
| 719 |
+
def create_forward(*inputs):
|
| 720 |
+
return module(*inputs)
|
| 721 |
+
|
| 722 |
+
return create_forward
|
| 723 |
+
|
| 724 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 725 |
+
create_custom_forward(resnet),
|
| 726 |
+
hidden_states,
|
| 727 |
+
temb,
|
| 728 |
+
zq,
|
| 729 |
+
conv_cache.get(conv_cache_key),
|
| 730 |
+
)
|
| 731 |
+
else:
|
| 732 |
+
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
| 733 |
+
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
if self.upsamplers is not None:
|
| 737 |
+
for upsampler in self.upsamplers:
|
| 738 |
+
hidden_states = upsampler(hidden_states)
|
| 739 |
+
|
| 740 |
+
return hidden_states, new_conv_cache
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class CogVideoXEncoder3D(nn.Module):
|
| 744 |
+
r"""
|
| 745 |
+
The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 749 |
+
The number of input channels.
|
| 750 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 751 |
+
The number of output channels.
|
| 752 |
+
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
| 753 |
+
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
| 754 |
+
options.
|
| 755 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 756 |
+
The number of output channels for each block.
|
| 757 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
| 758 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
| 759 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 760 |
+
The number of layers per block.
|
| 761 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 762 |
+
The number of groups for normalization.
|
| 763 |
+
"""
|
| 764 |
+
|
| 765 |
+
_supports_gradient_checkpointing = True
|
| 766 |
+
|
| 767 |
+
def __init__(
|
| 768 |
+
self,
|
| 769 |
+
in_channels: int = 3,
|
| 770 |
+
out_channels: int = 16,
|
| 771 |
+
down_block_types: Tuple[str, ...] = (
|
| 772 |
+
"CogVideoXDownBlock3D",
|
| 773 |
+
"CogVideoXDownBlock3D",
|
| 774 |
+
"CogVideoXDownBlock3D",
|
| 775 |
+
"CogVideoXDownBlock3D",
|
| 776 |
+
),
|
| 777 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
| 778 |
+
layers_per_block: int = 3,
|
| 779 |
+
act_fn: str = "silu",
|
| 780 |
+
norm_eps: float = 1e-6,
|
| 781 |
+
norm_num_groups: int = 32,
|
| 782 |
+
dropout: float = 0.0,
|
| 783 |
+
pad_mode: str = "first",
|
| 784 |
+
temporal_compression_ratio: float = 4,
|
| 785 |
+
):
|
| 786 |
+
super().__init__()
|
| 787 |
+
|
| 788 |
+
# log2 of temporal_compress_times
|
| 789 |
+
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
| 790 |
+
|
| 791 |
+
self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
|
| 792 |
+
self.down_blocks = nn.ModuleList([])
|
| 793 |
+
|
| 794 |
+
# down blocks
|
| 795 |
+
output_channel = block_out_channels[0]
|
| 796 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 797 |
+
input_channel = output_channel
|
| 798 |
+
output_channel = block_out_channels[i]
|
| 799 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 800 |
+
compress_time = i < temporal_compress_level
|
| 801 |
+
|
| 802 |
+
if down_block_type == "CogVideoXDownBlock3D":
|
| 803 |
+
down_block = CogVideoXDownBlock3D(
|
| 804 |
+
in_channels=input_channel,
|
| 805 |
+
out_channels=output_channel,
|
| 806 |
+
temb_channels=0,
|
| 807 |
+
dropout=dropout,
|
| 808 |
+
num_layers=layers_per_block,
|
| 809 |
+
resnet_eps=norm_eps,
|
| 810 |
+
resnet_act_fn=act_fn,
|
| 811 |
+
resnet_groups=norm_num_groups,
|
| 812 |
+
add_downsample=not is_final_block,
|
| 813 |
+
compress_time=compress_time,
|
| 814 |
+
)
|
| 815 |
+
else:
|
| 816 |
+
raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
|
| 817 |
+
|
| 818 |
+
self.down_blocks.append(down_block)
|
| 819 |
+
|
| 820 |
+
# mid block
|
| 821 |
+
self.mid_block = CogVideoXMidBlock3D(
|
| 822 |
+
in_channels=block_out_channels[-1],
|
| 823 |
+
temb_channels=0,
|
| 824 |
+
dropout=dropout,
|
| 825 |
+
num_layers=2,
|
| 826 |
+
resnet_eps=norm_eps,
|
| 827 |
+
resnet_act_fn=act_fn,
|
| 828 |
+
resnet_groups=norm_num_groups,
|
| 829 |
+
pad_mode=pad_mode,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
|
| 833 |
+
self.conv_act = nn.SiLU()
|
| 834 |
+
self.conv_out = CogVideoXCausalConv3d(
|
| 835 |
+
block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
self.gradient_checkpointing = False
|
| 839 |
+
|
| 840 |
+
def forward(
|
| 841 |
+
self,
|
| 842 |
+
sample: torch.Tensor,
|
| 843 |
+
temb: Optional[torch.Tensor] = None,
|
| 844 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 845 |
+
) -> torch.Tensor:
|
| 846 |
+
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
| 847 |
+
|
| 848 |
+
new_conv_cache = {}
|
| 849 |
+
conv_cache = conv_cache or {}
|
| 850 |
+
|
| 851 |
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
| 852 |
+
|
| 853 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 854 |
+
|
| 855 |
+
def create_custom_forward(module):
|
| 856 |
+
def custom_forward(*inputs):
|
| 857 |
+
return module(*inputs)
|
| 858 |
+
|
| 859 |
+
return custom_forward
|
| 860 |
+
|
| 861 |
+
# 1. Down
|
| 862 |
+
for i, down_block in enumerate(self.down_blocks):
|
| 863 |
+
conv_cache_key = f"down_block_{i}"
|
| 864 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 865 |
+
create_custom_forward(down_block),
|
| 866 |
+
hidden_states,
|
| 867 |
+
temb,
|
| 868 |
+
None,
|
| 869 |
+
conv_cache.get(conv_cache_key),
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
# 2. Mid
|
| 873 |
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
| 874 |
+
create_custom_forward(self.mid_block),
|
| 875 |
+
hidden_states,
|
| 876 |
+
temb,
|
| 877 |
+
None,
|
| 878 |
+
conv_cache.get("mid_block"),
|
| 879 |
+
)
|
| 880 |
+
else:
|
| 881 |
+
# 1. Down
|
| 882 |
+
for i, down_block in enumerate(self.down_blocks):
|
| 883 |
+
conv_cache_key = f"down_block_{i}"
|
| 884 |
+
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
| 885 |
+
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# 2. Mid
|
| 889 |
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
| 890 |
+
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
# 3. Post-process
|
| 894 |
+
hidden_states = self.norm_out(hidden_states)
|
| 895 |
+
hidden_states = self.conv_act(hidden_states)
|
| 896 |
+
|
| 897 |
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
| 898 |
+
|
| 899 |
+
return hidden_states, new_conv_cache
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
class CogVideoXDecoder3D(nn.Module):
|
| 903 |
+
r"""
|
| 904 |
+
The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
|
| 905 |
+
sample.
|
| 906 |
+
|
| 907 |
+
Args:
|
| 908 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 909 |
+
The number of input channels.
|
| 910 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 911 |
+
The number of output channels.
|
| 912 |
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
| 913 |
+
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
| 914 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 915 |
+
The number of output channels for each block.
|
| 916 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
| 917 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
| 918 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 919 |
+
The number of layers per block.
|
| 920 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 921 |
+
The number of groups for normalization.
|
| 922 |
+
"""
|
| 923 |
+
|
| 924 |
+
_supports_gradient_checkpointing = True
|
| 925 |
+
|
| 926 |
+
def __init__(
|
| 927 |
+
self,
|
| 928 |
+
in_channels: int = 16,
|
| 929 |
+
out_channels: int = 3,
|
| 930 |
+
up_block_types: Tuple[str, ...] = (
|
| 931 |
+
"CogVideoXUpBlock3D",
|
| 932 |
+
"CogVideoXUpBlock3D",
|
| 933 |
+
"CogVideoXUpBlock3D",
|
| 934 |
+
"CogVideoXUpBlock3D",
|
| 935 |
+
),
|
| 936 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
| 937 |
+
layers_per_block: int = 3,
|
| 938 |
+
act_fn: str = "silu",
|
| 939 |
+
norm_eps: float = 1e-6,
|
| 940 |
+
norm_num_groups: int = 32,
|
| 941 |
+
dropout: float = 0.0,
|
| 942 |
+
pad_mode: str = "first",
|
| 943 |
+
temporal_compression_ratio: float = 4,
|
| 944 |
+
):
|
| 945 |
+
super().__init__()
|
| 946 |
+
|
| 947 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 948 |
+
|
| 949 |
+
self.conv_in = CogVideoXCausalConv3d(
|
| 950 |
+
in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
# mid block
|
| 954 |
+
self.mid_block = CogVideoXMidBlock3D(
|
| 955 |
+
in_channels=reversed_block_out_channels[0],
|
| 956 |
+
temb_channels=0,
|
| 957 |
+
num_layers=2,
|
| 958 |
+
resnet_eps=norm_eps,
|
| 959 |
+
resnet_act_fn=act_fn,
|
| 960 |
+
resnet_groups=norm_num_groups,
|
| 961 |
+
spatial_norm_dim=in_channels,
|
| 962 |
+
pad_mode=pad_mode,
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# up blocks
|
| 966 |
+
self.up_blocks = nn.ModuleList([])
|
| 967 |
+
|
| 968 |
+
output_channel = reversed_block_out_channels[0]
|
| 969 |
+
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
| 970 |
+
|
| 971 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 972 |
+
prev_output_channel = output_channel
|
| 973 |
+
output_channel = reversed_block_out_channels[i]
|
| 974 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 975 |
+
compress_time = i < temporal_compress_level
|
| 976 |
+
|
| 977 |
+
if up_block_type == "CogVideoXUpBlock3D":
|
| 978 |
+
up_block = CogVideoXUpBlock3D(
|
| 979 |
+
in_channels=prev_output_channel,
|
| 980 |
+
out_channels=output_channel,
|
| 981 |
+
temb_channels=0,
|
| 982 |
+
dropout=dropout,
|
| 983 |
+
num_layers=layers_per_block + 1,
|
| 984 |
+
resnet_eps=norm_eps,
|
| 985 |
+
resnet_act_fn=act_fn,
|
| 986 |
+
resnet_groups=norm_num_groups,
|
| 987 |
+
spatial_norm_dim=in_channels,
|
| 988 |
+
add_upsample=not is_final_block,
|
| 989 |
+
compress_time=compress_time,
|
| 990 |
+
pad_mode=pad_mode,
|
| 991 |
+
)
|
| 992 |
+
prev_output_channel = output_channel
|
| 993 |
+
else:
|
| 994 |
+
raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
|
| 995 |
+
|
| 996 |
+
self.up_blocks.append(up_block)
|
| 997 |
+
|
| 998 |
+
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
|
| 999 |
+
self.conv_act = nn.SiLU()
|
| 1000 |
+
self.conv_out = CogVideoXCausalConv3d(
|
| 1001 |
+
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
self.gradient_checkpointing = False
|
| 1005 |
+
|
| 1006 |
+
def forward(
|
| 1007 |
+
self,
|
| 1008 |
+
sample: torch.Tensor,
|
| 1009 |
+
temb: Optional[torch.Tensor] = None,
|
| 1010 |
+
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
| 1011 |
+
) -> torch.Tensor:
|
| 1012 |
+
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
| 1013 |
+
|
| 1014 |
+
new_conv_cache = {}
|
| 1015 |
+
conv_cache = conv_cache or {}
|
| 1016 |
+
|
| 1017 |
+
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
| 1018 |
+
|
| 1019 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1020 |
+
|
| 1021 |
+
def create_custom_forward(module):
|
| 1022 |
+
def custom_forward(*inputs):
|
| 1023 |
+
return module(*inputs)
|
| 1024 |
+
|
| 1025 |
+
return custom_forward
|
| 1026 |
+
|
| 1027 |
+
# 1. Mid
|
| 1028 |
+
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
| 1029 |
+
create_custom_forward(self.mid_block),
|
| 1030 |
+
hidden_states,
|
| 1031 |
+
temb,
|
| 1032 |
+
sample,
|
| 1033 |
+
conv_cache.get("mid_block"),
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
# 2. Up
|
| 1037 |
+
for i, up_block in enumerate(self.up_blocks):
|
| 1038 |
+
conv_cache_key = f"up_block_{i}"
|
| 1039 |
+
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
| 1040 |
+
create_custom_forward(up_block),
|
| 1041 |
+
hidden_states,
|
| 1042 |
+
temb,
|
| 1043 |
+
sample,
|
| 1044 |
+
conv_cache.get(conv_cache_key),
|
| 1045 |
+
)
|
| 1046 |
+
else:
|
| 1047 |
+
# 1. Mid
|
| 1048 |
+
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
| 1049 |
+
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
# 2. Up
|
| 1053 |
+
for i, up_block in enumerate(self.up_blocks):
|
| 1054 |
+
conv_cache_key = f"up_block_{i}"
|
| 1055 |
+
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
| 1056 |
+
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
# 3. Post-process
|
| 1060 |
+
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
| 1061 |
+
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
| 1062 |
+
)
|
| 1063 |
+
hidden_states = self.conv_act(hidden_states)
|
| 1064 |
+
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
| 1065 |
+
|
| 1066 |
+
return hidden_states, new_conv_cache
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 1070 |
+
r"""
|
| 1071 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
|
| 1072 |
+
[CogVideoX](https://github.com/THUDM/CogVideo).
|
| 1073 |
+
|
| 1074 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 1075 |
+
for all models (such as downloading or saving).
|
| 1076 |
+
|
| 1077 |
+
Parameters:
|
| 1078 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
| 1079 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
| 1080 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
| 1081 |
+
Tuple of downsample block types.
|
| 1082 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
| 1083 |
+
Tuple of upsample block types.
|
| 1084 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
| 1085 |
+
Tuple of block output channels.
|
| 1086 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 1087 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
| 1088 |
+
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
| 1089 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
| 1090 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
| 1091 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
| 1092 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
| 1093 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
| 1094 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
| 1095 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
| 1096 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
| 1097 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
| 1098 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
| 1099 |
+
"""
|
| 1100 |
+
|
| 1101 |
+
_supports_gradient_checkpointing = True
|
| 1102 |
+
_no_split_modules = ["CogVideoXResnetBlock3D"]
|
| 1103 |
+
|
| 1104 |
+
@register_to_config
|
| 1105 |
+
def __init__(
|
| 1106 |
+
self,
|
| 1107 |
+
in_channels: int = 3,
|
| 1108 |
+
out_channels: int = 3,
|
| 1109 |
+
down_block_types: Tuple[str] = (
|
| 1110 |
+
"CogVideoXDownBlock3D",
|
| 1111 |
+
"CogVideoXDownBlock3D",
|
| 1112 |
+
"CogVideoXDownBlock3D",
|
| 1113 |
+
"CogVideoXDownBlock3D",
|
| 1114 |
+
),
|
| 1115 |
+
up_block_types: Tuple[str] = (
|
| 1116 |
+
"CogVideoXUpBlock3D",
|
| 1117 |
+
"CogVideoXUpBlock3D",
|
| 1118 |
+
"CogVideoXUpBlock3D",
|
| 1119 |
+
"CogVideoXUpBlock3D",
|
| 1120 |
+
),
|
| 1121 |
+
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
| 1122 |
+
latent_channels: int = 16,
|
| 1123 |
+
layers_per_block: int = 3,
|
| 1124 |
+
act_fn: str = "silu",
|
| 1125 |
+
norm_eps: float = 1e-6,
|
| 1126 |
+
norm_num_groups: int = 32,
|
| 1127 |
+
temporal_compression_ratio: float = 4,
|
| 1128 |
+
sample_height: int = 480,
|
| 1129 |
+
sample_width: int = 720,
|
| 1130 |
+
scaling_factor: float = 1.15258426,
|
| 1131 |
+
shift_factor: Optional[float] = None,
|
| 1132 |
+
latents_mean: Optional[Tuple[float]] = None,
|
| 1133 |
+
latents_std: Optional[Tuple[float]] = None,
|
| 1134 |
+
force_upcast: float = True,
|
| 1135 |
+
use_quant_conv: bool = False,
|
| 1136 |
+
use_post_quant_conv: bool = False,
|
| 1137 |
+
invert_scale_latents: bool = False,
|
| 1138 |
+
):
|
| 1139 |
+
super().__init__()
|
| 1140 |
+
|
| 1141 |
+
self.encoder = CogVideoXEncoder3D(
|
| 1142 |
+
in_channels=in_channels,
|
| 1143 |
+
out_channels=latent_channels,
|
| 1144 |
+
down_block_types=down_block_types,
|
| 1145 |
+
block_out_channels=block_out_channels,
|
| 1146 |
+
layers_per_block=layers_per_block,
|
| 1147 |
+
act_fn=act_fn,
|
| 1148 |
+
norm_eps=norm_eps,
|
| 1149 |
+
norm_num_groups=norm_num_groups,
|
| 1150 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 1151 |
+
)
|
| 1152 |
+
self.decoder = CogVideoXDecoder3D(
|
| 1153 |
+
in_channels=latent_channels,
|
| 1154 |
+
out_channels=out_channels,
|
| 1155 |
+
up_block_types=up_block_types,
|
| 1156 |
+
block_out_channels=block_out_channels,
|
| 1157 |
+
layers_per_block=layers_per_block,
|
| 1158 |
+
act_fn=act_fn,
|
| 1159 |
+
norm_eps=norm_eps,
|
| 1160 |
+
norm_num_groups=norm_num_groups,
|
| 1161 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
| 1162 |
+
)
|
| 1163 |
+
self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
|
| 1164 |
+
self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
|
| 1165 |
+
|
| 1166 |
+
self.use_slicing = False
|
| 1167 |
+
self.use_tiling = False
|
| 1168 |
+
self.auto_split_process = False
|
| 1169 |
+
|
| 1170 |
+
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
| 1171 |
+
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
| 1172 |
+
# If you decode X latent frames together, the number of output frames is:
|
| 1173 |
+
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
| 1174 |
+
#
|
| 1175 |
+
# Example with num_latent_frames_batch_size = 2:
|
| 1176 |
+
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
| 1177 |
+
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
| 1178 |
+
# => 6 * 8 = 48 frames
|
| 1179 |
+
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
| 1180 |
+
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
| 1181 |
+
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
| 1182 |
+
# => 1 * 9 + 5 * 8 = 49 frames
|
| 1183 |
+
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
| 1184 |
+
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
| 1185 |
+
# number of temporal frames.
|
| 1186 |
+
self.num_latent_frames_batch_size = 2
|
| 1187 |
+
self.num_sample_frames_batch_size = 8
|
| 1188 |
+
|
| 1189 |
+
# We make the minimum height and width of sample for tiling half that of the generally supported
|
| 1190 |
+
self.tile_sample_min_height = sample_height // 2
|
| 1191 |
+
self.tile_sample_min_width = sample_width // 2
|
| 1192 |
+
self.tile_latent_min_height = int(
|
| 1193 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
| 1194 |
+
)
|
| 1195 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
| 1196 |
+
|
| 1197 |
+
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
| 1198 |
+
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
| 1199 |
+
# and so the tiling implementation has only been tested on those specific resolutions.
|
| 1200 |
+
self.tile_overlap_factor_height = 1 / 6
|
| 1201 |
+
self.tile_overlap_factor_width = 1 / 5
|
| 1202 |
+
|
| 1203 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1204 |
+
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
| 1205 |
+
module.gradient_checkpointing = value
|
| 1206 |
+
|
| 1207 |
+
def enable_tiling(
|
| 1208 |
+
self,
|
| 1209 |
+
tile_sample_min_height: Optional[int] = None,
|
| 1210 |
+
tile_sample_min_width: Optional[int] = None,
|
| 1211 |
+
tile_overlap_factor_height: Optional[float] = None,
|
| 1212 |
+
tile_overlap_factor_width: Optional[float] = None,
|
| 1213 |
+
) -> None:
|
| 1214 |
+
r"""
|
| 1215 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 1216 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 1217 |
+
processing larger images.
|
| 1218 |
+
|
| 1219 |
+
Args:
|
| 1220 |
+
tile_sample_min_height (`int`, *optional*):
|
| 1221 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
| 1222 |
+
tile_sample_min_width (`int`, *optional*):
|
| 1223 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
| 1224 |
+
tile_overlap_factor_height (`int`, *optional*):
|
| 1225 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
| 1226 |
+
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
| 1227 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
| 1228 |
+
tile_overlap_factor_width (`int`, *optional*):
|
| 1229 |
+
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
| 1230 |
+
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
| 1231 |
+
value might cause more tiles to be processed leading to slow down of the decoding process.
|
| 1232 |
+
"""
|
| 1233 |
+
self.use_tiling = True
|
| 1234 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
| 1235 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
| 1236 |
+
self.tile_latent_min_height = int(
|
| 1237 |
+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
| 1238 |
+
)
|
| 1239 |
+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
| 1240 |
+
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
| 1241 |
+
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
| 1242 |
+
|
| 1243 |
+
def disable_tiling(self) -> None:
|
| 1244 |
+
r"""
|
| 1245 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 1246 |
+
decoding in one step.
|
| 1247 |
+
"""
|
| 1248 |
+
self.use_tiling = False
|
| 1249 |
+
|
| 1250 |
+
def enable_slicing(self) -> None:
|
| 1251 |
+
r"""
|
| 1252 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 1253 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 1254 |
+
"""
|
| 1255 |
+
self.use_slicing = True
|
| 1256 |
+
|
| 1257 |
+
def disable_slicing(self) -> None:
|
| 1258 |
+
r"""
|
| 1259 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 1260 |
+
decoding in one step.
|
| 1261 |
+
"""
|
| 1262 |
+
self.use_slicing = False
|
| 1263 |
+
|
| 1264 |
+
def _set_first_frame(self):
|
| 1265 |
+
for name, module in self.named_modules():
|
| 1266 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1267 |
+
module.auto_split_process = False
|
| 1268 |
+
module.first_frame_flag = True
|
| 1269 |
+
|
| 1270 |
+
def _set_rest_frame(self):
|
| 1271 |
+
for name, module in self.named_modules():
|
| 1272 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1273 |
+
module.auto_split_process = False
|
| 1274 |
+
module.first_frame_flag = False
|
| 1275 |
+
|
| 1276 |
+
def enable_auto_split_process(self) -> None:
|
| 1277 |
+
self.auto_split_process = True
|
| 1278 |
+
for name, module in self.named_modules():
|
| 1279 |
+
if isinstance(module, CogVideoXUpsample3D):
|
| 1280 |
+
module.auto_split_process = True
|
| 1281 |
+
|
| 1282 |
+
def disable_auto_split_process(self) -> None:
|
| 1283 |
+
self.auto_split_process = False
|
| 1284 |
+
|
| 1285 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 1286 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 1287 |
+
|
| 1288 |
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
| 1289 |
+
return self.tiled_encode(x)
|
| 1290 |
+
|
| 1291 |
+
frame_batch_size = self.num_sample_frames_batch_size
|
| 1292 |
+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
| 1293 |
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
| 1294 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1295 |
+
conv_cache = None
|
| 1296 |
+
enc = []
|
| 1297 |
+
|
| 1298 |
+
for i in range(num_batches):
|
| 1299 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1300 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
| 1301 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
| 1302 |
+
x_intermediate = x[:, :, start_frame:end_frame]
|
| 1303 |
+
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
| 1304 |
+
if self.quant_conv is not None:
|
| 1305 |
+
x_intermediate = self.quant_conv(x_intermediate)
|
| 1306 |
+
enc.append(x_intermediate)
|
| 1307 |
+
|
| 1308 |
+
enc = torch.cat(enc, dim=2)
|
| 1309 |
+
return enc
|
| 1310 |
+
|
| 1311 |
+
@apply_forward_hook
|
| 1312 |
+
def encode(
|
| 1313 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 1314 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 1315 |
+
"""
|
| 1316 |
+
Encode a batch of images into latents.
|
| 1317 |
+
|
| 1318 |
+
Args:
|
| 1319 |
+
x (`torch.Tensor`): Input batch of images.
|
| 1320 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1321 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 1322 |
+
|
| 1323 |
+
Returns:
|
| 1324 |
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 1325 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 1326 |
+
"""
|
| 1327 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 1328 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 1329 |
+
h = torch.cat(encoded_slices)
|
| 1330 |
+
else:
|
| 1331 |
+
h = self._encode(x)
|
| 1332 |
+
|
| 1333 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 1334 |
+
|
| 1335 |
+
if not return_dict:
|
| 1336 |
+
return (posterior,)
|
| 1337 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 1338 |
+
|
| 1339 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1340 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 1341 |
+
|
| 1342 |
+
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
| 1343 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 1344 |
+
|
| 1345 |
+
if self.auto_split_process:
|
| 1346 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
| 1347 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1348 |
+
conv_cache = None
|
| 1349 |
+
dec = []
|
| 1350 |
+
|
| 1351 |
+
for i in range(num_batches):
|
| 1352 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1353 |
+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
| 1354 |
+
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
| 1355 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1356 |
+
if self.post_quant_conv is not None:
|
| 1357 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1358 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1359 |
+
dec.append(z_intermediate)
|
| 1360 |
+
else:
|
| 1361 |
+
conv_cache = None
|
| 1362 |
+
start_frame = 0
|
| 1363 |
+
end_frame = 1
|
| 1364 |
+
dec = []
|
| 1365 |
+
|
| 1366 |
+
self._set_first_frame()
|
| 1367 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1368 |
+
if self.post_quant_conv is not None:
|
| 1369 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1370 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1371 |
+
dec.append(z_intermediate)
|
| 1372 |
+
|
| 1373 |
+
self._set_rest_frame()
|
| 1374 |
+
start_frame = end_frame
|
| 1375 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1376 |
+
|
| 1377 |
+
while start_frame < num_frames:
|
| 1378 |
+
z_intermediate = z[:, :, start_frame:end_frame]
|
| 1379 |
+
if self.post_quant_conv is not None:
|
| 1380 |
+
z_intermediate = self.post_quant_conv(z_intermediate)
|
| 1381 |
+
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
| 1382 |
+
dec.append(z_intermediate)
|
| 1383 |
+
start_frame = end_frame
|
| 1384 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1385 |
+
|
| 1386 |
+
dec = torch.cat(dec, dim=2)
|
| 1387 |
+
|
| 1388 |
+
if not return_dict:
|
| 1389 |
+
return (dec,)
|
| 1390 |
+
|
| 1391 |
+
return DecoderOutput(sample=dec)
|
| 1392 |
+
|
| 1393 |
+
@apply_forward_hook
|
| 1394 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1395 |
+
"""
|
| 1396 |
+
Decode a batch of images.
|
| 1397 |
+
|
| 1398 |
+
Args:
|
| 1399 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1400 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1401 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1402 |
+
|
| 1403 |
+
Returns:
|
| 1404 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1405 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1406 |
+
returned.
|
| 1407 |
+
"""
|
| 1408 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 1409 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 1410 |
+
decoded = torch.cat(decoded_slices)
|
| 1411 |
+
else:
|
| 1412 |
+
decoded = self._decode(z).sample
|
| 1413 |
+
|
| 1414 |
+
if not return_dict:
|
| 1415 |
+
return (decoded,)
|
| 1416 |
+
return DecoderOutput(sample=decoded)
|
| 1417 |
+
|
| 1418 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1419 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
| 1420 |
+
for y in range(blend_extent):
|
| 1421 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 1422 |
+
y / blend_extent
|
| 1423 |
+
)
|
| 1424 |
+
return b
|
| 1425 |
+
|
| 1426 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1427 |
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
| 1428 |
+
for x in range(blend_extent):
|
| 1429 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 1430 |
+
x / blend_extent
|
| 1431 |
+
)
|
| 1432 |
+
return b
|
| 1433 |
+
|
| 1434 |
+
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 1435 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 1436 |
+
|
| 1437 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
| 1438 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
| 1439 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
| 1440 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
| 1441 |
+
output, but they should be much less noticeable.
|
| 1442 |
+
|
| 1443 |
+
Args:
|
| 1444 |
+
x (`torch.Tensor`): Input batch of videos.
|
| 1445 |
+
|
| 1446 |
+
Returns:
|
| 1447 |
+
`torch.Tensor`:
|
| 1448 |
+
The latent representation of the encoded videos.
|
| 1449 |
+
"""
|
| 1450 |
+
# For a rough memory estimate, take a look at the `tiled_decode` method.
|
| 1451 |
+
batch_size, num_channels, num_frames, height, width = x.shape
|
| 1452 |
+
|
| 1453 |
+
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
|
| 1454 |
+
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
|
| 1455 |
+
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
|
| 1456 |
+
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
|
| 1457 |
+
row_limit_height = self.tile_latent_min_height - blend_extent_height
|
| 1458 |
+
row_limit_width = self.tile_latent_min_width - blend_extent_width
|
| 1459 |
+
frame_batch_size = self.num_sample_frames_batch_size
|
| 1460 |
+
|
| 1461 |
+
# Split x into overlapping tiles and encode them separately.
|
| 1462 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1463 |
+
rows = []
|
| 1464 |
+
for i in range(0, height, overlap_height):
|
| 1465 |
+
row = []
|
| 1466 |
+
for j in range(0, width, overlap_width):
|
| 1467 |
+
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
| 1468 |
+
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
| 1469 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1470 |
+
conv_cache = None
|
| 1471 |
+
time = []
|
| 1472 |
+
|
| 1473 |
+
for k in range(num_batches):
|
| 1474 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1475 |
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
| 1476 |
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
| 1477 |
+
tile = x[
|
| 1478 |
+
:,
|
| 1479 |
+
:,
|
| 1480 |
+
start_frame:end_frame,
|
| 1481 |
+
i : i + self.tile_sample_min_height,
|
| 1482 |
+
j : j + self.tile_sample_min_width,
|
| 1483 |
+
]
|
| 1484 |
+
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
| 1485 |
+
if self.quant_conv is not None:
|
| 1486 |
+
tile = self.quant_conv(tile)
|
| 1487 |
+
time.append(tile)
|
| 1488 |
+
|
| 1489 |
+
row.append(torch.cat(time, dim=2))
|
| 1490 |
+
rows.append(row)
|
| 1491 |
+
|
| 1492 |
+
result_rows = []
|
| 1493 |
+
for i, row in enumerate(rows):
|
| 1494 |
+
result_row = []
|
| 1495 |
+
for j, tile in enumerate(row):
|
| 1496 |
+
# blend the above tile and the left tile
|
| 1497 |
+
# to the current tile and add the current tile to the result row
|
| 1498 |
+
if i > 0:
|
| 1499 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
| 1500 |
+
if j > 0:
|
| 1501 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
| 1502 |
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
| 1503 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 1504 |
+
|
| 1505 |
+
enc = torch.cat(result_rows, dim=3)
|
| 1506 |
+
return enc
|
| 1507 |
+
|
| 1508 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1509 |
+
r"""
|
| 1510 |
+
Decode a batch of images using a tiled decoder.
|
| 1511 |
+
|
| 1512 |
+
Args:
|
| 1513 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1514 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1515 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1516 |
+
|
| 1517 |
+
Returns:
|
| 1518 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1519 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1520 |
+
returned.
|
| 1521 |
+
"""
|
| 1522 |
+
# Rough memory assessment:
|
| 1523 |
+
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
| 1524 |
+
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
| 1525 |
+
# - Assume fp16 (2 bytes per value).
|
| 1526 |
+
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
| 1527 |
+
#
|
| 1528 |
+
# Memory assessment when using tiling:
|
| 1529 |
+
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
| 1530 |
+
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
| 1531 |
+
|
| 1532 |
+
batch_size, num_channels, num_frames, height, width = z.shape
|
| 1533 |
+
|
| 1534 |
+
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
| 1535 |
+
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
| 1536 |
+
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
| 1537 |
+
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
| 1538 |
+
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
| 1539 |
+
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
| 1540 |
+
frame_batch_size = self.num_latent_frames_batch_size
|
| 1541 |
+
|
| 1542 |
+
# Split z into overlapping tiles and decode them separately.
|
| 1543 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1544 |
+
rows = []
|
| 1545 |
+
for i in range(0, height, overlap_height):
|
| 1546 |
+
row = []
|
| 1547 |
+
for j in range(0, width, overlap_width):
|
| 1548 |
+
if self.auto_split_process:
|
| 1549 |
+
num_batches = max(num_frames // frame_batch_size, 1)
|
| 1550 |
+
conv_cache = None
|
| 1551 |
+
time = []
|
| 1552 |
+
|
| 1553 |
+
for k in range(num_batches):
|
| 1554 |
+
remaining_frames = num_frames % frame_batch_size
|
| 1555 |
+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
| 1556 |
+
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
| 1557 |
+
tile = z[
|
| 1558 |
+
:,
|
| 1559 |
+
:,
|
| 1560 |
+
start_frame:end_frame,
|
| 1561 |
+
i : i + self.tile_latent_min_height,
|
| 1562 |
+
j : j + self.tile_latent_min_width,
|
| 1563 |
+
]
|
| 1564 |
+
if self.post_quant_conv is not None:
|
| 1565 |
+
tile = self.post_quant_conv(tile)
|
| 1566 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1567 |
+
time.append(tile)
|
| 1568 |
+
|
| 1569 |
+
row.append(torch.cat(time, dim=2))
|
| 1570 |
+
else:
|
| 1571 |
+
conv_cache = None
|
| 1572 |
+
start_frame = 0
|
| 1573 |
+
end_frame = 1
|
| 1574 |
+
dec = []
|
| 1575 |
+
|
| 1576 |
+
tile = z[
|
| 1577 |
+
:,
|
| 1578 |
+
:,
|
| 1579 |
+
start_frame:end_frame,
|
| 1580 |
+
i : i + self.tile_latent_min_height,
|
| 1581 |
+
j : j + self.tile_latent_min_width,
|
| 1582 |
+
]
|
| 1583 |
+
|
| 1584 |
+
self._set_first_frame()
|
| 1585 |
+
if self.post_quant_conv is not None:
|
| 1586 |
+
tile = self.post_quant_conv(tile)
|
| 1587 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1588 |
+
dec.append(tile)
|
| 1589 |
+
|
| 1590 |
+
self._set_rest_frame()
|
| 1591 |
+
start_frame = end_frame
|
| 1592 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1593 |
+
|
| 1594 |
+
while start_frame < num_frames:
|
| 1595 |
+
tile = z[
|
| 1596 |
+
:,
|
| 1597 |
+
:,
|
| 1598 |
+
start_frame:end_frame,
|
| 1599 |
+
i : i + self.tile_latent_min_height,
|
| 1600 |
+
j : j + self.tile_latent_min_width,
|
| 1601 |
+
]
|
| 1602 |
+
if self.post_quant_conv is not None:
|
| 1603 |
+
tile = self.post_quant_conv(tile)
|
| 1604 |
+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
| 1605 |
+
dec.append(tile)
|
| 1606 |
+
start_frame = end_frame
|
| 1607 |
+
end_frame += self.num_latent_frames_batch_size
|
| 1608 |
+
|
| 1609 |
+
row.append(torch.cat(dec, dim=2))
|
| 1610 |
+
rows.append(row)
|
| 1611 |
+
|
| 1612 |
+
result_rows = []
|
| 1613 |
+
for i, row in enumerate(rows):
|
| 1614 |
+
result_row = []
|
| 1615 |
+
for j, tile in enumerate(row):
|
| 1616 |
+
# blend the above tile and the left tile
|
| 1617 |
+
# to the current tile and add the current tile to the result row
|
| 1618 |
+
if i > 0:
|
| 1619 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
| 1620 |
+
if j > 0:
|
| 1621 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
| 1622 |
+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
| 1623 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 1624 |
+
|
| 1625 |
+
dec = torch.cat(result_rows, dim=3)
|
| 1626 |
+
|
| 1627 |
+
if not return_dict:
|
| 1628 |
+
return (dec,)
|
| 1629 |
+
|
| 1630 |
+
return DecoderOutput(sample=dec)
|
| 1631 |
+
|
| 1632 |
+
def forward(
|
| 1633 |
+
self,
|
| 1634 |
+
sample: torch.Tensor,
|
| 1635 |
+
sample_posterior: bool = False,
|
| 1636 |
+
return_dict: bool = True,
|
| 1637 |
+
generator: Optional[torch.Generator] = None,
|
| 1638 |
+
) -> Union[torch.Tensor, torch.Tensor]:
|
| 1639 |
+
x = sample
|
| 1640 |
+
posterior = self.encode(x).latent_dist
|
| 1641 |
+
if sample_posterior:
|
| 1642 |
+
z = posterior.sample(generator=generator)
|
| 1643 |
+
else:
|
| 1644 |
+
z = posterior.mode()
|
| 1645 |
+
dec = self.decode(z)
|
| 1646 |
+
if not return_dict:
|
| 1647 |
+
return (dec,)
|
| 1648 |
+
return dec
|
| 1649 |
+
|
| 1650 |
+
@classmethod
|
| 1651 |
+
def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
|
| 1652 |
+
if subfolder is not None:
|
| 1653 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 1654 |
+
|
| 1655 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 1656 |
+
if not os.path.isfile(config_file):
|
| 1657 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 1658 |
+
with open(config_file, "r") as f:
|
| 1659 |
+
config = json.load(f)
|
| 1660 |
+
|
| 1661 |
+
model = cls.from_config(config, **vae_additional_kwargs)
|
| 1662 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 1663 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 1664 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 1665 |
+
if os.path.exists(model_file_safetensors):
|
| 1666 |
+
from safetensors.torch import load_file, safe_open
|
| 1667 |
+
state_dict = load_file(model_file_safetensors)
|
| 1668 |
+
else:
|
| 1669 |
+
if not os.path.isfile(model_file):
|
| 1670 |
+
raise RuntimeError(f"{model_file} does not exist")
|
| 1671 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1672 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1673 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1674 |
+
print(m, u)
|
| 1675 |
+
return model
|
videox_fun/models/flux_transformer2d.py
ADDED
|
@@ -0,0 +1,940 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py
|
| 2 |
+
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 25 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 26 |
+
from diffusers.models.attention import FeedForward
|
| 27 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
| 28 |
+
from diffusers.models.embeddings import (
|
| 29 |
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
| 30 |
+
CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed)
|
| 31 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 33 |
+
from diffusers.models.normalization import (AdaLayerNormContinuous,
|
| 34 |
+
AdaLayerNormZero,
|
| 35 |
+
AdaLayerNormZeroSingle)
|
| 36 |
+
from diffusers.utils import (USE_PEFT_BACKEND, logging, scale_lora_layers,
|
| 37 |
+
unscale_lora_layers)
|
| 38 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 39 |
+
|
| 40 |
+
from ..dist import (FluxMultiGPUsAttnProcessor2_0, get_sequence_parallel_rank,
|
| 41 |
+
get_sequence_parallel_world_size, get_sp_group)
|
| 42 |
+
from .attention_utils import attention
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 45 |
+
|
| 46 |
+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 47 |
+
query = attn.to_q(hidden_states)
|
| 48 |
+
key = attn.to_k(hidden_states)
|
| 49 |
+
value = attn.to_v(hidden_states)
|
| 50 |
+
|
| 51 |
+
encoder_query = encoder_key = encoder_value = None
|
| 52 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 53 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 54 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 55 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 56 |
+
|
| 57 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 58 |
+
|
| 59 |
+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 60 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 61 |
+
|
| 62 |
+
def apply_rotary_emb(
|
| 63 |
+
x: torch.Tensor,
|
| 64 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 65 |
+
use_real: bool = True,
|
| 66 |
+
use_real_unbind_dim: int = -1,
|
| 67 |
+
sequence_dim: int = 2,
|
| 68 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 69 |
+
"""
|
| 70 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 71 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 72 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 73 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
x (`torch.Tensor`):
|
| 77 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 78 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 82 |
+
"""
|
| 83 |
+
if use_real:
|
| 84 |
+
cos, sin = freqs_cis # [S, D]
|
| 85 |
+
if sequence_dim == 2:
|
| 86 |
+
cos = cos[None, None, :, :]
|
| 87 |
+
sin = sin[None, None, :, :]
|
| 88 |
+
elif sequence_dim == 1:
|
| 89 |
+
cos = cos[None, :, None, :]
|
| 90 |
+
sin = sin[None, :, None, :]
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 93 |
+
|
| 94 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 95 |
+
|
| 96 |
+
if use_real_unbind_dim == -1:
|
| 97 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 98 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 99 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 100 |
+
elif use_real_unbind_dim == -2:
|
| 101 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 102 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 103 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 106 |
+
|
| 107 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 108 |
+
|
| 109 |
+
return out
|
| 110 |
+
else:
|
| 111 |
+
# used for lumina
|
| 112 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 113 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 114 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 115 |
+
|
| 116 |
+
return x_out.type_as(x)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class FluxAttnProcessor:
|
| 120 |
+
_attention_backend = None
|
| 121 |
+
|
| 122 |
+
def __init__(self):
|
| 123 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 124 |
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 125 |
+
|
| 126 |
+
def __call__(
|
| 127 |
+
self,
|
| 128 |
+
attn: "FluxAttention",
|
| 129 |
+
hidden_states: torch.Tensor,
|
| 130 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 131 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 132 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 133 |
+
text_seq_len: int = None,
|
| 134 |
+
) -> torch.Tensor:
|
| 135 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 136 |
+
attn, hidden_states, encoder_hidden_states
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 140 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 141 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 142 |
+
|
| 143 |
+
query = attn.norm_q(query)
|
| 144 |
+
key = attn.norm_k(key)
|
| 145 |
+
|
| 146 |
+
if attn.added_kv_proj_dim is not None:
|
| 147 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 148 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 149 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 150 |
+
|
| 151 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 152 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 153 |
+
|
| 154 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 155 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 156 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 157 |
+
|
| 158 |
+
if image_rotary_emb is not None:
|
| 159 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 160 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 161 |
+
|
| 162 |
+
hidden_states = attention(
|
| 163 |
+
query, key, value, attn_mask=attention_mask,
|
| 164 |
+
)
|
| 165 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 166 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 167 |
+
|
| 168 |
+
if encoder_hidden_states is not None:
|
| 169 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 170 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 171 |
+
)
|
| 172 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 173 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 174 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 175 |
+
|
| 176 |
+
return hidden_states, encoder_hidden_states
|
| 177 |
+
else:
|
| 178 |
+
return hidden_states
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
| 182 |
+
"""Flux Attention processor for IP-Adapter."""
|
| 183 |
+
|
| 184 |
+
_attention_backend = None
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
|
| 191 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 192 |
+
raise ImportError(
|
| 193 |
+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.hidden_size = hidden_size
|
| 197 |
+
self.cross_attention_dim = cross_attention_dim
|
| 198 |
+
|
| 199 |
+
if not isinstance(num_tokens, (tuple, list)):
|
| 200 |
+
num_tokens = [num_tokens]
|
| 201 |
+
|
| 202 |
+
if not isinstance(scale, list):
|
| 203 |
+
scale = [scale] * len(num_tokens)
|
| 204 |
+
if len(scale) != len(num_tokens):
|
| 205 |
+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
| 206 |
+
self.scale = scale
|
| 207 |
+
|
| 208 |
+
self.to_k_ip = nn.ModuleList(
|
| 209 |
+
[
|
| 210 |
+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
| 211 |
+
for _ in range(len(num_tokens))
|
| 212 |
+
]
|
| 213 |
+
)
|
| 214 |
+
self.to_v_ip = nn.ModuleList(
|
| 215 |
+
[
|
| 216 |
+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
| 217 |
+
for _ in range(len(num_tokens))
|
| 218 |
+
]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def __call__(
|
| 222 |
+
self,
|
| 223 |
+
attn: "FluxAttention",
|
| 224 |
+
hidden_states: torch.Tensor,
|
| 225 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 226 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 227 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 228 |
+
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
| 229 |
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
batch_size = hidden_states.shape[0]
|
| 232 |
+
|
| 233 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 234 |
+
attn, hidden_states, encoder_hidden_states
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 238 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 239 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 240 |
+
|
| 241 |
+
query = attn.norm_q(query)
|
| 242 |
+
key = attn.norm_k(key)
|
| 243 |
+
ip_query = query
|
| 244 |
+
|
| 245 |
+
if encoder_hidden_states is not None:
|
| 246 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 247 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 248 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 249 |
+
|
| 250 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 251 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 252 |
+
|
| 253 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 254 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 255 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 256 |
+
|
| 257 |
+
if image_rotary_emb is not None:
|
| 258 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 259 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 260 |
+
|
| 261 |
+
hidden_states = attention(
|
| 262 |
+
query,
|
| 263 |
+
key,
|
| 264 |
+
value,
|
| 265 |
+
attn_mask=attention_mask,
|
| 266 |
+
dropout_p=0.0,
|
| 267 |
+
is_causal=False,
|
| 268 |
+
)
|
| 269 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 270 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 271 |
+
|
| 272 |
+
if encoder_hidden_states is not None:
|
| 273 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 274 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 275 |
+
)
|
| 276 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 277 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 278 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 279 |
+
|
| 280 |
+
# IP-adapter
|
| 281 |
+
ip_attn_output = torch.zeros_like(hidden_states)
|
| 282 |
+
|
| 283 |
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
| 284 |
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
| 285 |
+
):
|
| 286 |
+
ip_key = to_k_ip(current_ip_hidden_states)
|
| 287 |
+
ip_value = to_v_ip(current_ip_hidden_states)
|
| 288 |
+
|
| 289 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
|
| 290 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
|
| 291 |
+
|
| 292 |
+
current_ip_hidden_states = dispatch_attention_fn(
|
| 293 |
+
ip_query,
|
| 294 |
+
ip_key,
|
| 295 |
+
ip_value,
|
| 296 |
+
attn_mask=None,
|
| 297 |
+
dropout_p=0.0,
|
| 298 |
+
is_causal=False,
|
| 299 |
+
backend=self._attention_backend,
|
| 300 |
+
)
|
| 301 |
+
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
|
| 302 |
+
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
| 303 |
+
ip_attn_output += scale * current_ip_hidden_states
|
| 304 |
+
|
| 305 |
+
return hidden_states, encoder_hidden_states, ip_attn_output
|
| 306 |
+
else:
|
| 307 |
+
return hidden_states
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class FluxAttention(torch.nn.Module):
|
| 311 |
+
_default_processor_cls = FluxAttnProcessor
|
| 312 |
+
_available_processors = [
|
| 313 |
+
FluxAttnProcessor,
|
| 314 |
+
FluxIPAdapterAttnProcessor,
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
query_dim: int,
|
| 320 |
+
heads: int = 8,
|
| 321 |
+
dim_head: int = 64,
|
| 322 |
+
dropout: float = 0.0,
|
| 323 |
+
bias: bool = False,
|
| 324 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 325 |
+
added_proj_bias: Optional[bool] = True,
|
| 326 |
+
out_bias: bool = True,
|
| 327 |
+
eps: float = 1e-5,
|
| 328 |
+
out_dim: int = None,
|
| 329 |
+
context_pre_only: Optional[bool] = None,
|
| 330 |
+
pre_only: bool = False,
|
| 331 |
+
elementwise_affine: bool = True,
|
| 332 |
+
processor=None,
|
| 333 |
+
):
|
| 334 |
+
super().__init__()
|
| 335 |
+
|
| 336 |
+
self.head_dim = dim_head
|
| 337 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 338 |
+
self.query_dim = query_dim
|
| 339 |
+
self.use_bias = bias
|
| 340 |
+
self.dropout = dropout
|
| 341 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 342 |
+
self.context_pre_only = context_pre_only
|
| 343 |
+
self.pre_only = pre_only
|
| 344 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 345 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 346 |
+
self.added_proj_bias = added_proj_bias
|
| 347 |
+
|
| 348 |
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 349 |
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 350 |
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 351 |
+
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 352 |
+
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 353 |
+
|
| 354 |
+
if not self.pre_only:
|
| 355 |
+
self.to_out = torch.nn.ModuleList([])
|
| 356 |
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
| 357 |
+
self.to_out.append(torch.nn.Dropout(dropout))
|
| 358 |
+
|
| 359 |
+
if added_kv_proj_dim is not None:
|
| 360 |
+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 361 |
+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 362 |
+
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 363 |
+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 364 |
+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 365 |
+
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
| 366 |
+
|
| 367 |
+
if processor is None:
|
| 368 |
+
self.processor = self._default_processor_cls()
|
| 369 |
+
else:
|
| 370 |
+
self.processor = processor
|
| 371 |
+
|
| 372 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
| 373 |
+
r"""
|
| 374 |
+
Set the attention processor to use.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
processor (`AttnProcessor`):
|
| 378 |
+
The attention processor to use.
|
| 379 |
+
"""
|
| 380 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
| 381 |
+
# pop `processor` from `self._modules`
|
| 382 |
+
if (
|
| 383 |
+
hasattr(self, "processor")
|
| 384 |
+
and isinstance(self.processor, torch.nn.Module)
|
| 385 |
+
and not isinstance(processor, torch.nn.Module)
|
| 386 |
+
):
|
| 387 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
| 388 |
+
self._modules.pop("processor")
|
| 389 |
+
|
| 390 |
+
self.processor = processor
|
| 391 |
+
|
| 392 |
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
| 393 |
+
r"""
|
| 394 |
+
Get the attention processor in use.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
| 398 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
"AttentionProcessor": The attention processor in use.
|
| 402 |
+
"""
|
| 403 |
+
if not return_deprecated_lora:
|
| 404 |
+
return self.processor
|
| 405 |
+
|
| 406 |
+
def forward(
|
| 407 |
+
self,
|
| 408 |
+
hidden_states: torch.Tensor,
|
| 409 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 410 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 411 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 412 |
+
**kwargs,
|
| 413 |
+
) -> torch.Tensor:
|
| 414 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 415 |
+
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
| 416 |
+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
| 417 |
+
if len(unused_kwargs) > 0:
|
| 418 |
+
logger.warning(
|
| 419 |
+
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
| 420 |
+
)
|
| 421 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 422 |
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
@maybe_allow_in_graph
|
| 426 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 427 |
+
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 430 |
+
|
| 431 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 432 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 433 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 434 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 435 |
+
|
| 436 |
+
self.attn = FluxAttention(
|
| 437 |
+
query_dim=dim,
|
| 438 |
+
dim_head=attention_head_dim,
|
| 439 |
+
heads=num_attention_heads,
|
| 440 |
+
out_dim=dim,
|
| 441 |
+
bias=True,
|
| 442 |
+
processor=FluxAttnProcessor(),
|
| 443 |
+
eps=1e-6,
|
| 444 |
+
pre_only=True,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def forward(
|
| 448 |
+
self,
|
| 449 |
+
hidden_states: torch.Tensor,
|
| 450 |
+
encoder_hidden_states: torch.Tensor,
|
| 451 |
+
temb: torch.Tensor,
|
| 452 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 453 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 454 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 455 |
+
text_seq_len = encoder_hidden_states.shape[1]
|
| 456 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 457 |
+
|
| 458 |
+
residual = hidden_states
|
| 459 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 460 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 461 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 462 |
+
attn_output = self.attn(
|
| 463 |
+
hidden_states=norm_hidden_states,
|
| 464 |
+
image_rotary_emb=image_rotary_emb,
|
| 465 |
+
text_seq_len=text_seq_len,
|
| 466 |
+
**joint_attention_kwargs,
|
| 467 |
+
)
|
| 468 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 469 |
+
gate = gate.unsqueeze(1)
|
| 470 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 471 |
+
hidden_states = residual + hidden_states
|
| 472 |
+
if hidden_states.dtype == torch.float16:
|
| 473 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 474 |
+
|
| 475 |
+
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
| 476 |
+
return encoder_hidden_states, hidden_states
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@maybe_allow_in_graph
|
| 480 |
+
class FluxTransformerBlock(nn.Module):
|
| 481 |
+
def __init__(
|
| 482 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 483 |
+
):
|
| 484 |
+
super().__init__()
|
| 485 |
+
|
| 486 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 487 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 488 |
+
|
| 489 |
+
self.attn = FluxAttention(
|
| 490 |
+
query_dim=dim,
|
| 491 |
+
added_kv_proj_dim=dim,
|
| 492 |
+
dim_head=attention_head_dim,
|
| 493 |
+
heads=num_attention_heads,
|
| 494 |
+
out_dim=dim,
|
| 495 |
+
context_pre_only=False,
|
| 496 |
+
bias=True,
|
| 497 |
+
processor=FluxAttnProcessor(),
|
| 498 |
+
eps=eps,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 502 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 503 |
+
|
| 504 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 505 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 506 |
+
|
| 507 |
+
def forward(
|
| 508 |
+
self,
|
| 509 |
+
hidden_states: torch.Tensor,
|
| 510 |
+
encoder_hidden_states: torch.Tensor,
|
| 511 |
+
temb: torch.Tensor,
|
| 512 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 513 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 514 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 515 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 516 |
+
|
| 517 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 518 |
+
encoder_hidden_states, emb=temb
|
| 519 |
+
)
|
| 520 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 521 |
+
|
| 522 |
+
# Attention.
|
| 523 |
+
attention_outputs = self.attn(
|
| 524 |
+
hidden_states=norm_hidden_states,
|
| 525 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 526 |
+
image_rotary_emb=image_rotary_emb,
|
| 527 |
+
**joint_attention_kwargs,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
if len(attention_outputs) == 2:
|
| 531 |
+
attn_output, context_attn_output = attention_outputs
|
| 532 |
+
elif len(attention_outputs) == 3:
|
| 533 |
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
| 534 |
+
|
| 535 |
+
# Process attention outputs for the `hidden_states`.
|
| 536 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 537 |
+
hidden_states = hidden_states + attn_output
|
| 538 |
+
|
| 539 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 540 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 541 |
+
|
| 542 |
+
ff_output = self.ff(norm_hidden_states)
|
| 543 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 544 |
+
|
| 545 |
+
hidden_states = hidden_states + ff_output
|
| 546 |
+
if len(attention_outputs) == 3:
|
| 547 |
+
hidden_states = hidden_states + ip_attn_output
|
| 548 |
+
|
| 549 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 550 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 551 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 552 |
+
|
| 553 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 554 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 555 |
+
|
| 556 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 557 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 558 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 559 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 560 |
+
|
| 561 |
+
return encoder_hidden_states, hidden_states
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class FluxPosEmbed(nn.Module):
|
| 565 |
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
| 566 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
| 567 |
+
super().__init__()
|
| 568 |
+
self.theta = theta
|
| 569 |
+
self.axes_dim = axes_dim
|
| 570 |
+
|
| 571 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 572 |
+
n_axes = ids.shape[-1]
|
| 573 |
+
cos_out = []
|
| 574 |
+
sin_out = []
|
| 575 |
+
pos = ids.float()
|
| 576 |
+
is_mps = ids.device.type == "mps"
|
| 577 |
+
is_npu = ids.device.type == "npu"
|
| 578 |
+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
| 579 |
+
for i in range(n_axes):
|
| 580 |
+
cos, sin = get_1d_rotary_pos_embed(
|
| 581 |
+
self.axes_dim[i],
|
| 582 |
+
pos[:, i],
|
| 583 |
+
theta=self.theta,
|
| 584 |
+
repeat_interleave_real=True,
|
| 585 |
+
use_real=True,
|
| 586 |
+
freqs_dtype=freqs_dtype,
|
| 587 |
+
)
|
| 588 |
+
cos_out.append(cos)
|
| 589 |
+
sin_out.append(sin)
|
| 590 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
| 591 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
| 592 |
+
return freqs_cos, freqs_sin
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
class FluxTransformer2DModel(
|
| 596 |
+
ModelMixin,
|
| 597 |
+
ConfigMixin,
|
| 598 |
+
PeftAdapterMixin,
|
| 599 |
+
FromOriginalModelMixin,
|
| 600 |
+
):
|
| 601 |
+
"""
|
| 602 |
+
The Transformer model introduced in Flux.
|
| 603 |
+
|
| 604 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
patch_size (`int`, defaults to `1`):
|
| 608 |
+
Patch size to turn the input data into small patches.
|
| 609 |
+
in_channels (`int`, defaults to `64`):
|
| 610 |
+
The number of channels in the input.
|
| 611 |
+
out_channels (`int`, *optional*, defaults to `None`):
|
| 612 |
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
| 613 |
+
num_layers (`int`, defaults to `19`):
|
| 614 |
+
The number of layers of dual stream DiT blocks to use.
|
| 615 |
+
num_single_layers (`int`, defaults to `38`):
|
| 616 |
+
The number of layers of single stream DiT blocks to use.
|
| 617 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 618 |
+
The number of dimensions to use for each attention head.
|
| 619 |
+
num_attention_heads (`int`, defaults to `24`):
|
| 620 |
+
The number of attention heads to use.
|
| 621 |
+
joint_attention_dim (`int`, defaults to `4096`):
|
| 622 |
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
| 623 |
+
`encoder_hidden_states`).
|
| 624 |
+
pooled_projection_dim (`int`, defaults to `768`):
|
| 625 |
+
The number of dimensions to use for the pooled projection.
|
| 626 |
+
guidance_embeds (`bool`, defaults to `False`):
|
| 627 |
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
| 628 |
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 629 |
+
The dimensions to use for the rotary positional embeddings.
|
| 630 |
+
"""
|
| 631 |
+
|
| 632 |
+
_supports_gradient_checkpointing = True
|
| 633 |
+
# _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 634 |
+
# _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 635 |
+
# _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 636 |
+
|
| 637 |
+
@register_to_config
|
| 638 |
+
def __init__(
|
| 639 |
+
self,
|
| 640 |
+
patch_size: int = 1,
|
| 641 |
+
in_channels: int = 64,
|
| 642 |
+
out_channels: Optional[int] = None,
|
| 643 |
+
num_layers: int = 19,
|
| 644 |
+
num_single_layers: int = 38,
|
| 645 |
+
attention_head_dim: int = 128,
|
| 646 |
+
num_attention_heads: int = 24,
|
| 647 |
+
joint_attention_dim: int = 4096,
|
| 648 |
+
pooled_projection_dim: int = 768,
|
| 649 |
+
guidance_embeds: bool = False,
|
| 650 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
| 651 |
+
):
|
| 652 |
+
super().__init__()
|
| 653 |
+
self.out_channels = out_channels or in_channels
|
| 654 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 655 |
+
|
| 656 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 657 |
+
|
| 658 |
+
text_time_guidance_cls = (
|
| 659 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
| 660 |
+
)
|
| 661 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 662 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 666 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
| 667 |
+
|
| 668 |
+
self.transformer_blocks = nn.ModuleList(
|
| 669 |
+
[
|
| 670 |
+
FluxTransformerBlock(
|
| 671 |
+
dim=self.inner_dim,
|
| 672 |
+
num_attention_heads=num_attention_heads,
|
| 673 |
+
attention_head_dim=attention_head_dim,
|
| 674 |
+
)
|
| 675 |
+
for _ in range(num_layers)
|
| 676 |
+
]
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 680 |
+
[
|
| 681 |
+
FluxSingleTransformerBlock(
|
| 682 |
+
dim=self.inner_dim,
|
| 683 |
+
num_attention_heads=num_attention_heads,
|
| 684 |
+
attention_head_dim=attention_head_dim,
|
| 685 |
+
)
|
| 686 |
+
for _ in range(num_single_layers)
|
| 687 |
+
]
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 691 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 692 |
+
|
| 693 |
+
self.gradient_checkpointing = False
|
| 694 |
+
|
| 695 |
+
self.sp_world_size = 1
|
| 696 |
+
self.sp_world_rank = 0
|
| 697 |
+
|
| 698 |
+
def enable_multi_gpus_inference(self,):
|
| 699 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 700 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 701 |
+
self.all_gather = get_sp_group().all_gather
|
| 702 |
+
self.set_attn_processor(FluxMultiGPUsAttnProcessor2_0())
|
| 703 |
+
|
| 704 |
+
@property
|
| 705 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 706 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 707 |
+
r"""
|
| 708 |
+
Returns:
|
| 709 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 710 |
+
indexed by its weight name.
|
| 711 |
+
"""
|
| 712 |
+
# set recursively
|
| 713 |
+
processors = {}
|
| 714 |
+
|
| 715 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 716 |
+
if hasattr(module, "get_processor"):
|
| 717 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 718 |
+
|
| 719 |
+
for sub_name, child in module.named_children():
|
| 720 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 721 |
+
|
| 722 |
+
return processors
|
| 723 |
+
|
| 724 |
+
for name, module in self.named_children():
|
| 725 |
+
fn_recursive_add_processors(name, module, processors)
|
| 726 |
+
|
| 727 |
+
return processors
|
| 728 |
+
|
| 729 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 730 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 731 |
+
r"""
|
| 732 |
+
Sets the attention processor to use to compute attention.
|
| 733 |
+
|
| 734 |
+
Parameters:
|
| 735 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 736 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 737 |
+
for **all** `Attention` layers.
|
| 738 |
+
|
| 739 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 740 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 741 |
+
|
| 742 |
+
"""
|
| 743 |
+
count = len(self.attn_processors.keys())
|
| 744 |
+
|
| 745 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 746 |
+
raise ValueError(
|
| 747 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 748 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 752 |
+
if hasattr(module, "set_processor"):
|
| 753 |
+
if not isinstance(processor, dict):
|
| 754 |
+
module.set_processor(processor)
|
| 755 |
+
else:
|
| 756 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 757 |
+
|
| 758 |
+
for sub_name, child in module.named_children():
|
| 759 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 760 |
+
|
| 761 |
+
for name, module in self.named_children():
|
| 762 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 763 |
+
|
| 764 |
+
def forward(
|
| 765 |
+
self,
|
| 766 |
+
hidden_states: torch.Tensor,
|
| 767 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 768 |
+
pooled_projections: torch.Tensor = None,
|
| 769 |
+
timestep: torch.LongTensor = None,
|
| 770 |
+
img_ids: torch.Tensor = None,
|
| 771 |
+
txt_ids: torch.Tensor = None,
|
| 772 |
+
guidance: torch.Tensor = None,
|
| 773 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 774 |
+
controlnet_block_samples=None,
|
| 775 |
+
controlnet_single_block_samples=None,
|
| 776 |
+
return_dict: bool = True,
|
| 777 |
+
controlnet_blocks_repeat: bool = False,
|
| 778 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 779 |
+
"""
|
| 780 |
+
The [`FluxTransformer2DModel`] forward method.
|
| 781 |
+
|
| 782 |
+
Args:
|
| 783 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 784 |
+
Input `hidden_states`.
|
| 785 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 786 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 787 |
+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
| 788 |
+
from the embeddings of input conditions.
|
| 789 |
+
timestep ( `torch.LongTensor`):
|
| 790 |
+
Used to indicate denoising step.
|
| 791 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 792 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 793 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 794 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 795 |
+
`self.processor` in
|
| 796 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 797 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 798 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 799 |
+
tuple.
|
| 800 |
+
|
| 801 |
+
Returns:
|
| 802 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 803 |
+
`tuple` where the first element is the sample tensor.
|
| 804 |
+
"""
|
| 805 |
+
if joint_attention_kwargs is not None:
|
| 806 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 807 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 808 |
+
else:
|
| 809 |
+
lora_scale = 1.0
|
| 810 |
+
|
| 811 |
+
if USE_PEFT_BACKEND:
|
| 812 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 813 |
+
scale_lora_layers(self, lora_scale)
|
| 814 |
+
else:
|
| 815 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 816 |
+
logger.warning(
|
| 817 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 821 |
+
|
| 822 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 823 |
+
if guidance is not None:
|
| 824 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 825 |
+
|
| 826 |
+
temb = (
|
| 827 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 828 |
+
if guidance is None
|
| 829 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 830 |
+
)
|
| 831 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 832 |
+
|
| 833 |
+
if txt_ids.ndim == 3:
|
| 834 |
+
logger.warning(
|
| 835 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 836 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 837 |
+
)
|
| 838 |
+
txt_ids = txt_ids[0]
|
| 839 |
+
if img_ids.ndim == 3:
|
| 840 |
+
logger.warning(
|
| 841 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 842 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 843 |
+
)
|
| 844 |
+
img_ids = img_ids[0]
|
| 845 |
+
|
| 846 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 847 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 848 |
+
|
| 849 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 850 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 851 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 852 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 853 |
+
|
| 854 |
+
# Context Parallel
|
| 855 |
+
if self.sp_world_size > 1:
|
| 856 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 857 |
+
if image_rotary_emb is not None:
|
| 858 |
+
txt_rotary_emb = (
|
| 859 |
+
image_rotary_emb[0][:encoder_hidden_states.shape[1]],
|
| 860 |
+
image_rotary_emb[1][:encoder_hidden_states.shape[1]]
|
| 861 |
+
)
|
| 862 |
+
image_rotary_emb = (
|
| 863 |
+
torch.chunk(image_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 864 |
+
torch.chunk(image_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 865 |
+
)
|
| 866 |
+
image_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \
|
| 867 |
+
for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, image_rotary_emb)]
|
| 868 |
+
|
| 869 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 870 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 871 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 872 |
+
block,
|
| 873 |
+
hidden_states,
|
| 874 |
+
encoder_hidden_states,
|
| 875 |
+
temb,
|
| 876 |
+
image_rotary_emb,
|
| 877 |
+
joint_attention_kwargs,
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
else:
|
| 881 |
+
encoder_hidden_states, hidden_states = block(
|
| 882 |
+
hidden_states=hidden_states,
|
| 883 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 884 |
+
temb=temb,
|
| 885 |
+
image_rotary_emb=image_rotary_emb,
|
| 886 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
# controlnet residual
|
| 890 |
+
if controlnet_block_samples is not None:
|
| 891 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 892 |
+
interval_control = int(np.ceil(interval_control))
|
| 893 |
+
# For Xlabs ControlNet.
|
| 894 |
+
if controlnet_blocks_repeat:
|
| 895 |
+
hidden_states = (
|
| 896 |
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 897 |
+
)
|
| 898 |
+
else:
|
| 899 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 900 |
+
|
| 901 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 902 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 903 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 904 |
+
block,
|
| 905 |
+
hidden_states,
|
| 906 |
+
encoder_hidden_states,
|
| 907 |
+
temb,
|
| 908 |
+
image_rotary_emb,
|
| 909 |
+
joint_attention_kwargs,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
else:
|
| 913 |
+
encoder_hidden_states, hidden_states = block(
|
| 914 |
+
hidden_states=hidden_states,
|
| 915 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 916 |
+
temb=temb,
|
| 917 |
+
image_rotary_emb=image_rotary_emb,
|
| 918 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
# controlnet residual
|
| 922 |
+
if controlnet_single_block_samples is not None:
|
| 923 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 924 |
+
interval_control = int(np.ceil(interval_control))
|
| 925 |
+
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
| 926 |
+
|
| 927 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 928 |
+
output = self.proj_out(hidden_states)
|
| 929 |
+
|
| 930 |
+
if self.sp_world_size > 1:
|
| 931 |
+
output = self.all_gather(output, dim=1)
|
| 932 |
+
|
| 933 |
+
if USE_PEFT_BACKEND:
|
| 934 |
+
# remove `lora_scale` from each PEFT layer
|
| 935 |
+
unscale_lora_layers(self, lora_scale)
|
| 936 |
+
|
| 937 |
+
if not return_dict:
|
| 938 |
+
return (output,)
|
| 939 |
+
|
| 940 |
+
return Transformer2DModelOutput(sample=output)
|
videox_fun/models/qwenimage_transformer2d.py
ADDED
|
@@ -0,0 +1,893 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_qwenimage.py
|
| 2 |
+
# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import functools
|
| 18 |
+
import glob
|
| 19 |
+
import json
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import types
|
| 23 |
+
import warnings
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.cuda.amp as amp
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 32 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 33 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 34 |
+
from diffusers.models.attention import Attention, FeedForward
|
| 35 |
+
from diffusers.models.attention_processor import (
|
| 36 |
+
Attention, AttentionProcessor, CogVideoXAttnProcessor2_0,
|
| 37 |
+
FusedCogVideoXAttnProcessor2_0)
|
| 38 |
+
from diffusers.models.embeddings import (CogVideoXPatchEmbed,
|
| 39 |
+
TimestepEmbedding, Timesteps,
|
| 40 |
+
get_3d_sincos_pos_embed)
|
| 41 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 42 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 43 |
+
from diffusers.models.normalization import (AdaLayerNorm,
|
| 44 |
+
AdaLayerNormContinuous,
|
| 45 |
+
CogVideoXLayerNormZero, RMSNorm)
|
| 46 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 47 |
+
scale_lora_layers, unscale_lora_layers)
|
| 48 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 49 |
+
from torch import nn
|
| 50 |
+
|
| 51 |
+
from ..dist import (QwenImageMultiGPUsAttnProcessor2_0,
|
| 52 |
+
get_sequence_parallel_rank,
|
| 53 |
+
get_sequence_parallel_world_size, get_sp_group)
|
| 54 |
+
from .attention_utils import attention
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_timestep_embedding(
|
| 60 |
+
timesteps: torch.Tensor,
|
| 61 |
+
embedding_dim: int,
|
| 62 |
+
flip_sin_to_cos: bool = False,
|
| 63 |
+
downscale_freq_shift: float = 1,
|
| 64 |
+
scale: float = 1,
|
| 65 |
+
max_period: int = 10000,
|
| 66 |
+
) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 69 |
+
|
| 70 |
+
Args
|
| 71 |
+
timesteps (torch.Tensor):
|
| 72 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 73 |
+
embedding_dim (int):
|
| 74 |
+
the dimension of the output.
|
| 75 |
+
flip_sin_to_cos (bool):
|
| 76 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
| 77 |
+
downscale_freq_shift (float):
|
| 78 |
+
Controls the delta between frequencies between dimensions
|
| 79 |
+
scale (float):
|
| 80 |
+
Scaling factor applied to the embeddings.
|
| 81 |
+
max_period (int):
|
| 82 |
+
Controls the maximum frequency of the embeddings
|
| 83 |
+
Returns
|
| 84 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
| 85 |
+
"""
|
| 86 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 87 |
+
|
| 88 |
+
half_dim = embedding_dim // 2
|
| 89 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 90 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 91 |
+
)
|
| 92 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 93 |
+
|
| 94 |
+
emb = torch.exp(exponent).to(timesteps.dtype)
|
| 95 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 96 |
+
|
| 97 |
+
# scale embeddings
|
| 98 |
+
emb = scale * emb
|
| 99 |
+
|
| 100 |
+
# concat sine and cosine embeddings
|
| 101 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 102 |
+
|
| 103 |
+
# flip sine and cosine embeddings
|
| 104 |
+
if flip_sin_to_cos:
|
| 105 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 106 |
+
|
| 107 |
+
# zero pad
|
| 108 |
+
if embedding_dim % 2 == 1:
|
| 109 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 110 |
+
return emb
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def apply_rotary_emb_qwen(
|
| 114 |
+
x: torch.Tensor,
|
| 115 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 116 |
+
use_real: bool = True,
|
| 117 |
+
use_real_unbind_dim: int = -1,
|
| 118 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 119 |
+
"""
|
| 120 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 121 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 122 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 123 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
x (`torch.Tensor`):
|
| 127 |
+
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
| 128 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 132 |
+
"""
|
| 133 |
+
if use_real:
|
| 134 |
+
cos, sin = freqs_cis # [S, D]
|
| 135 |
+
cos = cos[None, None]
|
| 136 |
+
sin = sin[None, None]
|
| 137 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 138 |
+
|
| 139 |
+
if use_real_unbind_dim == -1:
|
| 140 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 141 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 142 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 143 |
+
elif use_real_unbind_dim == -2:
|
| 144 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 145 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
| 146 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 149 |
+
|
| 150 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
else:
|
| 154 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 155 |
+
freqs_cis = freqs_cis.unsqueeze(1)
|
| 156 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 157 |
+
|
| 158 |
+
return x_out.type_as(x)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class QwenTimestepProjEmbeddings(nn.Module):
|
| 162 |
+
def __init__(self, embedding_dim):
|
| 163 |
+
super().__init__()
|
| 164 |
+
|
| 165 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
| 166 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 167 |
+
|
| 168 |
+
def forward(self, timestep, hidden_states):
|
| 169 |
+
timesteps_proj = self.time_proj(timestep)
|
| 170 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
| 171 |
+
|
| 172 |
+
conditioning = timesteps_emb
|
| 173 |
+
|
| 174 |
+
return conditioning
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class QwenEmbedRope(nn.Module):
|
| 178 |
+
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.theta = theta
|
| 181 |
+
self.axes_dim = axes_dim
|
| 182 |
+
pos_index = torch.arange(4096)
|
| 183 |
+
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
| 184 |
+
self.pos_freqs = torch.cat(
|
| 185 |
+
[
|
| 186 |
+
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
| 187 |
+
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
| 188 |
+
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
| 189 |
+
],
|
| 190 |
+
dim=1,
|
| 191 |
+
)
|
| 192 |
+
self.neg_freqs = torch.cat(
|
| 193 |
+
[
|
| 194 |
+
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
| 195 |
+
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
| 196 |
+
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
| 197 |
+
],
|
| 198 |
+
dim=1,
|
| 199 |
+
)
|
| 200 |
+
self.rope_cache = {}
|
| 201 |
+
|
| 202 |
+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
|
| 203 |
+
self.scale_rope = scale_rope
|
| 204 |
+
|
| 205 |
+
def rope_params(self, index, dim, theta=10000):
|
| 206 |
+
"""
|
| 207 |
+
Args:
|
| 208 |
+
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
| 209 |
+
"""
|
| 210 |
+
assert dim % 2 == 0
|
| 211 |
+
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
| 212 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 213 |
+
return freqs
|
| 214 |
+
|
| 215 |
+
def forward(self, video_fhw, txt_seq_lens, device):
|
| 216 |
+
"""
|
| 217 |
+
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
| 218 |
+
txt_length: [bs] a list of 1 integers representing the length of the text
|
| 219 |
+
"""
|
| 220 |
+
if self.pos_freqs.device != device:
|
| 221 |
+
self.pos_freqs = self.pos_freqs.to(device)
|
| 222 |
+
self.neg_freqs = self.neg_freqs.to(device)
|
| 223 |
+
|
| 224 |
+
if isinstance(video_fhw, list):
|
| 225 |
+
video_fhw = video_fhw[0]
|
| 226 |
+
if not isinstance(video_fhw, list):
|
| 227 |
+
video_fhw = [video_fhw]
|
| 228 |
+
|
| 229 |
+
vid_freqs = []
|
| 230 |
+
max_vid_index = 0
|
| 231 |
+
for idx, fhw in enumerate(video_fhw):
|
| 232 |
+
frame, height, width = fhw
|
| 233 |
+
rope_key = f"{idx}_{height}_{width}"
|
| 234 |
+
|
| 235 |
+
if not torch.compiler.is_compiling():
|
| 236 |
+
if rope_key not in self.rope_cache:
|
| 237 |
+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
|
| 238 |
+
video_freq = self.rope_cache[rope_key]
|
| 239 |
+
else:
|
| 240 |
+
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
| 241 |
+
video_freq = video_freq.to(device)
|
| 242 |
+
vid_freqs.append(video_freq)
|
| 243 |
+
|
| 244 |
+
if self.scale_rope:
|
| 245 |
+
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
| 246 |
+
else:
|
| 247 |
+
max_vid_index = max(height, width, max_vid_index)
|
| 248 |
+
|
| 249 |
+
max_len = max(txt_seq_lens)
|
| 250 |
+
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
| 251 |
+
vid_freqs = torch.cat(vid_freqs, dim=0)
|
| 252 |
+
|
| 253 |
+
return vid_freqs, txt_freqs
|
| 254 |
+
|
| 255 |
+
@functools.lru_cache(maxsize=None)
|
| 256 |
+
def _compute_video_freqs(self, frame, height, width, idx=0):
|
| 257 |
+
seq_lens = frame * height * width
|
| 258 |
+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
| 259 |
+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
| 260 |
+
|
| 261 |
+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
| 262 |
+
if self.scale_rope:
|
| 263 |
+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
| 264 |
+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
| 265 |
+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
| 266 |
+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
| 267 |
+
else:
|
| 268 |
+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
| 269 |
+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
| 270 |
+
|
| 271 |
+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
| 272 |
+
return freqs.clone().contiguous()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class QwenDoubleStreamAttnProcessor2_0:
|
| 276 |
+
"""
|
| 277 |
+
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
| 278 |
+
implements joint attention computation where text and image streams are processed together.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
_attention_backend = None
|
| 282 |
+
|
| 283 |
+
def __init__(self):
|
| 284 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 285 |
+
raise ImportError(
|
| 286 |
+
"QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def __call__(
|
| 290 |
+
self,
|
| 291 |
+
attn: Attention,
|
| 292 |
+
hidden_states: torch.FloatTensor, # Image stream
|
| 293 |
+
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
| 294 |
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
| 295 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 296 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 297 |
+
) -> torch.FloatTensor:
|
| 298 |
+
if encoder_hidden_states is None:
|
| 299 |
+
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
| 300 |
+
|
| 301 |
+
seq_txt = encoder_hidden_states.shape[1]
|
| 302 |
+
|
| 303 |
+
# Compute QKV for image stream (sample projections)
|
| 304 |
+
img_query = attn.to_q(hidden_states)
|
| 305 |
+
img_key = attn.to_k(hidden_states)
|
| 306 |
+
img_value = attn.to_v(hidden_states)
|
| 307 |
+
|
| 308 |
+
# Compute QKV for text stream (context projections)
|
| 309 |
+
txt_query = attn.add_q_proj(encoder_hidden_states)
|
| 310 |
+
txt_key = attn.add_k_proj(encoder_hidden_states)
|
| 311 |
+
txt_value = attn.add_v_proj(encoder_hidden_states)
|
| 312 |
+
|
| 313 |
+
# Reshape for multi-head attention
|
| 314 |
+
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
| 315 |
+
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
| 316 |
+
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
| 317 |
+
|
| 318 |
+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
| 319 |
+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
| 320 |
+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
| 321 |
+
|
| 322 |
+
# Apply QK normalization
|
| 323 |
+
if attn.norm_q is not None:
|
| 324 |
+
img_query = attn.norm_q(img_query)
|
| 325 |
+
if attn.norm_k is not None:
|
| 326 |
+
img_key = attn.norm_k(img_key)
|
| 327 |
+
if attn.norm_added_q is not None:
|
| 328 |
+
txt_query = attn.norm_added_q(txt_query)
|
| 329 |
+
if attn.norm_added_k is not None:
|
| 330 |
+
txt_key = attn.norm_added_k(txt_key)
|
| 331 |
+
|
| 332 |
+
# Apply RoPE
|
| 333 |
+
if image_rotary_emb is not None:
|
| 334 |
+
img_freqs, txt_freqs = image_rotary_emb
|
| 335 |
+
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
| 336 |
+
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
| 337 |
+
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
| 338 |
+
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
| 339 |
+
|
| 340 |
+
# Concatenate for joint attention
|
| 341 |
+
# Order: [text, image]
|
| 342 |
+
joint_query = torch.cat([txt_query, img_query], dim=1)
|
| 343 |
+
joint_key = torch.cat([txt_key, img_key], dim=1)
|
| 344 |
+
joint_value = torch.cat([txt_value, img_value], dim=1)
|
| 345 |
+
|
| 346 |
+
joint_hidden_states = attention(
|
| 347 |
+
joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, causal=False
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Reshape back
|
| 351 |
+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
| 352 |
+
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
|
| 353 |
+
|
| 354 |
+
# Split attention outputs back
|
| 355 |
+
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
| 356 |
+
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
| 357 |
+
|
| 358 |
+
# Apply output projections
|
| 359 |
+
img_attn_output = attn.to_out[0](img_attn_output)
|
| 360 |
+
if len(attn.to_out) > 1:
|
| 361 |
+
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
| 362 |
+
|
| 363 |
+
txt_attn_output = attn.to_add_out(txt_attn_output)
|
| 364 |
+
|
| 365 |
+
return img_attn_output, txt_attn_output
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@maybe_allow_in_graph
|
| 369 |
+
class QwenImageTransformerBlock(nn.Module):
|
| 370 |
+
def __init__(
|
| 371 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
| 372 |
+
):
|
| 373 |
+
super().__init__()
|
| 374 |
+
|
| 375 |
+
self.dim = dim
|
| 376 |
+
self.num_attention_heads = num_attention_heads
|
| 377 |
+
self.attention_head_dim = attention_head_dim
|
| 378 |
+
|
| 379 |
+
# Image processing modules
|
| 380 |
+
self.img_mod = nn.Sequential(
|
| 381 |
+
nn.SiLU(),
|
| 382 |
+
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
| 383 |
+
)
|
| 384 |
+
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 385 |
+
self.attn = Attention(
|
| 386 |
+
query_dim=dim,
|
| 387 |
+
cross_attention_dim=None, # Enable cross attention for joint computation
|
| 388 |
+
added_kv_proj_dim=dim, # Enable added KV projections for text stream
|
| 389 |
+
dim_head=attention_head_dim,
|
| 390 |
+
heads=num_attention_heads,
|
| 391 |
+
out_dim=dim,
|
| 392 |
+
context_pre_only=False,
|
| 393 |
+
bias=True,
|
| 394 |
+
processor=QwenDoubleStreamAttnProcessor2_0(),
|
| 395 |
+
qk_norm=qk_norm,
|
| 396 |
+
eps=eps,
|
| 397 |
+
)
|
| 398 |
+
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 399 |
+
self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 400 |
+
|
| 401 |
+
# Text processing modules
|
| 402 |
+
self.txt_mod = nn.Sequential(
|
| 403 |
+
nn.SiLU(),
|
| 404 |
+
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
| 405 |
+
)
|
| 406 |
+
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 407 |
+
# Text doesn't need separate attention - it's handled by img_attn joint computation
|
| 408 |
+
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 409 |
+
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 410 |
+
|
| 411 |
+
def _modulate(self, x, mod_params):
|
| 412 |
+
"""Apply modulation to input tensor"""
|
| 413 |
+
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
| 414 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
| 415 |
+
|
| 416 |
+
def forward(
|
| 417 |
+
self,
|
| 418 |
+
hidden_states: torch.Tensor,
|
| 419 |
+
encoder_hidden_states: torch.Tensor,
|
| 420 |
+
encoder_hidden_states_mask: torch.Tensor,
|
| 421 |
+
temb: torch.Tensor,
|
| 422 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 423 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 424 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 425 |
+
# Get modulation parameters for both streams
|
| 426 |
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
| 427 |
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
| 428 |
+
|
| 429 |
+
# Split modulation parameters for norm1 and norm2
|
| 430 |
+
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
| 431 |
+
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
| 432 |
+
|
| 433 |
+
# Process image stream - norm1 + modulation
|
| 434 |
+
img_normed = self.img_norm1(hidden_states)
|
| 435 |
+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
| 436 |
+
|
| 437 |
+
# Process text stream - norm1 + modulation
|
| 438 |
+
txt_normed = self.txt_norm1(encoder_hidden_states)
|
| 439 |
+
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
| 440 |
+
|
| 441 |
+
# Use QwenAttnProcessor2_0 for joint attention computation
|
| 442 |
+
# This directly implements the DoubleStreamLayerMegatron logic:
|
| 443 |
+
# 1. Computes QKV for both streams
|
| 444 |
+
# 2. Applies QK normalization and RoPE
|
| 445 |
+
# 3. Concatenates and runs joint attention
|
| 446 |
+
# 4. Splits results back to separate streams
|
| 447 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 448 |
+
attn_output = self.attn(
|
| 449 |
+
hidden_states=img_modulated, # Image stream (will be processed as "sample")
|
| 450 |
+
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
|
| 451 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 452 |
+
image_rotary_emb=image_rotary_emb,
|
| 453 |
+
**joint_attention_kwargs,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
| 457 |
+
img_attn_output, txt_attn_output = attn_output
|
| 458 |
+
|
| 459 |
+
# Apply attention gates and add residual (like in Megatron)
|
| 460 |
+
hidden_states = hidden_states + img_gate1 * img_attn_output
|
| 461 |
+
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
| 462 |
+
|
| 463 |
+
# Process image stream - norm2 + MLP
|
| 464 |
+
img_normed2 = self.img_norm2(hidden_states)
|
| 465 |
+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
| 466 |
+
img_mlp_output = self.img_mlp(img_modulated2)
|
| 467 |
+
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
| 468 |
+
|
| 469 |
+
# Process text stream - norm2 + MLP
|
| 470 |
+
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
| 471 |
+
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
| 472 |
+
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
| 473 |
+
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
| 474 |
+
|
| 475 |
+
# Clip to prevent overflow for fp16
|
| 476 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 477 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 478 |
+
if hidden_states.dtype == torch.float16:
|
| 479 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 480 |
+
|
| 481 |
+
return encoder_hidden_states, hidden_states
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 485 |
+
"""
|
| 486 |
+
The Transformer model introduced in Qwen.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
patch_size (`int`, defaults to `2`):
|
| 490 |
+
Patch size to turn the input data into small patches.
|
| 491 |
+
in_channels (`int`, defaults to `64`):
|
| 492 |
+
The number of channels in the input.
|
| 493 |
+
out_channels (`int`, *optional*, defaults to `None`):
|
| 494 |
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
| 495 |
+
num_layers (`int`, defaults to `60`):
|
| 496 |
+
The number of layers of dual stream DiT blocks to use.
|
| 497 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 498 |
+
The number of dimensions to use for each attention head.
|
| 499 |
+
num_attention_heads (`int`, defaults to `24`):
|
| 500 |
+
The number of attention heads to use.
|
| 501 |
+
joint_attention_dim (`int`, defaults to `3584`):
|
| 502 |
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
| 503 |
+
`encoder_hidden_states`).
|
| 504 |
+
guidance_embeds (`bool`, defaults to `False`):
|
| 505 |
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
| 506 |
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 507 |
+
The dimensions to use for the rotary positional embeddings.
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
# _supports_gradient_checkpointing = True
|
| 511 |
+
# _no_split_modules = ["QwenImageTransformerBlock"]
|
| 512 |
+
# _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 513 |
+
# _repeated_blocks = ["QwenImageTransformerBlock"]
|
| 514 |
+
_supports_gradient_checkpointing = True
|
| 515 |
+
|
| 516 |
+
@register_to_config
|
| 517 |
+
def __init__(
|
| 518 |
+
self,
|
| 519 |
+
patch_size: int = 2,
|
| 520 |
+
in_channels: int = 64,
|
| 521 |
+
out_channels: Optional[int] = 16,
|
| 522 |
+
num_layers: int = 60,
|
| 523 |
+
attention_head_dim: int = 128,
|
| 524 |
+
num_attention_heads: int = 24,
|
| 525 |
+
joint_attention_dim: int = 3584,
|
| 526 |
+
guidance_embeds: bool = False, # TODO: this should probably be removed
|
| 527 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
| 528 |
+
):
|
| 529 |
+
super().__init__()
|
| 530 |
+
self.out_channels = out_channels or in_channels
|
| 531 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 532 |
+
|
| 533 |
+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
| 534 |
+
|
| 535 |
+
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
| 536 |
+
|
| 537 |
+
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
| 538 |
+
|
| 539 |
+
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
| 540 |
+
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 541 |
+
|
| 542 |
+
self.transformer_blocks = nn.ModuleList(
|
| 543 |
+
[
|
| 544 |
+
QwenImageTransformerBlock(
|
| 545 |
+
dim=self.inner_dim,
|
| 546 |
+
num_attention_heads=num_attention_heads,
|
| 547 |
+
attention_head_dim=attention_head_dim,
|
| 548 |
+
)
|
| 549 |
+
for _ in range(num_layers)
|
| 550 |
+
]
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 554 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 555 |
+
|
| 556 |
+
self.gradient_checkpointing = False
|
| 557 |
+
|
| 558 |
+
self.sp_world_size = 1
|
| 559 |
+
self.sp_world_rank = 0
|
| 560 |
+
|
| 561 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 562 |
+
if "value" in kwargs:
|
| 563 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 564 |
+
elif "enable" in kwargs:
|
| 565 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 566 |
+
else:
|
| 567 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 568 |
+
|
| 569 |
+
def enable_multi_gpus_inference(self,):
|
| 570 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 571 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 572 |
+
self.all_gather = get_sp_group().all_gather
|
| 573 |
+
self.set_attn_processor(QwenImageMultiGPUsAttnProcessor2_0())
|
| 574 |
+
|
| 575 |
+
@property
|
| 576 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 577 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 578 |
+
r"""
|
| 579 |
+
Returns:
|
| 580 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 581 |
+
indexed by its weight name.
|
| 582 |
+
"""
|
| 583 |
+
# set recursively
|
| 584 |
+
processors = {}
|
| 585 |
+
|
| 586 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 587 |
+
if hasattr(module, "get_processor"):
|
| 588 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 589 |
+
|
| 590 |
+
for sub_name, child in module.named_children():
|
| 591 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 592 |
+
|
| 593 |
+
return processors
|
| 594 |
+
|
| 595 |
+
for name, module in self.named_children():
|
| 596 |
+
fn_recursive_add_processors(name, module, processors)
|
| 597 |
+
|
| 598 |
+
return processors
|
| 599 |
+
|
| 600 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 601 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 602 |
+
r"""
|
| 603 |
+
Sets the attention processor to use to compute attention.
|
| 604 |
+
|
| 605 |
+
Parameters:
|
| 606 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 607 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 608 |
+
for **all** `Attention` layers.
|
| 609 |
+
|
| 610 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 611 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 612 |
+
|
| 613 |
+
"""
|
| 614 |
+
count = len(self.attn_processors.keys())
|
| 615 |
+
|
| 616 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 617 |
+
raise ValueError(
|
| 618 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 619 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 623 |
+
if hasattr(module, "set_processor"):
|
| 624 |
+
if not isinstance(processor, dict):
|
| 625 |
+
module.set_processor(processor)
|
| 626 |
+
else:
|
| 627 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 628 |
+
|
| 629 |
+
for sub_name, child in module.named_children():
|
| 630 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 631 |
+
|
| 632 |
+
for name, module in self.named_children():
|
| 633 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 634 |
+
|
| 635 |
+
def forward(
|
| 636 |
+
self,
|
| 637 |
+
hidden_states: torch.Tensor,
|
| 638 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 639 |
+
encoder_hidden_states_mask: torch.Tensor = None,
|
| 640 |
+
timestep: torch.LongTensor = None,
|
| 641 |
+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
| 642 |
+
txt_seq_lens: Optional[List[int]] = None,
|
| 643 |
+
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
| 644 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 645 |
+
return_dict: bool = True,
|
| 646 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 647 |
+
"""
|
| 648 |
+
The [`QwenTransformer2DModel`] forward method.
|
| 649 |
+
|
| 650 |
+
Args:
|
| 651 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 652 |
+
Input `hidden_states`.
|
| 653 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 654 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 655 |
+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
| 656 |
+
Mask of the input conditions.
|
| 657 |
+
timestep ( `torch.LongTensor`):
|
| 658 |
+
Used to indicate denoising step.
|
| 659 |
+
attention_kwargs (`dict`, *optional*):
|
| 660 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 661 |
+
`self.processor` in
|
| 662 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 663 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 664 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 665 |
+
tuple.
|
| 666 |
+
|
| 667 |
+
Returns:
|
| 668 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 669 |
+
`tuple` where the first element is the sample tensor.
|
| 670 |
+
"""
|
| 671 |
+
if attention_kwargs is not None:
|
| 672 |
+
attention_kwargs = attention_kwargs.copy()
|
| 673 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 674 |
+
else:
|
| 675 |
+
lora_scale = 1.0
|
| 676 |
+
|
| 677 |
+
if USE_PEFT_BACKEND:
|
| 678 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 679 |
+
scale_lora_layers(self, lora_scale)
|
| 680 |
+
else:
|
| 681 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 682 |
+
logger.warning(
|
| 683 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
hidden_states = self.img_in(hidden_states)
|
| 687 |
+
|
| 688 |
+
timestep = timestep.to(hidden_states.dtype)
|
| 689 |
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
| 690 |
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
| 691 |
+
|
| 692 |
+
if guidance is not None:
|
| 693 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 694 |
+
|
| 695 |
+
temb = (
|
| 696 |
+
self.time_text_embed(timestep, hidden_states)
|
| 697 |
+
if guidance is None
|
| 698 |
+
else self.time_text_embed(timestep, guidance, hidden_states)
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
| 702 |
+
|
| 703 |
+
# Context Parallel
|
| 704 |
+
if self.sp_world_size > 1:
|
| 705 |
+
hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 706 |
+
if image_rotary_emb is not None:
|
| 707 |
+
image_rotary_emb = (
|
| 708 |
+
torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
|
| 709 |
+
image_rotary_emb[1]
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 713 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 714 |
+
def create_custom_forward(module):
|
| 715 |
+
def custom_forward(*inputs):
|
| 716 |
+
return module(*inputs)
|
| 717 |
+
|
| 718 |
+
return custom_forward
|
| 719 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 720 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
| 721 |
+
create_custom_forward(block),
|
| 722 |
+
hidden_states,
|
| 723 |
+
encoder_hidden_states,
|
| 724 |
+
encoder_hidden_states_mask,
|
| 725 |
+
temb,
|
| 726 |
+
image_rotary_emb,
|
| 727 |
+
**ckpt_kwargs,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
else:
|
| 731 |
+
encoder_hidden_states, hidden_states = block(
|
| 732 |
+
hidden_states=hidden_states,
|
| 733 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 734 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 735 |
+
temb=temb,
|
| 736 |
+
image_rotary_emb=image_rotary_emb,
|
| 737 |
+
joint_attention_kwargs=attention_kwargs,
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# Use only the image part (hidden_states) from the dual-stream blocks
|
| 741 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 742 |
+
output = self.proj_out(hidden_states)
|
| 743 |
+
|
| 744 |
+
if self.sp_world_size > 1:
|
| 745 |
+
output = self.all_gather(output, dim=1)
|
| 746 |
+
|
| 747 |
+
if USE_PEFT_BACKEND:
|
| 748 |
+
# remove `lora_scale` from each PEFT layer
|
| 749 |
+
unscale_lora_layers(self, lora_scale)
|
| 750 |
+
|
| 751 |
+
if not return_dict:
|
| 752 |
+
return (output,)
|
| 753 |
+
|
| 754 |
+
return Transformer2DModelOutput(sample=output)
|
| 755 |
+
|
| 756 |
+
@classmethod
|
| 757 |
+
def from_pretrained(
|
| 758 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 759 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
| 760 |
+
):
|
| 761 |
+
if subfolder is not None:
|
| 762 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 763 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 764 |
+
|
| 765 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 766 |
+
if not os.path.isfile(config_file):
|
| 767 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 768 |
+
with open(config_file, "r") as f:
|
| 769 |
+
config = json.load(f)
|
| 770 |
+
|
| 771 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 772 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 773 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 774 |
+
|
| 775 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 776 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 777 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 778 |
+
|
| 779 |
+
if low_cpu_mem_usage:
|
| 780 |
+
try:
|
| 781 |
+
import re
|
| 782 |
+
|
| 783 |
+
from diffusers import __version__ as diffusers_version
|
| 784 |
+
if diffusers_version >= "0.33.0":
|
| 785 |
+
from diffusers.models.model_loading_utils import \
|
| 786 |
+
load_model_dict_into_meta
|
| 787 |
+
else:
|
| 788 |
+
from diffusers.models.modeling_utils import \
|
| 789 |
+
load_model_dict_into_meta
|
| 790 |
+
from diffusers.utils import is_accelerate_available
|
| 791 |
+
if is_accelerate_available():
|
| 792 |
+
import accelerate
|
| 793 |
+
|
| 794 |
+
# Instantiate model with empty weights
|
| 795 |
+
with accelerate.init_empty_weights():
|
| 796 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 797 |
+
|
| 798 |
+
param_device = "cpu"
|
| 799 |
+
if os.path.exists(model_file):
|
| 800 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 801 |
+
elif os.path.exists(model_file_safetensors):
|
| 802 |
+
from safetensors.torch import load_file, safe_open
|
| 803 |
+
state_dict = load_file(model_file_safetensors)
|
| 804 |
+
else:
|
| 805 |
+
from safetensors.torch import load_file, safe_open
|
| 806 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 807 |
+
state_dict = {}
|
| 808 |
+
print(model_files_safetensors)
|
| 809 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 810 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 811 |
+
for key in _state_dict:
|
| 812 |
+
state_dict[key] = _state_dict[key]
|
| 813 |
+
|
| 814 |
+
if diffusers_version >= "0.33.0":
|
| 815 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 816 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 817 |
+
load_model_dict_into_meta(
|
| 818 |
+
model,
|
| 819 |
+
state_dict,
|
| 820 |
+
dtype=torch_dtype,
|
| 821 |
+
model_name_or_path=pretrained_model_path,
|
| 822 |
+
)
|
| 823 |
+
else:
|
| 824 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 825 |
+
# move the params from meta device to cpu
|
| 826 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 827 |
+
if len(missing_keys) > 0:
|
| 828 |
+
raise ValueError(
|
| 829 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
| 830 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 831 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 832 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 836 |
+
model,
|
| 837 |
+
state_dict,
|
| 838 |
+
device=param_device,
|
| 839 |
+
dtype=torch_dtype,
|
| 840 |
+
model_name_or_path=pretrained_model_path,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 844 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 845 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 846 |
+
|
| 847 |
+
if len(unexpected_keys) > 0:
|
| 848 |
+
print(
|
| 849 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
return model
|
| 853 |
+
except Exception as e:
|
| 854 |
+
print(
|
| 855 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 859 |
+
if os.path.exists(model_file):
|
| 860 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 861 |
+
elif os.path.exists(model_file_safetensors):
|
| 862 |
+
from safetensors.torch import load_file, safe_open
|
| 863 |
+
state_dict = load_file(model_file_safetensors)
|
| 864 |
+
else:
|
| 865 |
+
from safetensors.torch import load_file, safe_open
|
| 866 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 867 |
+
state_dict = {}
|
| 868 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 869 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 870 |
+
for key in _state_dict:
|
| 871 |
+
state_dict[key] = _state_dict[key]
|
| 872 |
+
|
| 873 |
+
tmp_state_dict = {}
|
| 874 |
+
for key in state_dict:
|
| 875 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 876 |
+
tmp_state_dict[key] = state_dict[key]
|
| 877 |
+
else:
|
| 878 |
+
print(key, "Size don't match, skip")
|
| 879 |
+
|
| 880 |
+
state_dict = tmp_state_dict
|
| 881 |
+
|
| 882 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 883 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 884 |
+
print(m)
|
| 885 |
+
|
| 886 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 887 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 888 |
+
|
| 889 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 890 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 891 |
+
|
| 892 |
+
model = model.to(torch_dtype)
|
| 893 |
+
return model
|
videox_fun/models/qwenimage_vae.py
ADDED
|
@@ -0,0 +1,1087 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
|
| 2 |
+
# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
# We gratefully acknowledge the Wan Team for their outstanding contributions.
|
| 17 |
+
# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
|
| 18 |
+
# For more information about the Wan VAE, please refer to:
|
| 19 |
+
# - GitHub: https://github.com/Wan-Video/Wan2.1
|
| 20 |
+
# - arXiv: https://arxiv.org/abs/2503.20314
|
| 21 |
+
|
| 22 |
+
import functools
|
| 23 |
+
import glob
|
| 24 |
+
import json
|
| 25 |
+
import math
|
| 26 |
+
import os
|
| 27 |
+
import types
|
| 28 |
+
import warnings
|
| 29 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
import torch.cuda.amp as amp
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
import torch.utils.checkpoint
|
| 37 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 38 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 39 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 40 |
+
from diffusers.models.activations import get_activation
|
| 41 |
+
from diffusers.models.attention import FeedForward
|
| 42 |
+
from diffusers.models.attention_processor import Attention
|
| 43 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
| 44 |
+
DiagonalGaussianDistribution)
|
| 45 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 46 |
+
from diffusers.models.modeling_outputs import (AutoencoderKLOutput,
|
| 47 |
+
Transformer2DModelOutput)
|
| 48 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 49 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 50 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 51 |
+
scale_lora_layers, unscale_lora_layers)
|
| 52 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 53 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 54 |
+
from torch import nn
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 57 |
+
|
| 58 |
+
CACHE_T = 2
|
| 59 |
+
|
| 60 |
+
class QwenImageCausalConv3d(nn.Conv3d):
|
| 61 |
+
r"""
|
| 62 |
+
A custom 3D causal convolution layer with feature caching support.
|
| 63 |
+
|
| 64 |
+
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
| 65 |
+
caching for efficient inference.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
in_channels (int): Number of channels in the input image
|
| 69 |
+
out_channels (int): Number of channels produced by the convolution
|
| 70 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 71 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 72 |
+
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
in_channels: int,
|
| 78 |
+
out_channels: int,
|
| 79 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 80 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 81 |
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(
|
| 84 |
+
in_channels=in_channels,
|
| 85 |
+
out_channels=out_channels,
|
| 86 |
+
kernel_size=kernel_size,
|
| 87 |
+
stride=stride,
|
| 88 |
+
padding=padding,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Set up causal padding
|
| 92 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
| 93 |
+
self.padding = (0, 0, 0)
|
| 94 |
+
|
| 95 |
+
def forward(self, x, cache_x=None):
|
| 96 |
+
padding = list(self._padding)
|
| 97 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 98 |
+
cache_x = cache_x.to(x.device)
|
| 99 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 100 |
+
padding[4] -= cache_x.shape[2]
|
| 101 |
+
x = F.pad(x, padding)
|
| 102 |
+
return super().forward(x)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class QwenImageRMS_norm(nn.Module):
|
| 106 |
+
r"""
|
| 107 |
+
A custom RMS normalization layer.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
dim (int): The number of dimensions to normalize over.
|
| 111 |
+
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
| 112 |
+
Default is True.
|
| 113 |
+
images (bool, optional): Whether the input represents image data. Default is True.
|
| 114 |
+
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
| 118 |
+
super().__init__()
|
| 119 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 120 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 121 |
+
|
| 122 |
+
self.channel_first = channel_first
|
| 123 |
+
self.scale = dim**0.5
|
| 124 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 125 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class QwenImageUpsample(nn.Upsample):
|
| 132 |
+
r"""
|
| 133 |
+
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
x (torch.Tensor): Input tensor to be upsampled.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
torch.Tensor: Upsampled tensor with the same data type as the input.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
return super().forward(x.float()).type_as(x)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class QwenImageResample(nn.Module):
|
| 147 |
+
r"""
|
| 148 |
+
A custom resampling module for 2D and 3D data.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
dim (int): The number of input/output channels.
|
| 152 |
+
mode (str): The resampling mode. Must be one of:
|
| 153 |
+
- 'none': No resampling (identity operation).
|
| 154 |
+
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
| 155 |
+
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
| 156 |
+
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
| 157 |
+
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(self, dim: int, mode: str) -> None:
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.dim = dim
|
| 163 |
+
self.mode = mode
|
| 164 |
+
|
| 165 |
+
# layers
|
| 166 |
+
if mode == "upsample2d":
|
| 167 |
+
self.resample = nn.Sequential(
|
| 168 |
+
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 169 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1),
|
| 170 |
+
)
|
| 171 |
+
elif mode == "upsample3d":
|
| 172 |
+
self.resample = nn.Sequential(
|
| 173 |
+
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 174 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1),
|
| 175 |
+
)
|
| 176 |
+
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 177 |
+
|
| 178 |
+
elif mode == "downsample2d":
|
| 179 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 180 |
+
elif mode == "downsample3d":
|
| 181 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 182 |
+
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 183 |
+
|
| 184 |
+
else:
|
| 185 |
+
self.resample = nn.Identity()
|
| 186 |
+
|
| 187 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 188 |
+
b, c, t, h, w = x.size()
|
| 189 |
+
if self.mode == "upsample3d":
|
| 190 |
+
if feat_cache is not None:
|
| 191 |
+
idx = feat_idx[0]
|
| 192 |
+
if feat_cache[idx] is None:
|
| 193 |
+
feat_cache[idx] = "Rep"
|
| 194 |
+
feat_idx[0] += 1
|
| 195 |
+
else:
|
| 196 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 197 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
| 198 |
+
# cache last frame of last two chunk
|
| 199 |
+
cache_x = torch.cat(
|
| 200 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
| 201 |
+
)
|
| 202 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
| 203 |
+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
| 204 |
+
if feat_cache[idx] == "Rep":
|
| 205 |
+
x = self.time_conv(x)
|
| 206 |
+
else:
|
| 207 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 208 |
+
feat_cache[idx] = cache_x
|
| 209 |
+
feat_idx[0] += 1
|
| 210 |
+
|
| 211 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 212 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 213 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 214 |
+
t = x.shape[2]
|
| 215 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 216 |
+
x = self.resample(x)
|
| 217 |
+
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
| 218 |
+
|
| 219 |
+
if self.mode == "downsample3d":
|
| 220 |
+
if feat_cache is not None:
|
| 221 |
+
idx = feat_idx[0]
|
| 222 |
+
if feat_cache[idx] is None:
|
| 223 |
+
feat_cache[idx] = x.clone()
|
| 224 |
+
feat_idx[0] += 1
|
| 225 |
+
else:
|
| 226 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 227 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 228 |
+
feat_cache[idx] = cache_x
|
| 229 |
+
feat_idx[0] += 1
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class QwenImageResidualBlock(nn.Module):
|
| 234 |
+
r"""
|
| 235 |
+
A custom residual block module.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
in_dim (int): Number of input channels.
|
| 239 |
+
out_dim (int): Number of output channels.
|
| 240 |
+
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
| 241 |
+
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
in_dim: int,
|
| 247 |
+
out_dim: int,
|
| 248 |
+
dropout: float = 0.0,
|
| 249 |
+
non_linearity: str = "silu",
|
| 250 |
+
) -> None:
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.in_dim = in_dim
|
| 253 |
+
self.out_dim = out_dim
|
| 254 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 255 |
+
|
| 256 |
+
# layers
|
| 257 |
+
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
|
| 258 |
+
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
|
| 259 |
+
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
|
| 260 |
+
self.dropout = nn.Dropout(dropout)
|
| 261 |
+
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
|
| 262 |
+
self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 263 |
+
|
| 264 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 265 |
+
# Apply shortcut connection
|
| 266 |
+
h = self.conv_shortcut(x)
|
| 267 |
+
|
| 268 |
+
# First normalization and activation
|
| 269 |
+
x = self.norm1(x)
|
| 270 |
+
x = self.nonlinearity(x)
|
| 271 |
+
|
| 272 |
+
if feat_cache is not None:
|
| 273 |
+
idx = feat_idx[0]
|
| 274 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 275 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 276 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 277 |
+
|
| 278 |
+
x = self.conv1(x, feat_cache[idx])
|
| 279 |
+
feat_cache[idx] = cache_x
|
| 280 |
+
feat_idx[0] += 1
|
| 281 |
+
else:
|
| 282 |
+
x = self.conv1(x)
|
| 283 |
+
|
| 284 |
+
# Second normalization and activation
|
| 285 |
+
x = self.norm2(x)
|
| 286 |
+
x = self.nonlinearity(x)
|
| 287 |
+
|
| 288 |
+
# Dropout
|
| 289 |
+
x = self.dropout(x)
|
| 290 |
+
|
| 291 |
+
if feat_cache is not None:
|
| 292 |
+
idx = feat_idx[0]
|
| 293 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 294 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 295 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 296 |
+
|
| 297 |
+
x = self.conv2(x, feat_cache[idx])
|
| 298 |
+
feat_cache[idx] = cache_x
|
| 299 |
+
feat_idx[0] += 1
|
| 300 |
+
else:
|
| 301 |
+
x = self.conv2(x)
|
| 302 |
+
|
| 303 |
+
# Add residual connection
|
| 304 |
+
return x + h
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class QwenImageAttentionBlock(nn.Module):
|
| 308 |
+
r"""
|
| 309 |
+
Causal self-attention with a single head.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
dim (int): The number of channels in the input tensor.
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(self, dim):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.dim = dim
|
| 318 |
+
|
| 319 |
+
# layers
|
| 320 |
+
self.norm = QwenImageRMS_norm(dim)
|
| 321 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 322 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 323 |
+
|
| 324 |
+
def forward(self, x):
|
| 325 |
+
identity = x
|
| 326 |
+
batch_size, channels, time, height, width = x.size()
|
| 327 |
+
|
| 328 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
| 329 |
+
x = self.norm(x)
|
| 330 |
+
|
| 331 |
+
# compute query, key, value
|
| 332 |
+
qkv = self.to_qkv(x)
|
| 333 |
+
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
| 334 |
+
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
| 335 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 336 |
+
|
| 337 |
+
# apply attention
|
| 338 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 339 |
+
|
| 340 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
| 341 |
+
|
| 342 |
+
# output projection
|
| 343 |
+
x = self.proj(x)
|
| 344 |
+
|
| 345 |
+
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
| 346 |
+
x = x.view(batch_size, time, channels, height, width)
|
| 347 |
+
x = x.permute(0, 2, 1, 3, 4)
|
| 348 |
+
|
| 349 |
+
return x + identity
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class QwenImageMidBlock(nn.Module):
|
| 353 |
+
"""
|
| 354 |
+
Middle block for QwenImageVAE encoder and decoder.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
dim (int): Number of input/output channels.
|
| 358 |
+
dropout (float): Dropout rate.
|
| 359 |
+
non_linearity (str): Type of non-linearity to use.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.dim = dim
|
| 365 |
+
|
| 366 |
+
# Create the components
|
| 367 |
+
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
|
| 368 |
+
attentions = []
|
| 369 |
+
for _ in range(num_layers):
|
| 370 |
+
attentions.append(QwenImageAttentionBlock(dim))
|
| 371 |
+
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
|
| 372 |
+
self.attentions = nn.ModuleList(attentions)
|
| 373 |
+
self.resnets = nn.ModuleList(resnets)
|
| 374 |
+
|
| 375 |
+
self.gradient_checkpointing = False
|
| 376 |
+
|
| 377 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 378 |
+
# First residual block
|
| 379 |
+
x = self.resnets[0](x, feat_cache, feat_idx)
|
| 380 |
+
|
| 381 |
+
# Process through attention and residual blocks
|
| 382 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 383 |
+
if attn is not None:
|
| 384 |
+
x = attn(x)
|
| 385 |
+
|
| 386 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 387 |
+
|
| 388 |
+
return x
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class QwenImageEncoder3d(nn.Module):
|
| 392 |
+
r"""
|
| 393 |
+
A 3D encoder module.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
dim (int): The base number of channels in the first layer.
|
| 397 |
+
z_dim (int): The dimensionality of the latent space.
|
| 398 |
+
dim_mult (list of int): Multipliers for the number of channels in each block.
|
| 399 |
+
num_res_blocks (int): Number of residual blocks in each block.
|
| 400 |
+
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
| 401 |
+
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
| 402 |
+
dropout (float): Dropout rate for the dropout layers.
|
| 403 |
+
non_linearity (str): Type of non-linearity to use.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
dim=128,
|
| 409 |
+
z_dim=4,
|
| 410 |
+
dim_mult=[1, 2, 4, 4],
|
| 411 |
+
num_res_blocks=2,
|
| 412 |
+
attn_scales=[],
|
| 413 |
+
temperal_downsample=[True, True, False],
|
| 414 |
+
dropout=0.0,
|
| 415 |
+
non_linearity: str = "silu",
|
| 416 |
+
):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.dim = dim
|
| 419 |
+
self.z_dim = z_dim
|
| 420 |
+
self.dim_mult = dim_mult
|
| 421 |
+
self.num_res_blocks = num_res_blocks
|
| 422 |
+
self.attn_scales = attn_scales
|
| 423 |
+
self.temperal_downsample = temperal_downsample
|
| 424 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 425 |
+
|
| 426 |
+
# dimensions
|
| 427 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 428 |
+
scale = 1.0
|
| 429 |
+
|
| 430 |
+
# init block
|
| 431 |
+
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
|
| 432 |
+
|
| 433 |
+
# downsample blocks
|
| 434 |
+
self.down_blocks = nn.ModuleList([])
|
| 435 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 436 |
+
# residual (+attention) blocks
|
| 437 |
+
for _ in range(num_res_blocks):
|
| 438 |
+
self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
|
| 439 |
+
if scale in attn_scales:
|
| 440 |
+
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
|
| 441 |
+
in_dim = out_dim
|
| 442 |
+
|
| 443 |
+
# downsample block
|
| 444 |
+
if i != len(dim_mult) - 1:
|
| 445 |
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
| 446 |
+
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
|
| 447 |
+
scale /= 2.0
|
| 448 |
+
|
| 449 |
+
# middle blocks
|
| 450 |
+
self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
| 451 |
+
|
| 452 |
+
# output blocks
|
| 453 |
+
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
| 454 |
+
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
|
| 455 |
+
|
| 456 |
+
self.gradient_checkpointing = False
|
| 457 |
+
|
| 458 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 459 |
+
if feat_cache is not None:
|
| 460 |
+
idx = feat_idx[0]
|
| 461 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 462 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 463 |
+
# cache last frame of last two chunk
|
| 464 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 465 |
+
x = self.conv_in(x, feat_cache[idx])
|
| 466 |
+
feat_cache[idx] = cache_x
|
| 467 |
+
feat_idx[0] += 1
|
| 468 |
+
else:
|
| 469 |
+
x = self.conv_in(x)
|
| 470 |
+
|
| 471 |
+
## downsamples
|
| 472 |
+
for layer in self.down_blocks:
|
| 473 |
+
if feat_cache is not None:
|
| 474 |
+
x = layer(x, feat_cache, feat_idx)
|
| 475 |
+
else:
|
| 476 |
+
x = layer(x)
|
| 477 |
+
|
| 478 |
+
## middle
|
| 479 |
+
x = self.mid_block(x, feat_cache, feat_idx)
|
| 480 |
+
|
| 481 |
+
## head
|
| 482 |
+
x = self.norm_out(x)
|
| 483 |
+
x = self.nonlinearity(x)
|
| 484 |
+
if feat_cache is not None:
|
| 485 |
+
idx = feat_idx[0]
|
| 486 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 487 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 488 |
+
# cache last frame of last two chunk
|
| 489 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 490 |
+
x = self.conv_out(x, feat_cache[idx])
|
| 491 |
+
feat_cache[idx] = cache_x
|
| 492 |
+
feat_idx[0] += 1
|
| 493 |
+
else:
|
| 494 |
+
x = self.conv_out(x)
|
| 495 |
+
return x
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class QwenImageUpBlock(nn.Module):
|
| 499 |
+
"""
|
| 500 |
+
A block that handles upsampling for the QwenImageVAE decoder.
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
in_dim (int): Input dimension
|
| 504 |
+
out_dim (int): Output dimension
|
| 505 |
+
num_res_blocks (int): Number of residual blocks
|
| 506 |
+
dropout (float): Dropout rate
|
| 507 |
+
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
| 508 |
+
non_linearity (str): Type of non-linearity to use
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
def __init__(
|
| 512 |
+
self,
|
| 513 |
+
in_dim: int,
|
| 514 |
+
out_dim: int,
|
| 515 |
+
num_res_blocks: int,
|
| 516 |
+
dropout: float = 0.0,
|
| 517 |
+
upsample_mode: Optional[str] = None,
|
| 518 |
+
non_linearity: str = "silu",
|
| 519 |
+
):
|
| 520 |
+
super().__init__()
|
| 521 |
+
self.in_dim = in_dim
|
| 522 |
+
self.out_dim = out_dim
|
| 523 |
+
|
| 524 |
+
# Create layers list
|
| 525 |
+
resnets = []
|
| 526 |
+
# Add residual blocks and attention if needed
|
| 527 |
+
current_dim = in_dim
|
| 528 |
+
for _ in range(num_res_blocks + 1):
|
| 529 |
+
resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
| 530 |
+
current_dim = out_dim
|
| 531 |
+
|
| 532 |
+
self.resnets = nn.ModuleList(resnets)
|
| 533 |
+
|
| 534 |
+
# Add upsampling layer if needed
|
| 535 |
+
self.upsamplers = None
|
| 536 |
+
if upsample_mode is not None:
|
| 537 |
+
self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
|
| 538 |
+
|
| 539 |
+
self.gradient_checkpointing = False
|
| 540 |
+
|
| 541 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 542 |
+
"""
|
| 543 |
+
Forward pass through the upsampling block.
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
x (torch.Tensor): Input tensor
|
| 547 |
+
feat_cache (list, optional): Feature cache for causal convolutions
|
| 548 |
+
feat_idx (list, optional): Feature index for cache management
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
torch.Tensor: Output tensor
|
| 552 |
+
"""
|
| 553 |
+
for resnet in self.resnets:
|
| 554 |
+
if feat_cache is not None:
|
| 555 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 556 |
+
else:
|
| 557 |
+
x = resnet(x)
|
| 558 |
+
|
| 559 |
+
if self.upsamplers is not None:
|
| 560 |
+
if feat_cache is not None:
|
| 561 |
+
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
| 562 |
+
else:
|
| 563 |
+
x = self.upsamplers[0](x)
|
| 564 |
+
return x
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class QwenImageDecoder3d(nn.Module):
|
| 568 |
+
r"""
|
| 569 |
+
A 3D decoder module.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
dim (int): The base number of channels in the first layer.
|
| 573 |
+
z_dim (int): The dimensionality of the latent space.
|
| 574 |
+
dim_mult (list of int): Multipliers for the number of channels in each block.
|
| 575 |
+
num_res_blocks (int): Number of residual blocks in each block.
|
| 576 |
+
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
| 577 |
+
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
| 578 |
+
dropout (float): Dropout rate for the dropout layers.
|
| 579 |
+
non_linearity (str): Type of non-linearity to use.
|
| 580 |
+
"""
|
| 581 |
+
|
| 582 |
+
def __init__(
|
| 583 |
+
self,
|
| 584 |
+
dim=128,
|
| 585 |
+
z_dim=4,
|
| 586 |
+
dim_mult=[1, 2, 4, 4],
|
| 587 |
+
num_res_blocks=2,
|
| 588 |
+
attn_scales=[],
|
| 589 |
+
temperal_upsample=[False, True, True],
|
| 590 |
+
dropout=0.0,
|
| 591 |
+
non_linearity: str = "silu",
|
| 592 |
+
):
|
| 593 |
+
super().__init__()
|
| 594 |
+
self.dim = dim
|
| 595 |
+
self.z_dim = z_dim
|
| 596 |
+
self.dim_mult = dim_mult
|
| 597 |
+
self.num_res_blocks = num_res_blocks
|
| 598 |
+
self.attn_scales = attn_scales
|
| 599 |
+
self.temperal_upsample = temperal_upsample
|
| 600 |
+
|
| 601 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 602 |
+
|
| 603 |
+
# dimensions
|
| 604 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 605 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 606 |
+
|
| 607 |
+
# init block
|
| 608 |
+
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 609 |
+
|
| 610 |
+
# middle blocks
|
| 611 |
+
self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
| 612 |
+
|
| 613 |
+
# upsample blocks
|
| 614 |
+
self.up_blocks = nn.ModuleList([])
|
| 615 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 616 |
+
# residual (+attention) blocks
|
| 617 |
+
if i > 0:
|
| 618 |
+
in_dim = in_dim // 2
|
| 619 |
+
|
| 620 |
+
# Determine if we need upsampling
|
| 621 |
+
upsample_mode = None
|
| 622 |
+
if i != len(dim_mult) - 1:
|
| 623 |
+
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
| 624 |
+
|
| 625 |
+
# Create and add the upsampling block
|
| 626 |
+
up_block = QwenImageUpBlock(
|
| 627 |
+
in_dim=in_dim,
|
| 628 |
+
out_dim=out_dim,
|
| 629 |
+
num_res_blocks=num_res_blocks,
|
| 630 |
+
dropout=dropout,
|
| 631 |
+
upsample_mode=upsample_mode,
|
| 632 |
+
non_linearity=non_linearity,
|
| 633 |
+
)
|
| 634 |
+
self.up_blocks.append(up_block)
|
| 635 |
+
|
| 636 |
+
# Update scale for next iteration
|
| 637 |
+
if upsample_mode is not None:
|
| 638 |
+
scale *= 2.0
|
| 639 |
+
|
| 640 |
+
# output blocks
|
| 641 |
+
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
|
| 642 |
+
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
|
| 643 |
+
|
| 644 |
+
self.gradient_checkpointing = False
|
| 645 |
+
|
| 646 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 647 |
+
## conv1
|
| 648 |
+
if feat_cache is not None:
|
| 649 |
+
idx = feat_idx[0]
|
| 650 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 651 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 652 |
+
# cache last frame of last two chunk
|
| 653 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 654 |
+
x = self.conv_in(x, feat_cache[idx])
|
| 655 |
+
feat_cache[idx] = cache_x
|
| 656 |
+
feat_idx[0] += 1
|
| 657 |
+
else:
|
| 658 |
+
x = self.conv_in(x)
|
| 659 |
+
|
| 660 |
+
## middle
|
| 661 |
+
x = self.mid_block(x, feat_cache, feat_idx)
|
| 662 |
+
|
| 663 |
+
## upsamples
|
| 664 |
+
for up_block in self.up_blocks:
|
| 665 |
+
x = up_block(x, feat_cache, feat_idx)
|
| 666 |
+
|
| 667 |
+
## head
|
| 668 |
+
x = self.norm_out(x)
|
| 669 |
+
x = self.nonlinearity(x)
|
| 670 |
+
if feat_cache is not None:
|
| 671 |
+
idx = feat_idx[0]
|
| 672 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 673 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 674 |
+
# cache last frame of last two chunk
|
| 675 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 676 |
+
x = self.conv_out(x, feat_cache[idx])
|
| 677 |
+
feat_cache[idx] = cache_x
|
| 678 |
+
feat_idx[0] += 1
|
| 679 |
+
else:
|
| 680 |
+
x = self.conv_out(x)
|
| 681 |
+
return x
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 685 |
+
r"""
|
| 686 |
+
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
| 687 |
+
|
| 688 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 689 |
+
for all models (such as downloading or saving).
|
| 690 |
+
"""
|
| 691 |
+
|
| 692 |
+
_supports_gradient_checkpointing = False
|
| 693 |
+
|
| 694 |
+
# fmt: off
|
| 695 |
+
@register_to_config
|
| 696 |
+
def __init__(
|
| 697 |
+
self,
|
| 698 |
+
base_dim: int = 96,
|
| 699 |
+
z_dim: int = 16,
|
| 700 |
+
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
| 701 |
+
num_res_blocks: int = 2,
|
| 702 |
+
attn_scales: List[float] = [],
|
| 703 |
+
temperal_downsample: List[bool] = [False, True, True],
|
| 704 |
+
dropout: float = 0.0,
|
| 705 |
+
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
| 706 |
+
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
| 707 |
+
) -> None:
|
| 708 |
+
# fmt: on
|
| 709 |
+
super().__init__()
|
| 710 |
+
|
| 711 |
+
self.z_dim = z_dim
|
| 712 |
+
self.temperal_downsample = temperal_downsample
|
| 713 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 714 |
+
|
| 715 |
+
self.encoder = QwenImageEncoder3d(
|
| 716 |
+
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
| 717 |
+
)
|
| 718 |
+
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 719 |
+
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
|
| 720 |
+
|
| 721 |
+
self.decoder = QwenImageDecoder3d(
|
| 722 |
+
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
| 726 |
+
|
| 727 |
+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
| 728 |
+
# to perform decoding of a single video latent at a time.
|
| 729 |
+
self.use_slicing = False
|
| 730 |
+
|
| 731 |
+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
| 732 |
+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
| 733 |
+
# intermediate tiles together, the memory requirement can be lowered.
|
| 734 |
+
self.use_tiling = False
|
| 735 |
+
|
| 736 |
+
# The minimal tile height and width for spatial tiling to be used
|
| 737 |
+
self.tile_sample_min_height = 256
|
| 738 |
+
self.tile_sample_min_width = 256
|
| 739 |
+
|
| 740 |
+
# The minimal distance between two spatial tiles
|
| 741 |
+
self.tile_sample_stride_height = 192
|
| 742 |
+
self.tile_sample_stride_width = 192
|
| 743 |
+
|
| 744 |
+
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
| 745 |
+
self._cached_conv_counts = {
|
| 746 |
+
"decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
|
| 747 |
+
if self.decoder is not None
|
| 748 |
+
else 0,
|
| 749 |
+
"encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
|
| 750 |
+
if self.encoder is not None
|
| 751 |
+
else 0,
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
def enable_tiling(
|
| 755 |
+
self,
|
| 756 |
+
tile_sample_min_height: Optional[int] = None,
|
| 757 |
+
tile_sample_min_width: Optional[int] = None,
|
| 758 |
+
tile_sample_stride_height: Optional[float] = None,
|
| 759 |
+
tile_sample_stride_width: Optional[float] = None,
|
| 760 |
+
) -> None:
|
| 761 |
+
r"""
|
| 762 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 763 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 764 |
+
processing larger images.
|
| 765 |
+
|
| 766 |
+
Args:
|
| 767 |
+
tile_sample_min_height (`int`, *optional*):
|
| 768 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
| 769 |
+
tile_sample_min_width (`int`, *optional*):
|
| 770 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
| 771 |
+
tile_sample_stride_height (`int`, *optional*):
|
| 772 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
| 773 |
+
no tiling artifacts produced across the height dimension.
|
| 774 |
+
tile_sample_stride_width (`int`, *optional*):
|
| 775 |
+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
| 776 |
+
artifacts produced across the width dimension.
|
| 777 |
+
"""
|
| 778 |
+
self.use_tiling = True
|
| 779 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
| 780 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
| 781 |
+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
| 782 |
+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
| 783 |
+
|
| 784 |
+
def disable_tiling(self) -> None:
|
| 785 |
+
r"""
|
| 786 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 787 |
+
decoding in one step.
|
| 788 |
+
"""
|
| 789 |
+
self.use_tiling = False
|
| 790 |
+
|
| 791 |
+
def enable_slicing(self) -> None:
|
| 792 |
+
r"""
|
| 793 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 794 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 795 |
+
"""
|
| 796 |
+
self.use_slicing = True
|
| 797 |
+
|
| 798 |
+
def disable_slicing(self) -> None:
|
| 799 |
+
r"""
|
| 800 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 801 |
+
decoding in one step.
|
| 802 |
+
"""
|
| 803 |
+
self.use_slicing = False
|
| 804 |
+
|
| 805 |
+
def clear_cache(self):
|
| 806 |
+
def _count_conv3d(model):
|
| 807 |
+
count = 0
|
| 808 |
+
for m in model.modules():
|
| 809 |
+
if isinstance(m, QwenImageCausalConv3d):
|
| 810 |
+
count += 1
|
| 811 |
+
return count
|
| 812 |
+
|
| 813 |
+
self._conv_num = _count_conv3d(self.decoder)
|
| 814 |
+
self._conv_idx = [0]
|
| 815 |
+
self._feat_map = [None] * self._conv_num
|
| 816 |
+
# cache encode
|
| 817 |
+
self._enc_conv_num = _count_conv3d(self.encoder)
|
| 818 |
+
self._enc_conv_idx = [0]
|
| 819 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 820 |
+
|
| 821 |
+
def _encode(self, x: torch.Tensor):
|
| 822 |
+
_, _, num_frame, height, width = x.shape
|
| 823 |
+
|
| 824 |
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
| 825 |
+
return self.tiled_encode(x)
|
| 826 |
+
|
| 827 |
+
self.clear_cache()
|
| 828 |
+
iter_ = 1 + (num_frame - 1) // 4
|
| 829 |
+
for i in range(iter_):
|
| 830 |
+
self._enc_conv_idx = [0]
|
| 831 |
+
if i == 0:
|
| 832 |
+
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 833 |
+
else:
|
| 834 |
+
out_ = self.encoder(
|
| 835 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
| 836 |
+
feat_cache=self._enc_feat_map,
|
| 837 |
+
feat_idx=self._enc_conv_idx,
|
| 838 |
+
)
|
| 839 |
+
out = torch.cat([out, out_], 2)
|
| 840 |
+
|
| 841 |
+
enc = self.quant_conv(out)
|
| 842 |
+
self.clear_cache()
|
| 843 |
+
return enc
|
| 844 |
+
|
| 845 |
+
@apply_forward_hook
|
| 846 |
+
def encode(
|
| 847 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 848 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 849 |
+
r"""
|
| 850 |
+
Encode a batch of images into latents.
|
| 851 |
+
|
| 852 |
+
Args:
|
| 853 |
+
x (`torch.Tensor`): Input batch of images.
|
| 854 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 855 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 859 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 860 |
+
"""
|
| 861 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 862 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 863 |
+
h = torch.cat(encoded_slices)
|
| 864 |
+
else:
|
| 865 |
+
h = self._encode(x)
|
| 866 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 867 |
+
|
| 868 |
+
if not return_dict:
|
| 869 |
+
return (posterior,)
|
| 870 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 871 |
+
|
| 872 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
| 873 |
+
_, _, num_frame, height, width = z.shape
|
| 874 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 875 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 876 |
+
|
| 877 |
+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
| 878 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 879 |
+
|
| 880 |
+
self.clear_cache()
|
| 881 |
+
x = self.post_quant_conv(z)
|
| 882 |
+
for i in range(num_frame):
|
| 883 |
+
self._conv_idx = [0]
|
| 884 |
+
if i == 0:
|
| 885 |
+
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 886 |
+
else:
|
| 887 |
+
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 888 |
+
out = torch.cat([out, out_], 2)
|
| 889 |
+
|
| 890 |
+
out = torch.clamp(out, min=-1.0, max=1.0)
|
| 891 |
+
self.clear_cache()
|
| 892 |
+
if not return_dict:
|
| 893 |
+
return (out,)
|
| 894 |
+
|
| 895 |
+
return DecoderOutput(sample=out)
|
| 896 |
+
|
| 897 |
+
@apply_forward_hook
|
| 898 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 899 |
+
r"""
|
| 900 |
+
Decode a batch of images.
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 904 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 905 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 906 |
+
|
| 907 |
+
Returns:
|
| 908 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 909 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 910 |
+
returned.
|
| 911 |
+
"""
|
| 912 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 913 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 914 |
+
decoded = torch.cat(decoded_slices)
|
| 915 |
+
else:
|
| 916 |
+
decoded = self._decode(z).sample
|
| 917 |
+
|
| 918 |
+
if not return_dict:
|
| 919 |
+
return (decoded,)
|
| 920 |
+
return DecoderOutput(sample=decoded)
|
| 921 |
+
|
| 922 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 923 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 924 |
+
for y in range(blend_extent):
|
| 925 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 926 |
+
y / blend_extent
|
| 927 |
+
)
|
| 928 |
+
return b
|
| 929 |
+
|
| 930 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 931 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 932 |
+
for x in range(blend_extent):
|
| 933 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 934 |
+
x / blend_extent
|
| 935 |
+
)
|
| 936 |
+
return b
|
| 937 |
+
|
| 938 |
+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
| 939 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 940 |
+
|
| 941 |
+
Args:
|
| 942 |
+
x (`torch.Tensor`): Input batch of videos.
|
| 943 |
+
|
| 944 |
+
Returns:
|
| 945 |
+
`torch.Tensor`:
|
| 946 |
+
The latent representation of the encoded videos.
|
| 947 |
+
"""
|
| 948 |
+
_, _, num_frames, height, width = x.shape
|
| 949 |
+
latent_height = height // self.spatial_compression_ratio
|
| 950 |
+
latent_width = width // self.spatial_compression_ratio
|
| 951 |
+
|
| 952 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 953 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 954 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 955 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 956 |
+
|
| 957 |
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
| 958 |
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
| 959 |
+
|
| 960 |
+
# Split x into overlapping tiles and encode them separately.
|
| 961 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 962 |
+
rows = []
|
| 963 |
+
for i in range(0, height, self.tile_sample_stride_height):
|
| 964 |
+
row = []
|
| 965 |
+
for j in range(0, width, self.tile_sample_stride_width):
|
| 966 |
+
self.clear_cache()
|
| 967 |
+
time = []
|
| 968 |
+
frame_range = 1 + (num_frames - 1) // 4
|
| 969 |
+
for k in range(frame_range):
|
| 970 |
+
self._enc_conv_idx = [0]
|
| 971 |
+
if k == 0:
|
| 972 |
+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
| 973 |
+
else:
|
| 974 |
+
tile = x[
|
| 975 |
+
:,
|
| 976 |
+
:,
|
| 977 |
+
1 + 4 * (k - 1) : 1 + 4 * k,
|
| 978 |
+
i : i + self.tile_sample_min_height,
|
| 979 |
+
j : j + self.tile_sample_min_width,
|
| 980 |
+
]
|
| 981 |
+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 982 |
+
tile = self.quant_conv(tile)
|
| 983 |
+
time.append(tile)
|
| 984 |
+
row.append(torch.cat(time, dim=2))
|
| 985 |
+
rows.append(row)
|
| 986 |
+
self.clear_cache()
|
| 987 |
+
|
| 988 |
+
result_rows = []
|
| 989 |
+
for i, row in enumerate(rows):
|
| 990 |
+
result_row = []
|
| 991 |
+
for j, tile in enumerate(row):
|
| 992 |
+
# blend the above tile and the left tile
|
| 993 |
+
# to the current tile and add the current tile to the result row
|
| 994 |
+
if i > 0:
|
| 995 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 996 |
+
if j > 0:
|
| 997 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 998 |
+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
| 999 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 1000 |
+
|
| 1001 |
+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
| 1002 |
+
return enc
|
| 1003 |
+
|
| 1004 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1005 |
+
r"""
|
| 1006 |
+
Decode a batch of images using a tiled decoder.
|
| 1007 |
+
|
| 1008 |
+
Args:
|
| 1009 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1010 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1011 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1012 |
+
|
| 1013 |
+
Returns:
|
| 1014 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1015 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1016 |
+
returned.
|
| 1017 |
+
"""
|
| 1018 |
+
_, _, num_frames, height, width = z.shape
|
| 1019 |
+
sample_height = height * self.spatial_compression_ratio
|
| 1020 |
+
sample_width = width * self.spatial_compression_ratio
|
| 1021 |
+
|
| 1022 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1023 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1024 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 1025 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 1026 |
+
|
| 1027 |
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
| 1028 |
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
| 1029 |
+
|
| 1030 |
+
# Split z into overlapping tiles and decode them separately.
|
| 1031 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1032 |
+
rows = []
|
| 1033 |
+
for i in range(0, height, tile_latent_stride_height):
|
| 1034 |
+
row = []
|
| 1035 |
+
for j in range(0, width, tile_latent_stride_width):
|
| 1036 |
+
self.clear_cache()
|
| 1037 |
+
time = []
|
| 1038 |
+
for k in range(num_frames):
|
| 1039 |
+
self._conv_idx = [0]
|
| 1040 |
+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
| 1041 |
+
tile = self.post_quant_conv(tile)
|
| 1042 |
+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 1043 |
+
time.append(decoded)
|
| 1044 |
+
row.append(torch.cat(time, dim=2))
|
| 1045 |
+
rows.append(row)
|
| 1046 |
+
self.clear_cache()
|
| 1047 |
+
|
| 1048 |
+
result_rows = []
|
| 1049 |
+
for i, row in enumerate(rows):
|
| 1050 |
+
result_row = []
|
| 1051 |
+
for j, tile in enumerate(row):
|
| 1052 |
+
# blend the above tile and the left tile
|
| 1053 |
+
# to the current tile and add the current tile to the result row
|
| 1054 |
+
if i > 0:
|
| 1055 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 1056 |
+
if j > 0:
|
| 1057 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 1058 |
+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
| 1059 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 1060 |
+
|
| 1061 |
+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
| 1062 |
+
|
| 1063 |
+
if not return_dict:
|
| 1064 |
+
return (dec,)
|
| 1065 |
+
return DecoderOutput(sample=dec)
|
| 1066 |
+
|
| 1067 |
+
def forward(
|
| 1068 |
+
self,
|
| 1069 |
+
sample: torch.Tensor,
|
| 1070 |
+
sample_posterior: bool = False,
|
| 1071 |
+
return_dict: bool = True,
|
| 1072 |
+
generator: Optional[torch.Generator] = None,
|
| 1073 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
| 1074 |
+
"""
|
| 1075 |
+
Args:
|
| 1076 |
+
sample (`torch.Tensor`): Input sample.
|
| 1077 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1078 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 1079 |
+
"""
|
| 1080 |
+
x = sample
|
| 1081 |
+
posterior = self.encode(x).latent_dist
|
| 1082 |
+
if sample_posterior:
|
| 1083 |
+
z = posterior.sample(generator=generator)
|
| 1084 |
+
else:
|
| 1085 |
+
z = posterior.mode()
|
| 1086 |
+
dec = self.decode(z, return_dict=return_dict)
|
| 1087 |
+
return dec
|
videox_fun/models/wan_camera_adapter.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SimpleAdapter(nn.Module):
|
| 6 |
+
def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1):
|
| 7 |
+
super(SimpleAdapter, self).__init__()
|
| 8 |
+
|
| 9 |
+
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
|
| 10 |
+
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor)
|
| 11 |
+
|
| 12 |
+
# Convolution: reduce spatial dimensions by a factor
|
| 13 |
+
# of 2 (without overlap)
|
| 14 |
+
self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
|
| 15 |
+
|
| 16 |
+
# Residual blocks for feature extraction
|
| 17 |
+
self.residual_blocks = nn.Sequential(
|
| 18 |
+
*[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
# Reshape to merge the frame dimension into batch
|
| 23 |
+
bs, c, f, h, w = x.size()
|
| 24 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
|
| 25 |
+
|
| 26 |
+
# Pixel Unshuffle operation
|
| 27 |
+
x_unshuffled = self.pixel_unshuffle(x)
|
| 28 |
+
|
| 29 |
+
# Convolution operation
|
| 30 |
+
x_conv = self.conv(x_unshuffled)
|
| 31 |
+
|
| 32 |
+
# Feature extraction with residual blocks
|
| 33 |
+
out = self.residual_blocks(x_conv)
|
| 34 |
+
|
| 35 |
+
# Reshape to restore original bf dimension
|
| 36 |
+
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
|
| 37 |
+
|
| 38 |
+
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
|
| 39 |
+
out = out.permute(0, 2, 1, 3, 4)
|
| 40 |
+
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ResidualBlock(nn.Module):
|
| 45 |
+
def __init__(self, dim):
|
| 46 |
+
super(ResidualBlock, self).__init__()
|
| 47 |
+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
| 48 |
+
self.relu = nn.ReLU(inplace=True)
|
| 49 |
+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
residual = x
|
| 53 |
+
out = self.relu(self.conv1(x))
|
| 54 |
+
out = self.conv2(out)
|
| 55 |
+
out += residual
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
# Example usage
|
| 59 |
+
# in_dim = 3
|
| 60 |
+
# out_dim = 64
|
| 61 |
+
# adapter = SimpleAdapterWithReshape(in_dim, out_dim)
|
| 62 |
+
# x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4
|
| 63 |
+
# output = adapter(x)
|
| 64 |
+
# print(output.shape) # Should reflect transformed dimensions
|
videox_fun/models/wan_image_encoder.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
|
| 10 |
+
from .attention_utils import attention, flash_attention
|
| 11 |
+
from .wan_xlm_roberta import XLMRoberta
|
| 12 |
+
from diffusers.configuration_utils import ConfigMixin
|
| 13 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
'XLMRobertaCLIP',
|
| 19 |
+
'clip_xlm_roberta_vit_h_14',
|
| 20 |
+
'CLIPModel',
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def pos_interpolate(pos, seq_len):
|
| 25 |
+
if pos.size(1) == seq_len:
|
| 26 |
+
return pos
|
| 27 |
+
else:
|
| 28 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
| 29 |
+
tar_grid = int(math.sqrt(seq_len))
|
| 30 |
+
n = pos.size(1) - src_grid * src_grid
|
| 31 |
+
return torch.cat([
|
| 32 |
+
pos[:, :n],
|
| 33 |
+
F.interpolate(
|
| 34 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
| 35 |
+
0, 3, 1, 2),
|
| 36 |
+
size=(tar_grid, tar_grid),
|
| 37 |
+
mode='bicubic',
|
| 38 |
+
align_corners=False).flatten(2).transpose(1, 2)
|
| 39 |
+
],
|
| 40 |
+
dim=1)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class QuickGELU(nn.Module):
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return x * torch.sigmoid(1.702 * x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LayerNorm(nn.LayerNorm):
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return super().forward(x.float()).type_as(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SelfAttention(nn.Module):
|
| 56 |
+
|
| 57 |
+
def __init__(self,
|
| 58 |
+
dim,
|
| 59 |
+
num_heads,
|
| 60 |
+
causal=False,
|
| 61 |
+
attn_dropout=0.0,
|
| 62 |
+
proj_dropout=0.0):
|
| 63 |
+
assert dim % num_heads == 0
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.dim = dim
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
self.head_dim = dim // num_heads
|
| 68 |
+
self.causal = causal
|
| 69 |
+
self.attn_dropout = attn_dropout
|
| 70 |
+
self.proj_dropout = proj_dropout
|
| 71 |
+
|
| 72 |
+
# layers
|
| 73 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
| 74 |
+
self.proj = nn.Linear(dim, dim)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
"""
|
| 78 |
+
x: [B, L, C].
|
| 79 |
+
"""
|
| 80 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 81 |
+
|
| 82 |
+
# compute query, key, value
|
| 83 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
| 84 |
+
|
| 85 |
+
# compute attention
|
| 86 |
+
p = self.attn_dropout if self.training else 0.0
|
| 87 |
+
x = attention(q, k, v, dropout_p=p, causal=self.causal, attention_type="none")
|
| 88 |
+
x = x.reshape(b, s, c)
|
| 89 |
+
|
| 90 |
+
# output
|
| 91 |
+
x = self.proj(x)
|
| 92 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SwiGLU(nn.Module):
|
| 97 |
+
|
| 98 |
+
def __init__(self, dim, mid_dim):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.dim = dim
|
| 101 |
+
self.mid_dim = mid_dim
|
| 102 |
+
|
| 103 |
+
# layers
|
| 104 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
| 105 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
| 106 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
| 110 |
+
x = self.fc3(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class AttentionBlock(nn.Module):
|
| 115 |
+
|
| 116 |
+
def __init__(self,
|
| 117 |
+
dim,
|
| 118 |
+
mlp_ratio,
|
| 119 |
+
num_heads,
|
| 120 |
+
post_norm=False,
|
| 121 |
+
causal=False,
|
| 122 |
+
activation='quick_gelu',
|
| 123 |
+
attn_dropout=0.0,
|
| 124 |
+
proj_dropout=0.0,
|
| 125 |
+
norm_eps=1e-5):
|
| 126 |
+
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.dim = dim
|
| 129 |
+
self.mlp_ratio = mlp_ratio
|
| 130 |
+
self.num_heads = num_heads
|
| 131 |
+
self.post_norm = post_norm
|
| 132 |
+
self.causal = causal
|
| 133 |
+
self.norm_eps = norm_eps
|
| 134 |
+
|
| 135 |
+
# layers
|
| 136 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
| 137 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
| 138 |
+
proj_dropout)
|
| 139 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
| 140 |
+
if activation == 'swi_glu':
|
| 141 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
| 142 |
+
else:
|
| 143 |
+
self.mlp = nn.Sequential(
|
| 144 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 145 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 146 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
if self.post_norm:
|
| 150 |
+
x = x + self.norm1(self.attn(x))
|
| 151 |
+
x = x + self.norm2(self.mlp(x))
|
| 152 |
+
else:
|
| 153 |
+
x = x + self.attn(self.norm1(x))
|
| 154 |
+
x = x + self.mlp(self.norm2(x))
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class AttentionPool(nn.Module):
|
| 159 |
+
|
| 160 |
+
def __init__(self,
|
| 161 |
+
dim,
|
| 162 |
+
mlp_ratio,
|
| 163 |
+
num_heads,
|
| 164 |
+
activation='gelu',
|
| 165 |
+
proj_dropout=0.0,
|
| 166 |
+
norm_eps=1e-5):
|
| 167 |
+
assert dim % num_heads == 0
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.dim = dim
|
| 170 |
+
self.mlp_ratio = mlp_ratio
|
| 171 |
+
self.num_heads = num_heads
|
| 172 |
+
self.head_dim = dim // num_heads
|
| 173 |
+
self.proj_dropout = proj_dropout
|
| 174 |
+
self.norm_eps = norm_eps
|
| 175 |
+
|
| 176 |
+
# layers
|
| 177 |
+
gain = 1.0 / math.sqrt(dim)
|
| 178 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 179 |
+
self.to_q = nn.Linear(dim, dim)
|
| 180 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
| 181 |
+
self.proj = nn.Linear(dim, dim)
|
| 182 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
| 183 |
+
self.mlp = nn.Sequential(
|
| 184 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 185 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 186 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
"""
|
| 190 |
+
x: [B, L, C].
|
| 191 |
+
"""
|
| 192 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 193 |
+
|
| 194 |
+
# compute query, key, value
|
| 195 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
| 196 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
| 197 |
+
|
| 198 |
+
# compute attention
|
| 199 |
+
x = flash_attention(q, k, v, version=2)
|
| 200 |
+
x = x.reshape(b, 1, c)
|
| 201 |
+
|
| 202 |
+
# output
|
| 203 |
+
x = self.proj(x)
|
| 204 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 205 |
+
|
| 206 |
+
# mlp
|
| 207 |
+
x = x + self.mlp(self.norm(x))
|
| 208 |
+
return x[:, 0]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class VisionTransformer(nn.Module):
|
| 212 |
+
|
| 213 |
+
def __init__(self,
|
| 214 |
+
image_size=224,
|
| 215 |
+
patch_size=16,
|
| 216 |
+
dim=768,
|
| 217 |
+
mlp_ratio=4,
|
| 218 |
+
out_dim=512,
|
| 219 |
+
num_heads=12,
|
| 220 |
+
num_layers=12,
|
| 221 |
+
pool_type='token',
|
| 222 |
+
pre_norm=True,
|
| 223 |
+
post_norm=False,
|
| 224 |
+
activation='quick_gelu',
|
| 225 |
+
attn_dropout=0.0,
|
| 226 |
+
proj_dropout=0.0,
|
| 227 |
+
embedding_dropout=0.0,
|
| 228 |
+
norm_eps=1e-5):
|
| 229 |
+
if image_size % patch_size != 0:
|
| 230 |
+
print(
|
| 231 |
+
'[WARNING] image_size is not divisible by patch_size',
|
| 232 |
+
flush=True)
|
| 233 |
+
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
| 234 |
+
out_dim = out_dim or dim
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.image_size = image_size
|
| 237 |
+
self.patch_size = patch_size
|
| 238 |
+
self.num_patches = (image_size // patch_size)**2
|
| 239 |
+
self.dim = dim
|
| 240 |
+
self.mlp_ratio = mlp_ratio
|
| 241 |
+
self.out_dim = out_dim
|
| 242 |
+
self.num_heads = num_heads
|
| 243 |
+
self.num_layers = num_layers
|
| 244 |
+
self.pool_type = pool_type
|
| 245 |
+
self.post_norm = post_norm
|
| 246 |
+
self.norm_eps = norm_eps
|
| 247 |
+
|
| 248 |
+
# embeddings
|
| 249 |
+
gain = 1.0 / math.sqrt(dim)
|
| 250 |
+
self.patch_embedding = nn.Conv2d(
|
| 251 |
+
3,
|
| 252 |
+
dim,
|
| 253 |
+
kernel_size=patch_size,
|
| 254 |
+
stride=patch_size,
|
| 255 |
+
bias=not pre_norm)
|
| 256 |
+
if pool_type in ('token', 'token_fc'):
|
| 257 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 258 |
+
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
| 259 |
+
1, self.num_patches +
|
| 260 |
+
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
| 261 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
| 262 |
+
|
| 263 |
+
# transformer
|
| 264 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
| 265 |
+
self.transformer = nn.Sequential(*[
|
| 266 |
+
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
| 267 |
+
activation, attn_dropout, proj_dropout, norm_eps)
|
| 268 |
+
for _ in range(num_layers)
|
| 269 |
+
])
|
| 270 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
| 271 |
+
|
| 272 |
+
# head
|
| 273 |
+
if pool_type == 'token':
|
| 274 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
| 275 |
+
elif pool_type == 'token_fc':
|
| 276 |
+
self.head = nn.Linear(dim, out_dim)
|
| 277 |
+
elif pool_type == 'attn_pool':
|
| 278 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
| 279 |
+
proj_dropout, norm_eps)
|
| 280 |
+
|
| 281 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
| 282 |
+
b = x.size(0)
|
| 283 |
+
|
| 284 |
+
# embeddings
|
| 285 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
| 286 |
+
if self.pool_type in ('token', 'token_fc'):
|
| 287 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
|
| 288 |
+
if interpolation:
|
| 289 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
| 290 |
+
else:
|
| 291 |
+
e = self.pos_embedding
|
| 292 |
+
x = self.dropout(x + e)
|
| 293 |
+
if self.pre_norm is not None:
|
| 294 |
+
x = self.pre_norm(x)
|
| 295 |
+
|
| 296 |
+
# transformer
|
| 297 |
+
if use_31_block:
|
| 298 |
+
x = self.transformer[:-1](x)
|
| 299 |
+
return x
|
| 300 |
+
else:
|
| 301 |
+
x = self.transformer(x)
|
| 302 |
+
return x
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class XLMRobertaWithHead(XLMRoberta):
|
| 306 |
+
|
| 307 |
+
def __init__(self, **kwargs):
|
| 308 |
+
self.out_dim = kwargs.pop('out_dim')
|
| 309 |
+
super().__init__(**kwargs)
|
| 310 |
+
|
| 311 |
+
# head
|
| 312 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
| 313 |
+
self.head = nn.Sequential(
|
| 314 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
| 315 |
+
nn.Linear(mid_dim, self.out_dim, bias=False))
|
| 316 |
+
|
| 317 |
+
def forward(self, ids):
|
| 318 |
+
# xlm-roberta
|
| 319 |
+
x = super().forward(ids)
|
| 320 |
+
|
| 321 |
+
# average pooling
|
| 322 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
| 323 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
| 324 |
+
|
| 325 |
+
# head
|
| 326 |
+
x = self.head(x)
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class XLMRobertaCLIP(nn.Module):
|
| 331 |
+
|
| 332 |
+
def __init__(self,
|
| 333 |
+
embed_dim=1024,
|
| 334 |
+
image_size=224,
|
| 335 |
+
patch_size=14,
|
| 336 |
+
vision_dim=1280,
|
| 337 |
+
vision_mlp_ratio=4,
|
| 338 |
+
vision_heads=16,
|
| 339 |
+
vision_layers=32,
|
| 340 |
+
vision_pool='token',
|
| 341 |
+
vision_pre_norm=True,
|
| 342 |
+
vision_post_norm=False,
|
| 343 |
+
activation='gelu',
|
| 344 |
+
vocab_size=250002,
|
| 345 |
+
max_text_len=514,
|
| 346 |
+
type_size=1,
|
| 347 |
+
pad_id=1,
|
| 348 |
+
text_dim=1024,
|
| 349 |
+
text_heads=16,
|
| 350 |
+
text_layers=24,
|
| 351 |
+
text_post_norm=True,
|
| 352 |
+
text_dropout=0.1,
|
| 353 |
+
attn_dropout=0.0,
|
| 354 |
+
proj_dropout=0.0,
|
| 355 |
+
embedding_dropout=0.0,
|
| 356 |
+
norm_eps=1e-5):
|
| 357 |
+
super().__init__()
|
| 358 |
+
self.embed_dim = embed_dim
|
| 359 |
+
self.image_size = image_size
|
| 360 |
+
self.patch_size = patch_size
|
| 361 |
+
self.vision_dim = vision_dim
|
| 362 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
| 363 |
+
self.vision_heads = vision_heads
|
| 364 |
+
self.vision_layers = vision_layers
|
| 365 |
+
self.vision_pre_norm = vision_pre_norm
|
| 366 |
+
self.vision_post_norm = vision_post_norm
|
| 367 |
+
self.activation = activation
|
| 368 |
+
self.vocab_size = vocab_size
|
| 369 |
+
self.max_text_len = max_text_len
|
| 370 |
+
self.type_size = type_size
|
| 371 |
+
self.pad_id = pad_id
|
| 372 |
+
self.text_dim = text_dim
|
| 373 |
+
self.text_heads = text_heads
|
| 374 |
+
self.text_layers = text_layers
|
| 375 |
+
self.text_post_norm = text_post_norm
|
| 376 |
+
self.norm_eps = norm_eps
|
| 377 |
+
|
| 378 |
+
# models
|
| 379 |
+
self.visual = VisionTransformer(
|
| 380 |
+
image_size=image_size,
|
| 381 |
+
patch_size=patch_size,
|
| 382 |
+
dim=vision_dim,
|
| 383 |
+
mlp_ratio=vision_mlp_ratio,
|
| 384 |
+
out_dim=embed_dim,
|
| 385 |
+
num_heads=vision_heads,
|
| 386 |
+
num_layers=vision_layers,
|
| 387 |
+
pool_type=vision_pool,
|
| 388 |
+
pre_norm=vision_pre_norm,
|
| 389 |
+
post_norm=vision_post_norm,
|
| 390 |
+
activation=activation,
|
| 391 |
+
attn_dropout=attn_dropout,
|
| 392 |
+
proj_dropout=proj_dropout,
|
| 393 |
+
embedding_dropout=embedding_dropout,
|
| 394 |
+
norm_eps=norm_eps)
|
| 395 |
+
self.textual = XLMRobertaWithHead(
|
| 396 |
+
vocab_size=vocab_size,
|
| 397 |
+
max_seq_len=max_text_len,
|
| 398 |
+
type_size=type_size,
|
| 399 |
+
pad_id=pad_id,
|
| 400 |
+
dim=text_dim,
|
| 401 |
+
out_dim=embed_dim,
|
| 402 |
+
num_heads=text_heads,
|
| 403 |
+
num_layers=text_layers,
|
| 404 |
+
post_norm=text_post_norm,
|
| 405 |
+
dropout=text_dropout)
|
| 406 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
| 407 |
+
|
| 408 |
+
def forward(self, imgs, txt_ids):
|
| 409 |
+
"""
|
| 410 |
+
imgs: [B, 3, H, W] of torch.float32.
|
| 411 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
| 412 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
| 413 |
+
txt_ids: [B, L] of torch.long.
|
| 414 |
+
Encoded by data.CLIPTokenizer.
|
| 415 |
+
"""
|
| 416 |
+
xi = self.visual(imgs)
|
| 417 |
+
xt = self.textual(txt_ids)
|
| 418 |
+
return xi, xt
|
| 419 |
+
|
| 420 |
+
def param_groups(self):
|
| 421 |
+
groups = [{
|
| 422 |
+
'params': [
|
| 423 |
+
p for n, p in self.named_parameters()
|
| 424 |
+
if 'norm' in n or n.endswith('bias')
|
| 425 |
+
],
|
| 426 |
+
'weight_decay': 0.0
|
| 427 |
+
}, {
|
| 428 |
+
'params': [
|
| 429 |
+
p for n, p in self.named_parameters()
|
| 430 |
+
if not ('norm' in n or n.endswith('bias'))
|
| 431 |
+
]
|
| 432 |
+
}]
|
| 433 |
+
return groups
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _clip(pretrained=False,
|
| 437 |
+
pretrained_name=None,
|
| 438 |
+
model_cls=XLMRobertaCLIP,
|
| 439 |
+
return_transforms=False,
|
| 440 |
+
return_tokenizer=False,
|
| 441 |
+
tokenizer_padding='eos',
|
| 442 |
+
dtype=torch.float32,
|
| 443 |
+
device='cpu',
|
| 444 |
+
**kwargs):
|
| 445 |
+
# init a model on device
|
| 446 |
+
with torch.device(device):
|
| 447 |
+
model = model_cls(**kwargs)
|
| 448 |
+
|
| 449 |
+
# set device
|
| 450 |
+
model = model.to(dtype=dtype, device=device)
|
| 451 |
+
output = (model,)
|
| 452 |
+
|
| 453 |
+
# init transforms
|
| 454 |
+
if return_transforms:
|
| 455 |
+
# mean and std
|
| 456 |
+
if 'siglip' in pretrained_name.lower():
|
| 457 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
| 458 |
+
else:
|
| 459 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 460 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
| 461 |
+
|
| 462 |
+
# transforms
|
| 463 |
+
transforms = T.Compose([
|
| 464 |
+
T.Resize((model.image_size, model.image_size),
|
| 465 |
+
interpolation=T.InterpolationMode.BICUBIC),
|
| 466 |
+
T.ToTensor(),
|
| 467 |
+
T.Normalize(mean=mean, std=std)
|
| 468 |
+
])
|
| 469 |
+
output += (transforms,)
|
| 470 |
+
return output[0] if len(output) == 1 else output
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def clip_xlm_roberta_vit_h_14(
|
| 474 |
+
pretrained=False,
|
| 475 |
+
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
| 476 |
+
**kwargs):
|
| 477 |
+
cfg = dict(
|
| 478 |
+
embed_dim=1024,
|
| 479 |
+
image_size=224,
|
| 480 |
+
patch_size=14,
|
| 481 |
+
vision_dim=1280,
|
| 482 |
+
vision_mlp_ratio=4,
|
| 483 |
+
vision_heads=16,
|
| 484 |
+
vision_layers=32,
|
| 485 |
+
vision_pool='token',
|
| 486 |
+
activation='gelu',
|
| 487 |
+
vocab_size=250002,
|
| 488 |
+
max_text_len=514,
|
| 489 |
+
type_size=1,
|
| 490 |
+
pad_id=1,
|
| 491 |
+
text_dim=1024,
|
| 492 |
+
text_heads=16,
|
| 493 |
+
text_layers=24,
|
| 494 |
+
text_post_norm=True,
|
| 495 |
+
text_dropout=0.1,
|
| 496 |
+
attn_dropout=0.0,
|
| 497 |
+
proj_dropout=0.0,
|
| 498 |
+
embedding_dropout=0.0)
|
| 499 |
+
cfg.update(**kwargs)
|
| 500 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 504 |
+
|
| 505 |
+
def __init__(self):
|
| 506 |
+
super(CLIPModel, self).__init__()
|
| 507 |
+
# init model
|
| 508 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
| 509 |
+
pretrained=False,
|
| 510 |
+
return_transforms=True,
|
| 511 |
+
return_tokenizer=False)
|
| 512 |
+
|
| 513 |
+
def forward(self, videos):
|
| 514 |
+
# preprocess
|
| 515 |
+
size = (self.model.image_size,) * 2
|
| 516 |
+
videos = torch.cat([
|
| 517 |
+
F.interpolate(
|
| 518 |
+
u.transpose(0, 1),
|
| 519 |
+
size=size,
|
| 520 |
+
mode='bicubic',
|
| 521 |
+
align_corners=False) for u in videos
|
| 522 |
+
])
|
| 523 |
+
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
| 524 |
+
|
| 525 |
+
# forward
|
| 526 |
+
with torch.cuda.amp.autocast(dtype=self.dtype):
|
| 527 |
+
out = self.model.visual(videos, use_31_block=True)
|
| 528 |
+
return out
|
| 529 |
+
|
| 530 |
+
@classmethod
|
| 531 |
+
def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}):
|
| 532 |
+
def filter_kwargs(cls, kwargs):
|
| 533 |
+
import inspect
|
| 534 |
+
sig = inspect.signature(cls.__init__)
|
| 535 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 536 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 537 |
+
return filtered_kwargs
|
| 538 |
+
|
| 539 |
+
model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
|
| 540 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 541 |
+
from safetensors.torch import load_file, safe_open
|
| 542 |
+
state_dict = load_file(pretrained_model_path)
|
| 543 |
+
else:
|
| 544 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 545 |
+
tmp_state_dict = {}
|
| 546 |
+
for key in state_dict:
|
| 547 |
+
tmp_state_dict["model." + key] = state_dict[key]
|
| 548 |
+
state_dict = tmp_state_dict
|
| 549 |
+
m, u = model.load_state_dict(state_dict)
|
| 550 |
+
|
| 551 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 552 |
+
print(m, u)
|
| 553 |
+
return model
|
videox_fun/models/wan_text_encoder.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from diffusers.configuration_utils import ConfigMixin
|
| 10 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def fp16_clamp(x):
|
| 15 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
| 16 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
| 17 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def init_weights(m):
|
| 22 |
+
if isinstance(m, T5LayerNorm):
|
| 23 |
+
nn.init.ones_(m.weight)
|
| 24 |
+
elif isinstance(m, T5FeedForward):
|
| 25 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
| 26 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
| 27 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
| 28 |
+
elif isinstance(m, T5Attention):
|
| 29 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
| 30 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
| 31 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
| 32 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
| 33 |
+
elif isinstance(m, T5RelativeEmbedding):
|
| 34 |
+
nn.init.normal_(
|
| 35 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class GELU(nn.Module):
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return 0.5 * x * (1.0 + torch.tanh(
|
| 41 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class T5LayerNorm(nn.Module):
|
| 45 |
+
def __init__(self, dim, eps=1e-6):
|
| 46 |
+
super(T5LayerNorm, self).__init__()
|
| 47 |
+
self.dim = dim
|
| 48 |
+
self.eps = eps
|
| 49 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
| 53 |
+
self.eps)
|
| 54 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 55 |
+
x = x.type_as(self.weight)
|
| 56 |
+
return self.weight * x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class T5Attention(nn.Module):
|
| 60 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
| 61 |
+
assert dim_attn % num_heads == 0
|
| 62 |
+
super(T5Attention, self).__init__()
|
| 63 |
+
self.dim = dim
|
| 64 |
+
self.dim_attn = dim_attn
|
| 65 |
+
self.num_heads = num_heads
|
| 66 |
+
self.head_dim = dim_attn // num_heads
|
| 67 |
+
|
| 68 |
+
# layers
|
| 69 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
| 70 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
| 71 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
| 72 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
| 73 |
+
self.dropout = nn.Dropout(dropout)
|
| 74 |
+
|
| 75 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
| 76 |
+
"""
|
| 77 |
+
x: [B, L1, C].
|
| 78 |
+
context: [B, L2, C] or None.
|
| 79 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
| 80 |
+
"""
|
| 81 |
+
# check inputs
|
| 82 |
+
context = x if context is None else context
|
| 83 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
| 84 |
+
|
| 85 |
+
# compute query, key, value
|
| 86 |
+
q = self.q(x).view(b, -1, n, c)
|
| 87 |
+
k = self.k(context).view(b, -1, n, c)
|
| 88 |
+
v = self.v(context).view(b, -1, n, c)
|
| 89 |
+
|
| 90 |
+
# attention bias
|
| 91 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
| 92 |
+
if pos_bias is not None:
|
| 93 |
+
attn_bias += pos_bias
|
| 94 |
+
if mask is not None:
|
| 95 |
+
assert mask.ndim in [2, 3]
|
| 96 |
+
mask = mask.view(b, 1, 1,
|
| 97 |
+
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
| 98 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
| 99 |
+
|
| 100 |
+
# compute attention (T5 does not use scaling)
|
| 101 |
+
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
| 102 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 103 |
+
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
| 104 |
+
|
| 105 |
+
# output
|
| 106 |
+
x = x.reshape(b, -1, n * c)
|
| 107 |
+
x = self.o(x)
|
| 108 |
+
x = self.dropout(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class T5FeedForward(nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
| 115 |
+
super(T5FeedForward, self).__init__()
|
| 116 |
+
self.dim = dim
|
| 117 |
+
self.dim_ffn = dim_ffn
|
| 118 |
+
|
| 119 |
+
# layers
|
| 120 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
| 121 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
| 122 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
| 123 |
+
self.dropout = nn.Dropout(dropout)
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = self.fc1(x) * self.gate(x)
|
| 127 |
+
x = self.dropout(x)
|
| 128 |
+
x = self.fc2(x)
|
| 129 |
+
x = self.dropout(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class T5SelfAttention(nn.Module):
|
| 134 |
+
def __init__(self,
|
| 135 |
+
dim,
|
| 136 |
+
dim_attn,
|
| 137 |
+
dim_ffn,
|
| 138 |
+
num_heads,
|
| 139 |
+
num_buckets,
|
| 140 |
+
shared_pos=True,
|
| 141 |
+
dropout=0.1):
|
| 142 |
+
super(T5SelfAttention, self).__init__()
|
| 143 |
+
self.dim = dim
|
| 144 |
+
self.dim_attn = dim_attn
|
| 145 |
+
self.dim_ffn = dim_ffn
|
| 146 |
+
self.num_heads = num_heads
|
| 147 |
+
self.num_buckets = num_buckets
|
| 148 |
+
self.shared_pos = shared_pos
|
| 149 |
+
|
| 150 |
+
# layers
|
| 151 |
+
self.norm1 = T5LayerNorm(dim)
|
| 152 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 153 |
+
self.norm2 = T5LayerNorm(dim)
|
| 154 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 155 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 156 |
+
num_buckets, num_heads, bidirectional=True)
|
| 157 |
+
|
| 158 |
+
def forward(self, x, mask=None, pos_bias=None):
|
| 159 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 160 |
+
x.size(1), x.size(1))
|
| 161 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 162 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class T5CrossAttention(nn.Module):
|
| 167 |
+
def __init__(self,
|
| 168 |
+
dim,
|
| 169 |
+
dim_attn,
|
| 170 |
+
dim_ffn,
|
| 171 |
+
num_heads,
|
| 172 |
+
num_buckets,
|
| 173 |
+
shared_pos=True,
|
| 174 |
+
dropout=0.1):
|
| 175 |
+
super(T5CrossAttention, self).__init__()
|
| 176 |
+
self.dim = dim
|
| 177 |
+
self.dim_attn = dim_attn
|
| 178 |
+
self.dim_ffn = dim_ffn
|
| 179 |
+
self.num_heads = num_heads
|
| 180 |
+
self.num_buckets = num_buckets
|
| 181 |
+
self.shared_pos = shared_pos
|
| 182 |
+
|
| 183 |
+
# layers
|
| 184 |
+
self.norm1 = T5LayerNorm(dim)
|
| 185 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 186 |
+
self.norm2 = T5LayerNorm(dim)
|
| 187 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 188 |
+
self.norm3 = T5LayerNorm(dim)
|
| 189 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 190 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 191 |
+
num_buckets, num_heads, bidirectional=False)
|
| 192 |
+
|
| 193 |
+
def forward(self,
|
| 194 |
+
x,
|
| 195 |
+
mask=None,
|
| 196 |
+
encoder_states=None,
|
| 197 |
+
encoder_mask=None,
|
| 198 |
+
pos_bias=None):
|
| 199 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 200 |
+
x.size(1), x.size(1))
|
| 201 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 202 |
+
x = fp16_clamp(x + self.cross_attn(
|
| 203 |
+
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
| 204 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
| 205 |
+
return x
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class T5RelativeEmbedding(nn.Module):
|
| 209 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
| 210 |
+
super(T5RelativeEmbedding, self).__init__()
|
| 211 |
+
self.num_buckets = num_buckets
|
| 212 |
+
self.num_heads = num_heads
|
| 213 |
+
self.bidirectional = bidirectional
|
| 214 |
+
self.max_dist = max_dist
|
| 215 |
+
|
| 216 |
+
# layers
|
| 217 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
| 218 |
+
|
| 219 |
+
def forward(self, lq, lk):
|
| 220 |
+
device = self.embedding.weight.device
|
| 221 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
| 222 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
| 223 |
+
if torch.device(type="meta") != device:
|
| 224 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
| 225 |
+
torch.arange(lq, device=device).unsqueeze(1)
|
| 226 |
+
else:
|
| 227 |
+
rel_pos = torch.arange(lk).unsqueeze(0) - \
|
| 228 |
+
torch.arange(lq).unsqueeze(1)
|
| 229 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
| 230 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
| 231 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
| 232 |
+
0) # [1, N, Lq, Lk]
|
| 233 |
+
return rel_pos_embeds.contiguous()
|
| 234 |
+
|
| 235 |
+
def _relative_position_bucket(self, rel_pos):
|
| 236 |
+
# preprocess
|
| 237 |
+
if self.bidirectional:
|
| 238 |
+
num_buckets = self.num_buckets // 2
|
| 239 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
| 240 |
+
rel_pos = torch.abs(rel_pos)
|
| 241 |
+
else:
|
| 242 |
+
num_buckets = self.num_buckets
|
| 243 |
+
rel_buckets = 0
|
| 244 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
| 245 |
+
|
| 246 |
+
# embeddings for small and large positions
|
| 247 |
+
max_exact = num_buckets // 2
|
| 248 |
+
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
| 249 |
+
math.log(self.max_dist / max_exact) *
|
| 250 |
+
(num_buckets - max_exact)).long()
|
| 251 |
+
rel_pos_large = torch.min(
|
| 252 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
| 253 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
| 254 |
+
return rel_buckets
|
| 255 |
+
|
| 256 |
+
class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 257 |
+
def __init__(self,
|
| 258 |
+
vocab,
|
| 259 |
+
dim,
|
| 260 |
+
dim_attn,
|
| 261 |
+
dim_ffn,
|
| 262 |
+
num_heads,
|
| 263 |
+
num_layers,
|
| 264 |
+
num_buckets,
|
| 265 |
+
shared_pos=True,
|
| 266 |
+
dropout=0.1):
|
| 267 |
+
super(WanT5EncoderModel, self).__init__()
|
| 268 |
+
self.dim = dim
|
| 269 |
+
self.dim_attn = dim_attn
|
| 270 |
+
self.dim_ffn = dim_ffn
|
| 271 |
+
self.num_heads = num_heads
|
| 272 |
+
self.num_layers = num_layers
|
| 273 |
+
self.num_buckets = num_buckets
|
| 274 |
+
self.shared_pos = shared_pos
|
| 275 |
+
|
| 276 |
+
# layers
|
| 277 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
| 278 |
+
else nn.Embedding(vocab, dim)
|
| 279 |
+
self.pos_embedding = T5RelativeEmbedding(
|
| 280 |
+
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
| 281 |
+
self.dropout = nn.Dropout(dropout)
|
| 282 |
+
self.blocks = nn.ModuleList([
|
| 283 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
| 284 |
+
shared_pos, dropout) for _ in range(num_layers)
|
| 285 |
+
])
|
| 286 |
+
self.norm = T5LayerNorm(dim)
|
| 287 |
+
|
| 288 |
+
# initialize weights
|
| 289 |
+
self.apply(init_weights)
|
| 290 |
+
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 294 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 295 |
+
):
|
| 296 |
+
x = self.token_embedding(input_ids)
|
| 297 |
+
x = self.dropout(x)
|
| 298 |
+
e = self.pos_embedding(x.size(1),
|
| 299 |
+
x.size(1)) if self.shared_pos else None
|
| 300 |
+
for block in self.blocks:
|
| 301 |
+
x = block(x, attention_mask, pos_bias=e)
|
| 302 |
+
x = self.norm(x)
|
| 303 |
+
x = self.dropout(x)
|
| 304 |
+
return (x, )
|
| 305 |
+
|
| 306 |
+
@classmethod
|
| 307 |
+
def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16):
|
| 308 |
+
def filter_kwargs(cls, kwargs):
|
| 309 |
+
import inspect
|
| 310 |
+
sig = inspect.signature(cls.__init__)
|
| 311 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 312 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 313 |
+
return filtered_kwargs
|
| 314 |
+
|
| 315 |
+
if low_cpu_mem_usage:
|
| 316 |
+
try:
|
| 317 |
+
import re
|
| 318 |
+
|
| 319 |
+
from diffusers import __version__ as diffusers_version
|
| 320 |
+
if diffusers_version >= "0.33.0":
|
| 321 |
+
from diffusers.models.model_loading_utils import \
|
| 322 |
+
load_model_dict_into_meta
|
| 323 |
+
else:
|
| 324 |
+
from diffusers.models.modeling_utils import \
|
| 325 |
+
load_model_dict_into_meta
|
| 326 |
+
from diffusers.utils import is_accelerate_available
|
| 327 |
+
if is_accelerate_available():
|
| 328 |
+
import accelerate
|
| 329 |
+
|
| 330 |
+
# Instantiate model with empty weights
|
| 331 |
+
with accelerate.init_empty_weights():
|
| 332 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 333 |
+
|
| 334 |
+
param_device = "cpu"
|
| 335 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 336 |
+
from safetensors.torch import load_file
|
| 337 |
+
state_dict = load_file(pretrained_model_path)
|
| 338 |
+
else:
|
| 339 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 340 |
+
|
| 341 |
+
if diffusers_version >= "0.33.0":
|
| 342 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 343 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 344 |
+
load_model_dict_into_meta(
|
| 345 |
+
model,
|
| 346 |
+
state_dict,
|
| 347 |
+
dtype=torch_dtype,
|
| 348 |
+
model_name_or_path=pretrained_model_path,
|
| 349 |
+
)
|
| 350 |
+
else:
|
| 351 |
+
# move the params from meta device to cpu
|
| 352 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 353 |
+
if len(missing_keys) > 0:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
| 356 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 357 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 358 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 362 |
+
model,
|
| 363 |
+
state_dict,
|
| 364 |
+
device=param_device,
|
| 365 |
+
dtype=torch_dtype,
|
| 366 |
+
model_name_or_path=pretrained_model_path,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 370 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 371 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 372 |
+
|
| 373 |
+
if len(unexpected_keys) > 0:
|
| 374 |
+
print(
|
| 375 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
return model
|
| 379 |
+
except Exception as e:
|
| 380 |
+
print(
|
| 381 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 385 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 386 |
+
from safetensors.torch import load_file, safe_open
|
| 387 |
+
state_dict = load_file(pretrained_model_path)
|
| 388 |
+
else:
|
| 389 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 390 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 391 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 392 |
+
print(m, u)
|
| 393 |
+
|
| 394 |
+
model = model.to(torch_dtype)
|
| 395 |
+
return model
|
videox_fun/models/wan_transformer3d.py
ADDED
|
@@ -0,0 +1,1399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
|
| 4 |
+
import glob
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import types
|
| 9 |
+
import warnings
|
| 10 |
+
from typing import Any, Dict, Optional, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 17 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
+
from diffusers.utils import is_torch_version, logging
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
from ..dist import (get_sequence_parallel_rank,
|
| 23 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 24 |
+
usp_attn_forward, xFuserLongContextAttention)
|
| 25 |
+
from ..utils import cfg_skip
|
| 26 |
+
from .attention_utils import attention
|
| 27 |
+
from .cache_utils import TeaCache
|
| 28 |
+
from .wan_camera_adapter import SimpleAdapter
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 32 |
+
# preprocess
|
| 33 |
+
assert dim % 2 == 0
|
| 34 |
+
half = dim // 2
|
| 35 |
+
position = position.type(torch.float64)
|
| 36 |
+
|
| 37 |
+
# calculation
|
| 38 |
+
sinusoid = torch.outer(
|
| 39 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
| 40 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@amp.autocast(enabled=False)
|
| 45 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 46 |
+
assert dim % 2 == 0
|
| 47 |
+
freqs = torch.outer(
|
| 48 |
+
torch.arange(max_seq_len),
|
| 49 |
+
1.0 / torch.pow(theta,
|
| 50 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
| 51 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 52 |
+
return freqs
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py
|
| 56 |
+
@amp.autocast(enabled=False)
|
| 57 |
+
def get_1d_rotary_pos_embed_riflex(
|
| 58 |
+
pos: Union[np.ndarray, int],
|
| 59 |
+
dim: int,
|
| 60 |
+
theta: float = 10000.0,
|
| 61 |
+
use_real=False,
|
| 62 |
+
k: Optional[int] = None,
|
| 63 |
+
L_test: Optional[int] = None,
|
| 64 |
+
L_test_scale: Optional[int] = None,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 68 |
+
|
| 69 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
| 70 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
| 71 |
+
data type.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
dim (`int`): Dimension of the frequency tensor.
|
| 75 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
| 76 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
| 77 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
| 78 |
+
use_real (`bool`, *optional*):
|
| 79 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 80 |
+
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
|
| 81 |
+
L_test (`int`, *optional*, defaults to None): the number of frames for inference
|
| 82 |
+
Returns:
|
| 83 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| 84 |
+
"""
|
| 85 |
+
assert dim % 2 == 0
|
| 86 |
+
|
| 87 |
+
if isinstance(pos, int):
|
| 88 |
+
pos = torch.arange(pos)
|
| 89 |
+
if isinstance(pos, np.ndarray):
|
| 90 |
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
| 91 |
+
|
| 92 |
+
freqs = 1.0 / torch.pow(theta,
|
| 93 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim))
|
| 94 |
+
|
| 95 |
+
# === Riflex modification start ===
|
| 96 |
+
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
|
| 97 |
+
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
|
| 98 |
+
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
|
| 99 |
+
if k is not None:
|
| 100 |
+
freqs[k-1] = 0.9 * 2 * torch.pi / L_test
|
| 101 |
+
# === Riflex modification end ===
|
| 102 |
+
if L_test_scale is not None:
|
| 103 |
+
freqs[k-1] = freqs[k-1] / L_test_scale
|
| 104 |
+
|
| 105 |
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
| 106 |
+
if use_real:
|
| 107 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
| 108 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
| 109 |
+
return freqs_cos, freqs_sin
|
| 110 |
+
else:
|
| 111 |
+
# lumina
|
| 112 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
| 113 |
+
return freqs_cis
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 117 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 118 |
+
tw = tgt_width
|
| 119 |
+
th = tgt_height
|
| 120 |
+
h, w = src
|
| 121 |
+
r = h / w
|
| 122 |
+
if r > (th / tw):
|
| 123 |
+
resize_height = th
|
| 124 |
+
resize_width = int(round(th / h * w))
|
| 125 |
+
else:
|
| 126 |
+
resize_width = tw
|
| 127 |
+
resize_height = int(round(tw / w * h))
|
| 128 |
+
|
| 129 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 130 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 131 |
+
|
| 132 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@amp.autocast(enabled=False)
|
| 136 |
+
@torch.compiler.disable()
|
| 137 |
+
def rope_apply(x, grid_sizes, freqs, frame_split_indices=None, ground_frame_indices=None):
|
| 138 |
+
n, c = x.size(2), x.size(3) // 2
|
| 139 |
+
|
| 140 |
+
# split freqs
|
| 141 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 142 |
+
|
| 143 |
+
# loop over samples
|
| 144 |
+
output = []
|
| 145 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 146 |
+
seq_len = f * h * w
|
| 147 |
+
|
| 148 |
+
# precompute multipliers
|
| 149 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
|
| 150 |
+
seq_len, n, -1, 2))
|
| 151 |
+
|
| 152 |
+
# Handle temporal freqs with split for paired data and ground frames
|
| 153 |
+
if frame_split_indices is not None and i < len(frame_split_indices):
|
| 154 |
+
# print("applying repeat rope")
|
| 155 |
+
# print(f"[ROPE] frame_split_indices: {frame_split_indices}")
|
| 156 |
+
# Split temporal positions: src [0, f_src-1], ground [0], tgt [0, f_tgt-1]
|
| 157 |
+
f_src = frame_split_indices[i]
|
| 158 |
+
|
| 159 |
+
# Check if we have ground frames
|
| 160 |
+
if ground_frame_indices is not None and i < len(ground_frame_indices):
|
| 161 |
+
ground_start, ground_end = ground_frame_indices[i]
|
| 162 |
+
f_ground = ground_end - ground_start
|
| 163 |
+
f_tgt = f - f_src - f_ground
|
| 164 |
+
|
| 165 |
+
# print(f"[ROPE] CoT data: f={f}, f_src={f_src}, f_ground={f_ground}, f_tgt={f_tgt}")
|
| 166 |
+
# print(f"[ROPE] ground_frame_indices: {ground_frame_indices}")
|
| 167 |
+
# exit()
|
| 168 |
+
# Generate independent temporal freqs
|
| 169 |
+
# Src: positions [1..f_src]
|
| 170 |
+
|
| 171 |
+
freqs_src_t = freqs[0][1:f_src + 1].view(f_src, 1, 1, -1).expand(f_src, h, w, -1)
|
| 172 |
+
|
| 173 |
+
# Ground: force all frames to use position 0
|
| 174 |
+
freqs_ground_t = freqs[0][:1].view(1, 1, 1, -1).repeat(f_ground, h, w, 1)
|
| 175 |
+
|
| 176 |
+
# Tgt: positions [1..f_tgt]
|
| 177 |
+
freqs_tgt_t = freqs[0][1:f_tgt + 1].view(f_tgt, 1, 1, -1).expand(f_tgt, h, w, -1)
|
| 178 |
+
|
| 179 |
+
freqs_temporal = torch.cat([freqs_src_t, freqs_ground_t, freqs_tgt_t], dim=0)
|
| 180 |
+
else:
|
| 181 |
+
# No ground frames, regular paired data
|
| 182 |
+
# print(f"[ROPE] Paired data: f={f}, f_src={f_src}, f_tgt={f - f_src}")
|
| 183 |
+
f_tgt = f - f_src
|
| 184 |
+
|
| 185 |
+
# Generate independent temporal freqs for src and tgt
|
| 186 |
+
freqs_src_t = freqs[0][:f_src].view(f_src, 1, 1, -1).expand(f_src, h, w, -1)
|
| 187 |
+
freqs_tgt_t = freqs[0][:f_tgt].view(f_tgt, 1, 1, -1).expand(f_tgt, h, w, -1)
|
| 188 |
+
freqs_temporal = torch.cat([freqs_src_t, freqs_tgt_t], dim=0)
|
| 189 |
+
else:
|
| 190 |
+
# Default: continuous temporal positions [0, f-1]
|
| 191 |
+
freqs_temporal = freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1)
|
| 192 |
+
|
| 193 |
+
# Combine temporal + spatial freqs
|
| 194 |
+
freqs_i = torch.cat([
|
| 195 |
+
freqs_temporal,
|
| 196 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 197 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 198 |
+
], dim=-1).reshape(seq_len, 1, -1)
|
| 199 |
+
|
| 200 |
+
# apply rotary embedding
|
| 201 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 202 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 203 |
+
# append to collection
|
| 204 |
+
output.append(x_i)
|
| 205 |
+
return torch.stack(output).to(x.dtype)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def rope_apply_qk(q, k, grid_sizes, freqs, frame_split_indices=None, ground_frame_indices=None):
|
| 209 |
+
q = rope_apply(q, grid_sizes, freqs, frame_split_indices, ground_frame_indices)
|
| 210 |
+
k = rope_apply(k, grid_sizes, freqs, frame_split_indices, ground_frame_indices)
|
| 211 |
+
return q, k
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class WanRMSNorm(nn.Module):
|
| 215 |
+
|
| 216 |
+
def __init__(self, dim, eps=1e-5):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.dim = dim
|
| 219 |
+
self.eps = eps
|
| 220 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
r"""
|
| 224 |
+
Args:
|
| 225 |
+
x(Tensor): Shape [B, L, C]
|
| 226 |
+
"""
|
| 227 |
+
return self._norm(x) * self.weight
|
| 228 |
+
|
| 229 |
+
def _norm(self, x):
|
| 230 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(x.dtype)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 234 |
+
|
| 235 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 236 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 237 |
+
|
| 238 |
+
def forward(self, x):
|
| 239 |
+
r"""
|
| 240 |
+
Args:
|
| 241 |
+
x(Tensor): Shape [B, L, C]
|
| 242 |
+
"""
|
| 243 |
+
return super().forward(x)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class WanSelfAttention(nn.Module):
|
| 247 |
+
|
| 248 |
+
def __init__(self,
|
| 249 |
+
dim,
|
| 250 |
+
num_heads,
|
| 251 |
+
window_size=(-1, -1),
|
| 252 |
+
qk_norm=True,
|
| 253 |
+
eps=1e-6):
|
| 254 |
+
assert dim % num_heads == 0
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.dim = dim
|
| 257 |
+
self.num_heads = num_heads
|
| 258 |
+
self.head_dim = dim // num_heads
|
| 259 |
+
self.window_size = window_size
|
| 260 |
+
self.qk_norm = qk_norm
|
| 261 |
+
self.eps = eps
|
| 262 |
+
|
| 263 |
+
# layers
|
| 264 |
+
self.q = nn.Linear(dim, dim)
|
| 265 |
+
self.k = nn.Linear(dim, dim)
|
| 266 |
+
self.v = nn.Linear(dim, dim)
|
| 267 |
+
self.o = nn.Linear(dim, dim)
|
| 268 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 269 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 270 |
+
|
| 271 |
+
def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0, frame_split_indices=None, ground_frame_indices=None):
|
| 272 |
+
r"""
|
| 273 |
+
Args:
|
| 274 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 275 |
+
seq_lens(Tensor): Shape [B]
|
| 276 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 277 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 278 |
+
frame_split_indices(List[int], optional): Split indices for paired data temporal RoPE
|
| 279 |
+
ground_frame_indices(List[Tuple[int, int]], optional): Ground frame positions for special temporal RoPE
|
| 280 |
+
"""
|
| 281 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 282 |
+
|
| 283 |
+
# query, key, value function
|
| 284 |
+
def qkv_fn(x):
|
| 285 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d)
|
| 286 |
+
k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d)
|
| 287 |
+
v = self.v(x.to(dtype)).view(b, s, n, d)
|
| 288 |
+
return q, k, v
|
| 289 |
+
|
| 290 |
+
q, k, v = qkv_fn(x)
|
| 291 |
+
|
| 292 |
+
q, k = rope_apply_qk(q, k, grid_sizes, freqs, frame_split_indices, ground_frame_indices)
|
| 293 |
+
|
| 294 |
+
x = attention(
|
| 295 |
+
q.to(dtype),
|
| 296 |
+
k.to(dtype),
|
| 297 |
+
v=v.to(dtype),
|
| 298 |
+
k_lens=seq_lens,
|
| 299 |
+
window_size=self.window_size)
|
| 300 |
+
x = x.to(dtype)
|
| 301 |
+
|
| 302 |
+
# output
|
| 303 |
+
x = x.flatten(2)
|
| 304 |
+
x = self.o(x)
|
| 305 |
+
return x
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class WanT2VCrossAttention(WanSelfAttention):
|
| 309 |
+
|
| 310 |
+
def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 311 |
+
r"""
|
| 312 |
+
Args:
|
| 313 |
+
x(Tensor): Shape [B, L1, C]
|
| 314 |
+
context(Tensor): Shape [B, L2, C]
|
| 315 |
+
context_lens(Tensor): Shape [B]
|
| 316 |
+
"""
|
| 317 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 318 |
+
|
| 319 |
+
# compute query, key, value
|
| 320 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
|
| 321 |
+
k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
|
| 322 |
+
v = self.v(context.to(dtype)).view(b, -1, n, d)
|
| 323 |
+
|
| 324 |
+
# compute attention
|
| 325 |
+
x = attention(
|
| 326 |
+
q.to(dtype),
|
| 327 |
+
k.to(dtype),
|
| 328 |
+
v.to(dtype),
|
| 329 |
+
k_lens=context_lens
|
| 330 |
+
)
|
| 331 |
+
x = x.to(dtype)
|
| 332 |
+
|
| 333 |
+
# output
|
| 334 |
+
x = x.flatten(2)
|
| 335 |
+
x = self.o(x)
|
| 336 |
+
return x
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
| 340 |
+
|
| 341 |
+
def __init__(self,
|
| 342 |
+
dim,
|
| 343 |
+
num_heads,
|
| 344 |
+
window_size=(-1, -1),
|
| 345 |
+
qk_norm=True,
|
| 346 |
+
eps=1e-6):
|
| 347 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
| 348 |
+
|
| 349 |
+
self.k_img = nn.Linear(dim, dim)
|
| 350 |
+
self.v_img = nn.Linear(dim, dim)
|
| 351 |
+
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
| 352 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 353 |
+
|
| 354 |
+
def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 355 |
+
r"""
|
| 356 |
+
Args:
|
| 357 |
+
x(Tensor): Shape [B, L1, C]
|
| 358 |
+
context(Tensor): Shape [B, L2, C]
|
| 359 |
+
context_lens(Tensor): Shape [B]
|
| 360 |
+
"""
|
| 361 |
+
context_img = context[:, :257]
|
| 362 |
+
context = context[:, 257:]
|
| 363 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 364 |
+
|
| 365 |
+
# compute query, key, value
|
| 366 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
|
| 367 |
+
k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
|
| 368 |
+
v = self.v(context.to(dtype)).view(b, -1, n, d)
|
| 369 |
+
k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d)
|
| 370 |
+
v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d)
|
| 371 |
+
|
| 372 |
+
img_x = attention(
|
| 373 |
+
q.to(dtype),
|
| 374 |
+
k_img.to(dtype),
|
| 375 |
+
v_img.to(dtype),
|
| 376 |
+
k_lens=None
|
| 377 |
+
)
|
| 378 |
+
img_x = img_x.to(dtype)
|
| 379 |
+
# compute attention
|
| 380 |
+
x = attention(
|
| 381 |
+
q.to(dtype),
|
| 382 |
+
k.to(dtype),
|
| 383 |
+
v.to(dtype),
|
| 384 |
+
k_lens=context_lens
|
| 385 |
+
)
|
| 386 |
+
x = x.to(dtype)
|
| 387 |
+
|
| 388 |
+
# output
|
| 389 |
+
x = x.flatten(2)
|
| 390 |
+
img_x = img_x.flatten(2)
|
| 391 |
+
x = x + img_x
|
| 392 |
+
x = self.o(x)
|
| 393 |
+
return x
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class WanCrossAttention(WanSelfAttention):
|
| 397 |
+
def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 398 |
+
r"""
|
| 399 |
+
Args:
|
| 400 |
+
x(Tensor): Shape [B, L1, C]
|
| 401 |
+
context(Tensor): Shape [B, L2, C]
|
| 402 |
+
context_lens(Tensor): Shape [B]
|
| 403 |
+
"""
|
| 404 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 405 |
+
# compute query, key, value
|
| 406 |
+
q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
|
| 407 |
+
k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
|
| 408 |
+
v = self.v(context.to(dtype)).view(b, -1, n, d)
|
| 409 |
+
# compute attention
|
| 410 |
+
x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens)
|
| 411 |
+
# output
|
| 412 |
+
x = x.flatten(2)
|
| 413 |
+
x = self.o(x.to(dtype))
|
| 414 |
+
return x
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
WAN_CROSSATTENTION_CLASSES = {
|
| 418 |
+
't2v_cross_attn': WanT2VCrossAttention,
|
| 419 |
+
'i2v_cross_attn': WanI2VCrossAttention,
|
| 420 |
+
'cross_attn': WanCrossAttention,
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class WanAttentionBlock(nn.Module):
|
| 425 |
+
|
| 426 |
+
def __init__(self,
|
| 427 |
+
cross_attn_type,
|
| 428 |
+
dim,
|
| 429 |
+
ffn_dim,
|
| 430 |
+
num_heads,
|
| 431 |
+
window_size=(-1, -1),
|
| 432 |
+
qk_norm=True,
|
| 433 |
+
cross_attn_norm=False,
|
| 434 |
+
eps=1e-6):
|
| 435 |
+
super().__init__()
|
| 436 |
+
self.dim = dim
|
| 437 |
+
self.ffn_dim = ffn_dim
|
| 438 |
+
self.num_heads = num_heads
|
| 439 |
+
self.window_size = window_size
|
| 440 |
+
self.qk_norm = qk_norm
|
| 441 |
+
self.cross_attn_norm = cross_attn_norm
|
| 442 |
+
self.eps = eps
|
| 443 |
+
|
| 444 |
+
# layers
|
| 445 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 446 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
| 447 |
+
eps)
|
| 448 |
+
self.norm3 = WanLayerNorm(
|
| 449 |
+
dim, eps,
|
| 450 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 451 |
+
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
| 452 |
+
num_heads,
|
| 453 |
+
(-1, -1),
|
| 454 |
+
qk_norm,
|
| 455 |
+
eps)
|
| 456 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 457 |
+
self.ffn = nn.Sequential(
|
| 458 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 459 |
+
nn.Linear(ffn_dim, dim))
|
| 460 |
+
|
| 461 |
+
# modulation
|
| 462 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 463 |
+
|
| 464 |
+
def forward(
|
| 465 |
+
self,
|
| 466 |
+
x,
|
| 467 |
+
e,
|
| 468 |
+
seq_lens,
|
| 469 |
+
grid_sizes,
|
| 470 |
+
freqs,
|
| 471 |
+
context,
|
| 472 |
+
context_lens,
|
| 473 |
+
dtype=torch.bfloat16,
|
| 474 |
+
t=0,
|
| 475 |
+
frame_split_indices=None,
|
| 476 |
+
ground_frame_indices=None,
|
| 477 |
+
):
|
| 478 |
+
r"""
|
| 479 |
+
Args:
|
| 480 |
+
x(Tensor): Shape [B, L, C]
|
| 481 |
+
e(Tensor): Shape [B, 6, C]
|
| 482 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 483 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 484 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 485 |
+
frame_split_indices(List[int], optional): Split indices for paired data temporal RoPE
|
| 486 |
+
ground_frame_indices(List[Tuple[int, int]], optional): Ground frame positions for special temporal RoPE
|
| 487 |
+
"""
|
| 488 |
+
if e.dim() > 3:
|
| 489 |
+
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
|
| 490 |
+
e = [e.squeeze(2) for e in e]
|
| 491 |
+
else:
|
| 492 |
+
e = (self.modulation + e).chunk(6, dim=1)
|
| 493 |
+
|
| 494 |
+
# self-attention
|
| 495 |
+
temp_x = self.norm1(x) * (1 + e[1]) + e[0]
|
| 496 |
+
temp_x = temp_x.to(dtype)
|
| 497 |
+
|
| 498 |
+
y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype, t=t, frame_split_indices=frame_split_indices, ground_frame_indices=ground_frame_indices)
|
| 499 |
+
x = x + y * e[2]
|
| 500 |
+
|
| 501 |
+
# cross-attention & ffn function
|
| 502 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 503 |
+
# cross-attention
|
| 504 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype, t=t)
|
| 505 |
+
|
| 506 |
+
# ffn function
|
| 507 |
+
temp_x = self.norm2(x) * (1 + e[4]) + e[3]
|
| 508 |
+
temp_x = temp_x.to(dtype)
|
| 509 |
+
|
| 510 |
+
y = self.ffn(temp_x)
|
| 511 |
+
x = x + y * e[5]
|
| 512 |
+
return x
|
| 513 |
+
|
| 514 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 515 |
+
return x
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class Head(nn.Module):
|
| 519 |
+
|
| 520 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 521 |
+
super().__init__()
|
| 522 |
+
self.dim = dim
|
| 523 |
+
self.out_dim = out_dim
|
| 524 |
+
self.patch_size = patch_size
|
| 525 |
+
self.eps = eps
|
| 526 |
+
|
| 527 |
+
# layers
|
| 528 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 529 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 530 |
+
self.head = nn.Linear(dim, out_dim)
|
| 531 |
+
|
| 532 |
+
# modulation
|
| 533 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 534 |
+
|
| 535 |
+
def forward(self, x, e):
|
| 536 |
+
r"""
|
| 537 |
+
Args:
|
| 538 |
+
x(Tensor): Shape [B, L1, C]
|
| 539 |
+
e(Tensor): Shape [B, C]
|
| 540 |
+
"""
|
| 541 |
+
if e.dim() > 2:
|
| 542 |
+
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
|
| 543 |
+
e = [e.squeeze(2) for e in e]
|
| 544 |
+
else:
|
| 545 |
+
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
| 546 |
+
|
| 547 |
+
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 548 |
+
return x
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class MLPProj(torch.nn.Module):
|
| 552 |
+
|
| 553 |
+
def __init__(self, in_dim, out_dim):
|
| 554 |
+
super().__init__()
|
| 555 |
+
|
| 556 |
+
self.proj = torch.nn.Sequential(
|
| 557 |
+
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
| 558 |
+
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
| 559 |
+
torch.nn.LayerNorm(out_dim))
|
| 560 |
+
|
| 561 |
+
def forward(self, image_embeds):
|
| 562 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
| 563 |
+
return clip_extra_context_tokens
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 568 |
+
r"""
|
| 569 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
# ignore_for_config = [
|
| 573 |
+
# 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 574 |
+
# ]
|
| 575 |
+
# _no_split_modules = ['WanAttentionBlock']
|
| 576 |
+
_supports_gradient_checkpointing = True
|
| 577 |
+
|
| 578 |
+
@register_to_config
|
| 579 |
+
def __init__(
|
| 580 |
+
self,
|
| 581 |
+
model_type='t2v',
|
| 582 |
+
patch_size=(1, 2, 2),
|
| 583 |
+
text_len=512,
|
| 584 |
+
in_dim=16,
|
| 585 |
+
dim=2048,
|
| 586 |
+
ffn_dim=8192,
|
| 587 |
+
freq_dim=256,
|
| 588 |
+
text_dim=4096,
|
| 589 |
+
out_dim=16,
|
| 590 |
+
num_heads=16,
|
| 591 |
+
num_layers=32,
|
| 592 |
+
window_size=(-1, -1),
|
| 593 |
+
qk_norm=True,
|
| 594 |
+
cross_attn_norm=True,
|
| 595 |
+
eps=1e-6,
|
| 596 |
+
in_channels=16,
|
| 597 |
+
hidden_size=2048,
|
| 598 |
+
add_control_adapter=False,
|
| 599 |
+
in_dim_control_adapter=24,
|
| 600 |
+
downscale_factor_control_adapter=8,
|
| 601 |
+
add_ref_conv=False,
|
| 602 |
+
in_dim_ref_conv=16,
|
| 603 |
+
cross_attn_type=None,
|
| 604 |
+
):
|
| 605 |
+
r"""
|
| 606 |
+
Initialize the diffusion model backbone.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 610 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 611 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 612 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 613 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 614 |
+
Fixed length for text embeddings
|
| 615 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 616 |
+
Input video channels (C_in)
|
| 617 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 618 |
+
Hidden dimension of the transformer
|
| 619 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 620 |
+
Intermediate dimension in feed-forward network
|
| 621 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 622 |
+
Dimension for sinusoidal time embeddings
|
| 623 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 624 |
+
Input dimension for text embeddings
|
| 625 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 626 |
+
Output video channels (C_out)
|
| 627 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 628 |
+
Number of attention heads
|
| 629 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 630 |
+
Number of transformer blocks
|
| 631 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 632 |
+
Window size for local attention (-1 indicates global attention)
|
| 633 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 634 |
+
Enable query/key normalization
|
| 635 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 636 |
+
Enable cross-attention normalization
|
| 637 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 638 |
+
Epsilon value for normalization layers
|
| 639 |
+
"""
|
| 640 |
+
|
| 641 |
+
super().__init__()
|
| 642 |
+
|
| 643 |
+
# assert model_type in ['t2v', 'i2v', 'ti2v']
|
| 644 |
+
self.model_type = model_type
|
| 645 |
+
|
| 646 |
+
self.patch_size = patch_size
|
| 647 |
+
self.text_len = text_len
|
| 648 |
+
self.in_dim = in_dim
|
| 649 |
+
self.dim = dim
|
| 650 |
+
self.ffn_dim = ffn_dim
|
| 651 |
+
self.freq_dim = freq_dim
|
| 652 |
+
self.text_dim = text_dim
|
| 653 |
+
self.out_dim = out_dim
|
| 654 |
+
self.num_heads = num_heads
|
| 655 |
+
self.num_layers = num_layers
|
| 656 |
+
self.window_size = window_size
|
| 657 |
+
self.qk_norm = qk_norm
|
| 658 |
+
self.cross_attn_norm = cross_attn_norm
|
| 659 |
+
self.eps = eps
|
| 660 |
+
|
| 661 |
+
# embeddings
|
| 662 |
+
self.patch_embedding = nn.Conv3d(
|
| 663 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 664 |
+
self.text_embedding = nn.Sequential(
|
| 665 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 666 |
+
nn.Linear(dim, dim))
|
| 667 |
+
|
| 668 |
+
self.time_embedding = nn.Sequential(
|
| 669 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 670 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 671 |
+
|
| 672 |
+
# blocks
|
| 673 |
+
if cross_attn_type is None:
|
| 674 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
| 675 |
+
self.blocks = nn.ModuleList([
|
| 676 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 677 |
+
window_size, qk_norm, cross_attn_norm, eps)
|
| 678 |
+
for _ in range(num_layers)
|
| 679 |
+
])
|
| 680 |
+
for layer_idx, block in enumerate(self.blocks):
|
| 681 |
+
block.self_attn.layer_idx = layer_idx
|
| 682 |
+
block.self_attn.num_layers = self.num_layers
|
| 683 |
+
|
| 684 |
+
# head
|
| 685 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 686 |
+
|
| 687 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 688 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 689 |
+
d = dim // num_heads
|
| 690 |
+
self.d = d
|
| 691 |
+
self.dim = dim
|
| 692 |
+
self.freqs = torch.cat(
|
| 693 |
+
[
|
| 694 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 695 |
+
rope_params(1024, 2 * (d // 6)),
|
| 696 |
+
rope_params(1024, 2 * (d // 6))
|
| 697 |
+
],
|
| 698 |
+
dim=1
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
if model_type == 'i2v':
|
| 702 |
+
self.img_emb = MLPProj(1280, dim)
|
| 703 |
+
|
| 704 |
+
if add_control_adapter:
|
| 705 |
+
self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], downscale_factor=downscale_factor_control_adapter)
|
| 706 |
+
else:
|
| 707 |
+
self.control_adapter = None
|
| 708 |
+
|
| 709 |
+
if add_ref_conv:
|
| 710 |
+
self.ref_conv = nn.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
|
| 711 |
+
else:
|
| 712 |
+
self.ref_conv = None
|
| 713 |
+
|
| 714 |
+
self.teacache = None
|
| 715 |
+
self.cfg_skip_ratio = None
|
| 716 |
+
self.current_steps = 0
|
| 717 |
+
self.num_inference_steps = None
|
| 718 |
+
self.gradient_checkpointing = False
|
| 719 |
+
self.sp_world_size = 1
|
| 720 |
+
self.sp_world_rank = 0
|
| 721 |
+
self.init_weights()
|
| 722 |
+
|
| 723 |
+
def _set_gradient_checkpointing(self, *args, **kwargs):
|
| 724 |
+
if "value" in kwargs:
|
| 725 |
+
self.gradient_checkpointing = kwargs["value"]
|
| 726 |
+
elif "enable" in kwargs:
|
| 727 |
+
self.gradient_checkpointing = kwargs["enable"]
|
| 728 |
+
else:
|
| 729 |
+
raise ValueError("Invalid set gradient checkpointing")
|
| 730 |
+
|
| 731 |
+
def enable_teacache(
|
| 732 |
+
self,
|
| 733 |
+
coefficients,
|
| 734 |
+
num_steps: int,
|
| 735 |
+
rel_l1_thresh: float,
|
| 736 |
+
num_skip_start_steps: int = 0,
|
| 737 |
+
offload: bool = True,
|
| 738 |
+
):
|
| 739 |
+
self.teacache = TeaCache(
|
| 740 |
+
coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
def share_teacache(
|
| 744 |
+
self,
|
| 745 |
+
transformer = None,
|
| 746 |
+
):
|
| 747 |
+
self.teacache = transformer.teacache
|
| 748 |
+
|
| 749 |
+
def disable_teacache(self):
|
| 750 |
+
self.teacache = None
|
| 751 |
+
|
| 752 |
+
def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
|
| 753 |
+
if cfg_skip_ratio != 0:
|
| 754 |
+
self.cfg_skip_ratio = cfg_skip_ratio
|
| 755 |
+
self.current_steps = 0
|
| 756 |
+
self.num_inference_steps = num_steps
|
| 757 |
+
else:
|
| 758 |
+
self.cfg_skip_ratio = None
|
| 759 |
+
self.current_steps = 0
|
| 760 |
+
self.num_inference_steps = None
|
| 761 |
+
|
| 762 |
+
def share_cfg_skip(
|
| 763 |
+
self,
|
| 764 |
+
transformer = None,
|
| 765 |
+
):
|
| 766 |
+
self.cfg_skip_ratio = transformer.cfg_skip_ratio
|
| 767 |
+
self.current_steps = transformer.current_steps
|
| 768 |
+
self.num_inference_steps = transformer.num_inference_steps
|
| 769 |
+
|
| 770 |
+
def disable_cfg_skip(self):
|
| 771 |
+
self.cfg_skip_ratio = None
|
| 772 |
+
self.current_steps = 0
|
| 773 |
+
self.num_inference_steps = None
|
| 774 |
+
|
| 775 |
+
def enable_riflex(
|
| 776 |
+
self,
|
| 777 |
+
k = 6,
|
| 778 |
+
L_test = 66,
|
| 779 |
+
L_test_scale = 4.886,
|
| 780 |
+
):
|
| 781 |
+
device = self.freqs.device
|
| 782 |
+
self.freqs = torch.cat(
|
| 783 |
+
[
|
| 784 |
+
get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale),
|
| 785 |
+
rope_params(1024, 2 * (self.d // 6)),
|
| 786 |
+
rope_params(1024, 2 * (self.d // 6))
|
| 787 |
+
],
|
| 788 |
+
dim=1
|
| 789 |
+
).to(device)
|
| 790 |
+
|
| 791 |
+
def disable_riflex(self):
|
| 792 |
+
device = self.freqs.device
|
| 793 |
+
self.freqs = torch.cat(
|
| 794 |
+
[
|
| 795 |
+
rope_params(1024, self.d - 4 * (self.d // 6)),
|
| 796 |
+
rope_params(1024, 2 * (self.d // 6)),
|
| 797 |
+
rope_params(1024, 2 * (self.d // 6))
|
| 798 |
+
],
|
| 799 |
+
dim=1
|
| 800 |
+
).to(device)
|
| 801 |
+
|
| 802 |
+
def enable_multi_gpus_inference(self,):
|
| 803 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 804 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 805 |
+
self.all_gather = get_sp_group().all_gather
|
| 806 |
+
|
| 807 |
+
# For normal model.
|
| 808 |
+
for block in self.blocks:
|
| 809 |
+
block.self_attn.forward = types.MethodType(
|
| 810 |
+
usp_attn_forward, block.self_attn)
|
| 811 |
+
|
| 812 |
+
# For vace model.
|
| 813 |
+
if hasattr(self, 'vace_blocks'):
|
| 814 |
+
for block in self.vace_blocks:
|
| 815 |
+
block.self_attn.forward = types.MethodType(
|
| 816 |
+
usp_attn_forward, block.self_attn)
|
| 817 |
+
|
| 818 |
+
@cfg_skip()
|
| 819 |
+
def forward(
|
| 820 |
+
self,
|
| 821 |
+
x,
|
| 822 |
+
t,
|
| 823 |
+
context,
|
| 824 |
+
seq_len,
|
| 825 |
+
clip_fea=None,
|
| 826 |
+
y=None,
|
| 827 |
+
y_camera=None,
|
| 828 |
+
full_ref=None,
|
| 829 |
+
subject_ref=None,
|
| 830 |
+
cond_flag=True,
|
| 831 |
+
frame_split_indices=None,
|
| 832 |
+
ground_frame_indices=None,
|
| 833 |
+
):
|
| 834 |
+
r"""
|
| 835 |
+
Forward pass through the diffusion model
|
| 836 |
+
|
| 837 |
+
Args:
|
| 838 |
+
x (List[Tensor]):
|
| 839 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 840 |
+
t (Tensor):
|
| 841 |
+
Diffusion timesteps tensor of shape [B]
|
| 842 |
+
context (List[Tensor]):
|
| 843 |
+
List of text embeddings each with shape [L, C]
|
| 844 |
+
seq_len (`int`):
|
| 845 |
+
Maximum sequence length for positional encoding
|
| 846 |
+
clip_fea (Tensor, *optional*):
|
| 847 |
+
CLIP image features for image-to-video mode
|
| 848 |
+
y (List[Tensor], *optional*):
|
| 849 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 850 |
+
cond_flag (`bool`, *optional*, defaults to True):
|
| 851 |
+
Flag to indicate whether to forward the condition input
|
| 852 |
+
|
| 853 |
+
Returns:
|
| 854 |
+
List[Tensor]:
|
| 855 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 856 |
+
"""
|
| 857 |
+
# Wan2.2 don't need a clip.
|
| 858 |
+
# if self.model_type == 'i2v':
|
| 859 |
+
# assert clip_fea is not None and y is not None
|
| 860 |
+
# params
|
| 861 |
+
device = self.patch_embedding.weight.device
|
| 862 |
+
dtype = x.dtype
|
| 863 |
+
if self.freqs.device != device and torch.device(type="meta") != device:
|
| 864 |
+
self.freqs = self.freqs.to(device)
|
| 865 |
+
|
| 866 |
+
if y is not None:
|
| 867 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 868 |
+
|
| 869 |
+
# embeddings
|
| 870 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 871 |
+
# add control adapter
|
| 872 |
+
if self.control_adapter is not None and y_camera is not None:
|
| 873 |
+
y_camera = self.control_adapter(y_camera)
|
| 874 |
+
x = [u + v for u, v in zip(x, y_camera)]
|
| 875 |
+
|
| 876 |
+
grid_sizes = torch.stack(
|
| 877 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 878 |
+
|
| 879 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 880 |
+
if self.ref_conv is not None and full_ref is not None:
|
| 881 |
+
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
| 882 |
+
grid_sizes = torch.stack([torch.tensor([u[0] + 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
|
| 883 |
+
seq_len += full_ref.size(1)
|
| 884 |
+
x = [torch.concat([_full_ref.unsqueeze(0), u], dim=1) for _full_ref, u in zip(full_ref, x)]
|
| 885 |
+
if t.dim() != 1 and t.size(1) < seq_len:
|
| 886 |
+
pad_size = seq_len - t.size(1)
|
| 887 |
+
last_elements = t[:, -1].unsqueeze(1)
|
| 888 |
+
padding = last_elements.repeat(1, pad_size)
|
| 889 |
+
t = torch.cat([padding, t], dim=1)
|
| 890 |
+
|
| 891 |
+
if subject_ref is not None:
|
| 892 |
+
subject_ref_frames = subject_ref.size(2)
|
| 893 |
+
subject_ref = self.patch_embedding(subject_ref).flatten(2).transpose(1, 2)
|
| 894 |
+
grid_sizes = torch.stack([torch.tensor([u[0] + subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
|
| 895 |
+
seq_len += subject_ref.size(1)
|
| 896 |
+
x = [torch.concat([u, _subject_ref.unsqueeze(0)], dim=1) for _subject_ref, u in zip(subject_ref, x)]
|
| 897 |
+
if t.dim() != 1 and t.size(1) < seq_len:
|
| 898 |
+
pad_size = seq_len - t.size(1)
|
| 899 |
+
last_elements = t[:, -1].unsqueeze(1)
|
| 900 |
+
padding = last_elements.repeat(1, pad_size)
|
| 901 |
+
t = torch.cat([t, padding], dim=1)
|
| 902 |
+
|
| 903 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 904 |
+
if self.sp_world_size > 1:
|
| 905 |
+
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
|
| 906 |
+
assert seq_lens.max() <= seq_len
|
| 907 |
+
x = torch.cat([
|
| 908 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 909 |
+
dim=1) for u in x
|
| 910 |
+
])
|
| 911 |
+
|
| 912 |
+
# time embeddings
|
| 913 |
+
with amp.autocast(dtype=torch.float32):
|
| 914 |
+
if t.dim() != 1:
|
| 915 |
+
if t.size(1) < seq_len:
|
| 916 |
+
pad_size = seq_len - t.size(1)
|
| 917 |
+
last_elements = t[:, -1].unsqueeze(1)
|
| 918 |
+
padding = last_elements.repeat(1, pad_size)
|
| 919 |
+
t = torch.cat([t, padding], dim=1)
|
| 920 |
+
bt = t.size(0)
|
| 921 |
+
ft = t.flatten()
|
| 922 |
+
e = self.time_embedding(
|
| 923 |
+
sinusoidal_embedding_1d(self.freq_dim,
|
| 924 |
+
ft).unflatten(0, (bt, seq_len)).float())
|
| 925 |
+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
| 926 |
+
else:
|
| 927 |
+
e = self.time_embedding(
|
| 928 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 929 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 930 |
+
|
| 931 |
+
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 932 |
+
# e0 = e0.to(dtype)
|
| 933 |
+
# e = e.to(dtype)
|
| 934 |
+
|
| 935 |
+
# context
|
| 936 |
+
context_lens = None
|
| 937 |
+
context = self.text_embedding(
|
| 938 |
+
torch.stack([
|
| 939 |
+
torch.cat(
|
| 940 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 941 |
+
for u in context
|
| 942 |
+
]))
|
| 943 |
+
|
| 944 |
+
if clip_fea is not None:
|
| 945 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 946 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 947 |
+
|
| 948 |
+
# Context Parallel
|
| 949 |
+
if self.sp_world_size > 1:
|
| 950 |
+
x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 951 |
+
if t.dim() != 1:
|
| 952 |
+
e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 953 |
+
e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 954 |
+
|
| 955 |
+
# TeaCache
|
| 956 |
+
if self.teacache is not None:
|
| 957 |
+
if cond_flag:
|
| 958 |
+
if t.dim() != 1:
|
| 959 |
+
modulated_inp = e0[:, -1, :]
|
| 960 |
+
else:
|
| 961 |
+
modulated_inp = e0
|
| 962 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 963 |
+
if skip_flag:
|
| 964 |
+
self.should_calc = True
|
| 965 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 966 |
+
else:
|
| 967 |
+
if cond_flag:
|
| 968 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 969 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 970 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 971 |
+
self.should_calc = False
|
| 972 |
+
else:
|
| 973 |
+
self.should_calc = True
|
| 974 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 975 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 976 |
+
self.teacache.should_calc = self.should_calc
|
| 977 |
+
else:
|
| 978 |
+
self.should_calc = self.teacache.should_calc
|
| 979 |
+
|
| 980 |
+
# TeaCache
|
| 981 |
+
if self.teacache is not None:
|
| 982 |
+
if not self.should_calc:
|
| 983 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 984 |
+
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
|
| 985 |
+
else:
|
| 986 |
+
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
|
| 987 |
+
|
| 988 |
+
for block in self.blocks:
|
| 989 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 990 |
+
|
| 991 |
+
def create_custom_forward(module):
|
| 992 |
+
def custom_forward(*inputs):
|
| 993 |
+
return module(*inputs)
|
| 994 |
+
|
| 995 |
+
return custom_forward
|
| 996 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 997 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 998 |
+
create_custom_forward(block),
|
| 999 |
+
x,
|
| 1000 |
+
e0,
|
| 1001 |
+
seq_lens,
|
| 1002 |
+
grid_sizes,
|
| 1003 |
+
self.freqs,
|
| 1004 |
+
context,
|
| 1005 |
+
context_lens,
|
| 1006 |
+
dtype,
|
| 1007 |
+
t,
|
| 1008 |
+
frame_split_indices,
|
| 1009 |
+
ground_frame_indices,
|
| 1010 |
+
**ckpt_kwargs,
|
| 1011 |
+
)
|
| 1012 |
+
else:
|
| 1013 |
+
# arguments
|
| 1014 |
+
kwargs = dict(
|
| 1015 |
+
e=e0,
|
| 1016 |
+
seq_lens=seq_lens,
|
| 1017 |
+
grid_sizes=grid_sizes,
|
| 1018 |
+
freqs=self.freqs,
|
| 1019 |
+
context=context,
|
| 1020 |
+
context_lens=context_lens,
|
| 1021 |
+
dtype=dtype,
|
| 1022 |
+
t=t,
|
| 1023 |
+
frame_split_indices=frame_split_indices,
|
| 1024 |
+
ground_frame_indices=ground_frame_indices,
|
| 1025 |
+
)
|
| 1026 |
+
x = block(x, **kwargs)
|
| 1027 |
+
|
| 1028 |
+
if cond_flag:
|
| 1029 |
+
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 1030 |
+
else:
|
| 1031 |
+
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 1032 |
+
else:
|
| 1033 |
+
for block in self.blocks:
|
| 1034 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1035 |
+
|
| 1036 |
+
def create_custom_forward(module):
|
| 1037 |
+
def custom_forward(*inputs):
|
| 1038 |
+
return module(*inputs)
|
| 1039 |
+
|
| 1040 |
+
return custom_forward
|
| 1041 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 1042 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 1043 |
+
create_custom_forward(block),
|
| 1044 |
+
x,
|
| 1045 |
+
e0,
|
| 1046 |
+
seq_lens,
|
| 1047 |
+
grid_sizes,
|
| 1048 |
+
self.freqs,
|
| 1049 |
+
context,
|
| 1050 |
+
context_lens,
|
| 1051 |
+
dtype,
|
| 1052 |
+
t,
|
| 1053 |
+
frame_split_indices,
|
| 1054 |
+
ground_frame_indices,
|
| 1055 |
+
**ckpt_kwargs,
|
| 1056 |
+
)
|
| 1057 |
+
else:
|
| 1058 |
+
# arguments
|
| 1059 |
+
kwargs = dict(
|
| 1060 |
+
e=e0,
|
| 1061 |
+
seq_lens=seq_lens,
|
| 1062 |
+
grid_sizes=grid_sizes,
|
| 1063 |
+
freqs=self.freqs,
|
| 1064 |
+
context=context,
|
| 1065 |
+
context_lens=context_lens,
|
| 1066 |
+
dtype=dtype,
|
| 1067 |
+
t=t,
|
| 1068 |
+
frame_split_indices=frame_split_indices,
|
| 1069 |
+
ground_frame_indices=ground_frame_indices,
|
| 1070 |
+
)
|
| 1071 |
+
x = block(x, **kwargs)
|
| 1072 |
+
|
| 1073 |
+
# head
|
| 1074 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1075 |
+
def create_custom_forward(module):
|
| 1076 |
+
def custom_forward(*inputs):
|
| 1077 |
+
return module(*inputs)
|
| 1078 |
+
|
| 1079 |
+
return custom_forward
|
| 1080 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 1081 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
|
| 1082 |
+
else:
|
| 1083 |
+
x = self.head(x, e)
|
| 1084 |
+
|
| 1085 |
+
if self.sp_world_size > 1:
|
| 1086 |
+
x = self.all_gather(x, dim=1)
|
| 1087 |
+
|
| 1088 |
+
if self.ref_conv is not None and full_ref is not None:
|
| 1089 |
+
full_ref_length = full_ref.size(1)
|
| 1090 |
+
x = x[:, full_ref_length:]
|
| 1091 |
+
grid_sizes = torch.stack([torch.tensor([u[0] - 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
|
| 1092 |
+
|
| 1093 |
+
if subject_ref is not None:
|
| 1094 |
+
subject_ref_length = subject_ref.size(1)
|
| 1095 |
+
x = x[:, :-subject_ref_length]
|
| 1096 |
+
grid_sizes = torch.stack([torch.tensor([u[0] - subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
|
| 1097 |
+
|
| 1098 |
+
# unpatchify
|
| 1099 |
+
x = self.unpatchify(x, grid_sizes)
|
| 1100 |
+
x = torch.stack(x)
|
| 1101 |
+
if self.teacache is not None and cond_flag:
|
| 1102 |
+
self.teacache.cnt += 1
|
| 1103 |
+
if self.teacache.cnt == self.teacache.num_steps:
|
| 1104 |
+
self.teacache.reset()
|
| 1105 |
+
return x
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
def unpatchify(self, x, grid_sizes):
|
| 1109 |
+
r"""
|
| 1110 |
+
Reconstruct video tensors from patch embeddings.
|
| 1111 |
+
|
| 1112 |
+
Args:
|
| 1113 |
+
x (List[Tensor]):
|
| 1114 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 1115 |
+
grid_sizes (Tensor):
|
| 1116 |
+
Original spatial-temporal grid dimensions before patching,
|
| 1117 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 1118 |
+
|
| 1119 |
+
Returns:
|
| 1120 |
+
List[Tensor]:
|
| 1121 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 1122 |
+
"""
|
| 1123 |
+
|
| 1124 |
+
c = self.out_dim
|
| 1125 |
+
out = []
|
| 1126 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 1127 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 1128 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 1129 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 1130 |
+
out.append(u)
|
| 1131 |
+
return out
|
| 1132 |
+
|
| 1133 |
+
def init_weights(self):
|
| 1134 |
+
r"""
|
| 1135 |
+
Initialize model parameters using Xavier initialization.
|
| 1136 |
+
"""
|
| 1137 |
+
|
| 1138 |
+
# basic init
|
| 1139 |
+
for m in self.modules():
|
| 1140 |
+
if isinstance(m, nn.Linear):
|
| 1141 |
+
nn.init.xavier_uniform_(m.weight)
|
| 1142 |
+
if m.bias is not None:
|
| 1143 |
+
nn.init.zeros_(m.bias)
|
| 1144 |
+
|
| 1145 |
+
# init embeddings
|
| 1146 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 1147 |
+
for m in self.text_embedding.modules():
|
| 1148 |
+
if isinstance(m, nn.Linear):
|
| 1149 |
+
nn.init.normal_(m.weight, std=.02)
|
| 1150 |
+
for m in self.time_embedding.modules():
|
| 1151 |
+
if isinstance(m, nn.Linear):
|
| 1152 |
+
nn.init.normal_(m.weight, std=.02)
|
| 1153 |
+
|
| 1154 |
+
# init output layer
|
| 1155 |
+
nn.init.zeros_(self.head.head.weight)
|
| 1156 |
+
|
| 1157 |
+
@classmethod
|
| 1158 |
+
def from_pretrained(
|
| 1159 |
+
cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
|
| 1160 |
+
low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
| 1161 |
+
):
|
| 1162 |
+
if subfolder is not None:
|
| 1163 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
| 1164 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
| 1165 |
+
|
| 1166 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
| 1167 |
+
if not os.path.isfile(config_file):
|
| 1168 |
+
raise RuntimeError(f"{config_file} does not exist")
|
| 1169 |
+
with open(config_file, "r") as f:
|
| 1170 |
+
config = json.load(f)
|
| 1171 |
+
|
| 1172 |
+
from diffusers.utils import WEIGHTS_NAME
|
| 1173 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
| 1174 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
| 1175 |
+
|
| 1176 |
+
if "dict_mapping" in transformer_additional_kwargs.keys():
|
| 1177 |
+
for key in transformer_additional_kwargs["dict_mapping"]:
|
| 1178 |
+
transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
|
| 1179 |
+
|
| 1180 |
+
if low_cpu_mem_usage:
|
| 1181 |
+
try:
|
| 1182 |
+
import re
|
| 1183 |
+
|
| 1184 |
+
from diffusers import __version__ as diffusers_version
|
| 1185 |
+
if diffusers_version >= "0.33.0":
|
| 1186 |
+
from diffusers.models.model_loading_utils import \
|
| 1187 |
+
load_model_dict_into_meta
|
| 1188 |
+
else:
|
| 1189 |
+
from diffusers.models.modeling_utils import \
|
| 1190 |
+
load_model_dict_into_meta
|
| 1191 |
+
from diffusers.utils import is_accelerate_available
|
| 1192 |
+
if is_accelerate_available():
|
| 1193 |
+
import accelerate
|
| 1194 |
+
|
| 1195 |
+
# Instantiate model with empty weights
|
| 1196 |
+
with accelerate.init_empty_weights():
|
| 1197 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1198 |
+
|
| 1199 |
+
param_device = "cpu"
|
| 1200 |
+
if os.path.exists(model_file):
|
| 1201 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1202 |
+
elif os.path.exists(model_file_safetensors):
|
| 1203 |
+
from safetensors.torch import load_file, safe_open
|
| 1204 |
+
state_dict = load_file(model_file_safetensors)
|
| 1205 |
+
else:
|
| 1206 |
+
from safetensors.torch import load_file, safe_open
|
| 1207 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1208 |
+
state_dict = {}
|
| 1209 |
+
print(model_files_safetensors)
|
| 1210 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1211 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1212 |
+
for key in _state_dict:
|
| 1213 |
+
state_dict[key] = _state_dict[key]
|
| 1214 |
+
|
| 1215 |
+
if diffusers_version >= "0.33.0":
|
| 1216 |
+
# Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
|
| 1217 |
+
# https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
|
| 1218 |
+
load_model_dict_into_meta(
|
| 1219 |
+
model,
|
| 1220 |
+
state_dict,
|
| 1221 |
+
dtype=torch_dtype,
|
| 1222 |
+
model_name_or_path=pretrained_model_path,
|
| 1223 |
+
)
|
| 1224 |
+
else:
|
| 1225 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
| 1226 |
+
# move the params from meta device to cpu
|
| 1227 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 1228 |
+
if len(missing_keys) > 0:
|
| 1229 |
+
raise ValueError(
|
| 1230 |
+
f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
|
| 1231 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 1232 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
| 1233 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
unexpected_keys = load_model_dict_into_meta(
|
| 1237 |
+
model,
|
| 1238 |
+
state_dict,
|
| 1239 |
+
device=param_device,
|
| 1240 |
+
dtype=torch_dtype,
|
| 1241 |
+
model_name_or_path=pretrained_model_path,
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
| 1245 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
| 1246 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
| 1247 |
+
|
| 1248 |
+
if len(unexpected_keys) > 0:
|
| 1249 |
+
print(
|
| 1250 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
return model
|
| 1254 |
+
except Exception as e:
|
| 1255 |
+
print(
|
| 1256 |
+
f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
|
| 1257 |
+
)
|
| 1258 |
+
|
| 1259 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
| 1260 |
+
if os.path.exists(model_file):
|
| 1261 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 1262 |
+
elif os.path.exists(model_file_safetensors):
|
| 1263 |
+
from safetensors.torch import load_file, safe_open
|
| 1264 |
+
state_dict = load_file(model_file_safetensors)
|
| 1265 |
+
else:
|
| 1266 |
+
from safetensors.torch import load_file, safe_open
|
| 1267 |
+
model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
|
| 1268 |
+
state_dict = {}
|
| 1269 |
+
for _model_file_safetensors in model_files_safetensors:
|
| 1270 |
+
_state_dict = load_file(_model_file_safetensors)
|
| 1271 |
+
for key in _state_dict:
|
| 1272 |
+
state_dict[key] = _state_dict[key]
|
| 1273 |
+
|
| 1274 |
+
if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
|
| 1275 |
+
model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :]
|
| 1276 |
+
model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
|
| 1277 |
+
state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
|
| 1278 |
+
|
| 1279 |
+
tmp_state_dict = {}
|
| 1280 |
+
for key in state_dict:
|
| 1281 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
| 1282 |
+
tmp_state_dict[key] = state_dict[key]
|
| 1283 |
+
else:
|
| 1284 |
+
print(key, "Size don't match, skip")
|
| 1285 |
+
|
| 1286 |
+
state_dict = tmp_state_dict
|
| 1287 |
+
|
| 1288 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1289 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1290 |
+
print(m)
|
| 1291 |
+
|
| 1292 |
+
params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
|
| 1293 |
+
print(f"### All Parameters: {sum(params) / 1e6} M")
|
| 1294 |
+
|
| 1295 |
+
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
|
| 1296 |
+
print(f"### attn1 Parameters: {sum(params) / 1e6} M")
|
| 1297 |
+
|
| 1298 |
+
model = model.to(torch_dtype)
|
| 1299 |
+
return model
|
| 1300 |
+
|
| 1301 |
+
|
| 1302 |
+
class Wan2_2Transformer3DModel(WanTransformer3DModel):
|
| 1303 |
+
r"""
|
| 1304 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 1305 |
+
"""
|
| 1306 |
+
|
| 1307 |
+
# ignore_for_config = [
|
| 1308 |
+
# 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 1309 |
+
# ]
|
| 1310 |
+
# _no_split_modules = ['WanAttentionBlock']
|
| 1311 |
+
_supports_gradient_checkpointing = True
|
| 1312 |
+
|
| 1313 |
+
def __init__(
|
| 1314 |
+
self,
|
| 1315 |
+
model_type='t2v',
|
| 1316 |
+
patch_size=(1, 2, 2),
|
| 1317 |
+
text_len=512,
|
| 1318 |
+
in_dim=16,
|
| 1319 |
+
dim=2048,
|
| 1320 |
+
ffn_dim=8192,
|
| 1321 |
+
freq_dim=256,
|
| 1322 |
+
text_dim=4096,
|
| 1323 |
+
out_dim=16,
|
| 1324 |
+
num_heads=16,
|
| 1325 |
+
num_layers=32,
|
| 1326 |
+
window_size=(-1, -1),
|
| 1327 |
+
qk_norm=True,
|
| 1328 |
+
cross_attn_norm=True,
|
| 1329 |
+
eps=1e-6,
|
| 1330 |
+
in_channels=16,
|
| 1331 |
+
hidden_size=2048,
|
| 1332 |
+
add_control_adapter=False,
|
| 1333 |
+
in_dim_control_adapter=24,
|
| 1334 |
+
downscale_factor_control_adapter=8,
|
| 1335 |
+
add_ref_conv=False,
|
| 1336 |
+
in_dim_ref_conv=16,
|
| 1337 |
+
):
|
| 1338 |
+
r"""
|
| 1339 |
+
Initialize the diffusion model backbone.
|
| 1340 |
+
Args:
|
| 1341 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 1342 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 1343 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 1344 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 1345 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 1346 |
+
Fixed length for text embeddings
|
| 1347 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 1348 |
+
Input video channels (C_in)
|
| 1349 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 1350 |
+
Hidden dimension of the transformer
|
| 1351 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 1352 |
+
Intermediate dimension in feed-forward network
|
| 1353 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 1354 |
+
Dimension for sinusoidal time embeddings
|
| 1355 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 1356 |
+
Input dimension for text embeddings
|
| 1357 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 1358 |
+
Output video channels (C_out)
|
| 1359 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 1360 |
+
Number of attention heads
|
| 1361 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 1362 |
+
Number of transformer blocks
|
| 1363 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 1364 |
+
Window size for local attention (-1 indicates global attention)
|
| 1365 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 1366 |
+
Enable query/key normalization
|
| 1367 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 1368 |
+
Enable cross-attention normalization
|
| 1369 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 1370 |
+
Epsilon value for normalization layers
|
| 1371 |
+
"""
|
| 1372 |
+
super().__init__(
|
| 1373 |
+
model_type=model_type,
|
| 1374 |
+
patch_size=patch_size,
|
| 1375 |
+
text_len=text_len,
|
| 1376 |
+
in_dim=in_dim,
|
| 1377 |
+
dim=dim,
|
| 1378 |
+
ffn_dim=ffn_dim,
|
| 1379 |
+
freq_dim=freq_dim,
|
| 1380 |
+
text_dim=text_dim,
|
| 1381 |
+
out_dim=out_dim,
|
| 1382 |
+
num_heads=num_heads,
|
| 1383 |
+
num_layers=num_layers,
|
| 1384 |
+
window_size=window_size,
|
| 1385 |
+
qk_norm=qk_norm,
|
| 1386 |
+
cross_attn_norm=cross_attn_norm,
|
| 1387 |
+
eps=eps,
|
| 1388 |
+
in_channels=in_channels,
|
| 1389 |
+
hidden_size=hidden_size,
|
| 1390 |
+
add_control_adapter=add_control_adapter,
|
| 1391 |
+
in_dim_control_adapter=in_dim_control_adapter,
|
| 1392 |
+
downscale_factor_control_adapter=downscale_factor_control_adapter,
|
| 1393 |
+
add_ref_conv=add_ref_conv,
|
| 1394 |
+
in_dim_ref_conv=in_dim_ref_conv,
|
| 1395 |
+
cross_attn_type="cross_attn"
|
| 1396 |
+
)
|
| 1397 |
+
|
| 1398 |
+
if hasattr(self, "img_emb"):
|
| 1399 |
+
del self.img_emb
|
videox_fun/models/wan_transformer3d_s2v.py
ADDED
|
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/model_s2v.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import types
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from typing import Any, Dict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.cuda.amp as amp
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from diffusers.configuration_utils import register_to_config
|
| 13 |
+
from diffusers.utils import is_torch_version
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
|
| 16 |
+
from ..dist import (get_sequence_parallel_rank,
|
| 17 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 18 |
+
usp_attn_s2v_forward)
|
| 19 |
+
from .attention_utils import attention
|
| 20 |
+
from .wan_audio_injector import (AudioInjector_WAN, CausalAudioEncoder,
|
| 21 |
+
FramePackMotioner, MotionerTransformers,
|
| 22 |
+
rope_precompute)
|
| 23 |
+
from .wan_transformer3d import (Wan2_2Transformer3DModel, WanAttentionBlock,
|
| 24 |
+
WanLayerNorm, WanSelfAttention,
|
| 25 |
+
sinusoidal_embedding_1d)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def zero_module(module):
|
| 29 |
+
"""
|
| 30 |
+
Zero out the parameters of a module and return it.
|
| 31 |
+
"""
|
| 32 |
+
for p in module.parameters():
|
| 33 |
+
p.detach().zero_()
|
| 34 |
+
return module
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def torch_dfs(model: nn.Module, parent_name='root'):
|
| 38 |
+
module_names, modules = [], []
|
| 39 |
+
current_name = parent_name if parent_name else 'root'
|
| 40 |
+
module_names.append(current_name)
|
| 41 |
+
modules.append(model)
|
| 42 |
+
|
| 43 |
+
for name, child in model.named_children():
|
| 44 |
+
if parent_name:
|
| 45 |
+
child_name = f'{parent_name}.{name}'
|
| 46 |
+
else:
|
| 47 |
+
child_name = name
|
| 48 |
+
child_modules, child_names = torch_dfs(child, child_name)
|
| 49 |
+
module_names += child_names
|
| 50 |
+
modules += child_modules
|
| 51 |
+
return modules, module_names
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@amp.autocast(enabled=False)
|
| 55 |
+
@torch.compiler.disable()
|
| 56 |
+
def s2v_rope_apply(x, grid_sizes, freqs, start=None):
|
| 57 |
+
n, c = x.size(2), x.size(3) // 2
|
| 58 |
+
# loop over samples
|
| 59 |
+
output = []
|
| 60 |
+
for i, _ in enumerate(x):
|
| 61 |
+
s = x.size(1)
|
| 62 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
|
| 63 |
+
freqs_i = freqs[i, :s]
|
| 64 |
+
# apply rotary embedding
|
| 65 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 66 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 67 |
+
# append to collection
|
| 68 |
+
output.append(x_i)
|
| 69 |
+
return torch.stack(output).float()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 73 |
+
q = s2v_rope_apply(q, grid_sizes, freqs)
|
| 74 |
+
k = s2v_rope_apply(k, grid_sizes, freqs)
|
| 75 |
+
return q, k
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class WanS2VSelfAttention(WanSelfAttention):
|
| 79 |
+
|
| 80 |
+
def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0):
|
| 81 |
+
"""
|
| 82 |
+
Args:
|
| 83 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 84 |
+
seq_lens(Tensor): Shape [B]
|
| 85 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 86 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 87 |
+
"""
|
| 88 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 89 |
+
|
| 90 |
+
# query, key, value function
|
| 91 |
+
def qkv_fn(x):
|
| 92 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 93 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 94 |
+
v = self.v(x).view(b, s, n, d)
|
| 95 |
+
return q, k, v
|
| 96 |
+
|
| 97 |
+
q, k, v = qkv_fn(x)
|
| 98 |
+
|
| 99 |
+
q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 100 |
+
|
| 101 |
+
x = attention(
|
| 102 |
+
q.to(dtype),
|
| 103 |
+
k.to(dtype),
|
| 104 |
+
v=v.to(dtype),
|
| 105 |
+
k_lens=seq_lens,
|
| 106 |
+
window_size=self.window_size)
|
| 107 |
+
x = x.to(dtype)
|
| 108 |
+
|
| 109 |
+
# output
|
| 110 |
+
x = x.flatten(2)
|
| 111 |
+
x = self.o(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class WanS2VAttentionBlock(WanAttentionBlock):
|
| 116 |
+
|
| 117 |
+
def __init__(self,
|
| 118 |
+
cross_attn_type,
|
| 119 |
+
dim,
|
| 120 |
+
ffn_dim,
|
| 121 |
+
num_heads,
|
| 122 |
+
window_size=(-1, -1),
|
| 123 |
+
qk_norm=True,
|
| 124 |
+
cross_attn_norm=False,
|
| 125 |
+
eps=1e-6):
|
| 126 |
+
super().__init__(
|
| 127 |
+
cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps
|
| 128 |
+
)
|
| 129 |
+
self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,qk_norm, eps)
|
| 130 |
+
|
| 131 |
+
def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, dtype=torch.bfloat16, t=0):
|
| 132 |
+
# e
|
| 133 |
+
seg_idx = e[1].item()
|
| 134 |
+
seg_idx = min(max(0, seg_idx), x.size(1))
|
| 135 |
+
seg_idx = [0, seg_idx, x.size(1)]
|
| 136 |
+
e = e[0]
|
| 137 |
+
modulation = self.modulation.unsqueeze(2)
|
| 138 |
+
e = (modulation + e).chunk(6, dim=1)
|
| 139 |
+
e = [element.squeeze(1) for element in e]
|
| 140 |
+
|
| 141 |
+
# norm
|
| 142 |
+
norm_x = self.norm1(x).float()
|
| 143 |
+
parts = []
|
| 144 |
+
for i in range(2):
|
| 145 |
+
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] *
|
| 146 |
+
(1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
|
| 147 |
+
norm_x = torch.cat(parts, dim=1)
|
| 148 |
+
# self-attention
|
| 149 |
+
y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs)
|
| 150 |
+
with amp.autocast(dtype=torch.float32):
|
| 151 |
+
z = []
|
| 152 |
+
for i in range(2):
|
| 153 |
+
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
|
| 154 |
+
y = torch.cat(z, dim=1)
|
| 155 |
+
x = x + y
|
| 156 |
+
|
| 157 |
+
# cross-attention & ffn function
|
| 158 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 159 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 160 |
+
norm2_x = self.norm2(x).float()
|
| 161 |
+
parts = []
|
| 162 |
+
for i in range(2):
|
| 163 |
+
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] *
|
| 164 |
+
(1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
|
| 165 |
+
norm2_x = torch.cat(parts, dim=1)
|
| 166 |
+
y = self.ffn(norm2_x)
|
| 167 |
+
with amp.autocast(dtype=torch.float32):
|
| 168 |
+
z = []
|
| 169 |
+
for i in range(2):
|
| 170 |
+
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
|
| 171 |
+
y = torch.cat(z, dim=1)
|
| 172 |
+
x = x + y
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 176 |
+
return x
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class Wan2_2Transformer3DModel_S2V(Wan2_2Transformer3DModel):
|
| 180 |
+
# ignore_for_config = [
|
| 181 |
+
# 'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm',
|
| 182 |
+
# 'text_dim', 'window_size'
|
| 183 |
+
# ]
|
| 184 |
+
# _no_split_modules = ['WanS2VAttentionBlock']
|
| 185 |
+
|
| 186 |
+
@register_to_config
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
cond_dim=0,
|
| 190 |
+
audio_dim=5120,
|
| 191 |
+
num_audio_token=4,
|
| 192 |
+
enable_adain=False,
|
| 193 |
+
adain_mode="attn_norm",
|
| 194 |
+
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
|
| 195 |
+
zero_init=False,
|
| 196 |
+
zero_timestep=False,
|
| 197 |
+
enable_motioner=True,
|
| 198 |
+
add_last_motion=True,
|
| 199 |
+
enable_tsm=False,
|
| 200 |
+
trainable_token_pos_emb=False,
|
| 201 |
+
motion_token_num=1024,
|
| 202 |
+
enable_framepack=False, # Mutually exclusive with enable_motioner
|
| 203 |
+
framepack_drop_mode="drop",
|
| 204 |
+
model_type='s2v',
|
| 205 |
+
patch_size=(1, 2, 2),
|
| 206 |
+
text_len=512,
|
| 207 |
+
in_dim=16,
|
| 208 |
+
dim=2048,
|
| 209 |
+
ffn_dim=8192,
|
| 210 |
+
freq_dim=256,
|
| 211 |
+
text_dim=4096,
|
| 212 |
+
out_dim=16,
|
| 213 |
+
num_heads=16,
|
| 214 |
+
num_layers=32,
|
| 215 |
+
window_size=(-1, -1),
|
| 216 |
+
qk_norm=True,
|
| 217 |
+
cross_attn_norm=True,
|
| 218 |
+
eps=1e-6,
|
| 219 |
+
in_channels=16,
|
| 220 |
+
hidden_size=2048,
|
| 221 |
+
*args,
|
| 222 |
+
**kwargs
|
| 223 |
+
):
|
| 224 |
+
super().__init__(
|
| 225 |
+
model_type=model_type,
|
| 226 |
+
patch_size=patch_size,
|
| 227 |
+
text_len=text_len,
|
| 228 |
+
in_dim=in_dim,
|
| 229 |
+
dim=dim,
|
| 230 |
+
ffn_dim=ffn_dim,
|
| 231 |
+
freq_dim=freq_dim,
|
| 232 |
+
text_dim=text_dim,
|
| 233 |
+
out_dim=out_dim,
|
| 234 |
+
num_heads=num_heads,
|
| 235 |
+
num_layers=num_layers,
|
| 236 |
+
window_size=window_size,
|
| 237 |
+
qk_norm=qk_norm,
|
| 238 |
+
cross_attn_norm=cross_attn_norm,
|
| 239 |
+
eps=eps,
|
| 240 |
+
in_channels=in_channels,
|
| 241 |
+
hidden_size=hidden_size
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
assert model_type == 's2v'
|
| 245 |
+
self.enbale_adain = enable_adain
|
| 246 |
+
# Whether to assign 0 value timestep to ref/motion
|
| 247 |
+
self.adain_mode = adain_mode
|
| 248 |
+
self.zero_timestep = zero_timestep
|
| 249 |
+
self.enable_motioner = enable_motioner
|
| 250 |
+
self.add_last_motion = add_last_motion
|
| 251 |
+
self.enable_framepack = enable_framepack
|
| 252 |
+
|
| 253 |
+
# Replace blocks
|
| 254 |
+
self.blocks = nn.ModuleList([
|
| 255 |
+
WanS2VAttentionBlock("cross_attn", dim, ffn_dim, num_heads, window_size, qk_norm,
|
| 256 |
+
cross_attn_norm, eps)
|
| 257 |
+
for _ in range(num_layers)
|
| 258 |
+
])
|
| 259 |
+
|
| 260 |
+
# init audio injector
|
| 261 |
+
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
| 262 |
+
if cond_dim > 0:
|
| 263 |
+
self.cond_encoder = nn.Conv3d(
|
| 264 |
+
cond_dim,
|
| 265 |
+
self.dim,
|
| 266 |
+
kernel_size=self.patch_size,
|
| 267 |
+
stride=self.patch_size)
|
| 268 |
+
self.trainable_cond_mask = nn.Embedding(3, self.dim)
|
| 269 |
+
self.casual_audio_encoder = CausalAudioEncoder(
|
| 270 |
+
dim=audio_dim,
|
| 271 |
+
out_dim=self.dim,
|
| 272 |
+
num_token=num_audio_token,
|
| 273 |
+
need_global=enable_adain)
|
| 274 |
+
self.audio_injector = AudioInjector_WAN(
|
| 275 |
+
all_modules,
|
| 276 |
+
all_modules_names,
|
| 277 |
+
dim=self.dim,
|
| 278 |
+
num_heads=self.num_heads,
|
| 279 |
+
inject_layer=audio_inject_layers,
|
| 280 |
+
root_net=self,
|
| 281 |
+
enable_adain=enable_adain,
|
| 282 |
+
adain_dim=self.dim,
|
| 283 |
+
need_adain_ont=adain_mode != "attn_norm",
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if zero_init:
|
| 287 |
+
self.zero_init_weights()
|
| 288 |
+
|
| 289 |
+
# init motioner
|
| 290 |
+
if enable_motioner and enable_framepack:
|
| 291 |
+
raise ValueError(
|
| 292 |
+
"enable_motioner and enable_framepack are mutually exclusive, please set one of them to False"
|
| 293 |
+
)
|
| 294 |
+
if enable_motioner:
|
| 295 |
+
motioner_dim = 2048
|
| 296 |
+
self.motioner = MotionerTransformers(
|
| 297 |
+
patch_size=(2, 4, 4),
|
| 298 |
+
dim=motioner_dim,
|
| 299 |
+
ffn_dim=motioner_dim,
|
| 300 |
+
freq_dim=256,
|
| 301 |
+
out_dim=16,
|
| 302 |
+
num_heads=16,
|
| 303 |
+
num_layers=13,
|
| 304 |
+
window_size=(-1, -1),
|
| 305 |
+
qk_norm=True,
|
| 306 |
+
cross_attn_norm=False,
|
| 307 |
+
eps=1e-6,
|
| 308 |
+
motion_token_num=motion_token_num,
|
| 309 |
+
enable_tsm=enable_tsm,
|
| 310 |
+
motion_stride=4,
|
| 311 |
+
expand_ratio=2,
|
| 312 |
+
trainable_token_pos_emb=trainable_token_pos_emb,
|
| 313 |
+
)
|
| 314 |
+
self.zip_motion_out = torch.nn.Sequential(
|
| 315 |
+
WanLayerNorm(motioner_dim),
|
| 316 |
+
zero_module(nn.Linear(motioner_dim, self.dim)))
|
| 317 |
+
|
| 318 |
+
self.trainable_token_pos_emb = trainable_token_pos_emb
|
| 319 |
+
if trainable_token_pos_emb:
|
| 320 |
+
d = self.dim // self.num_heads
|
| 321 |
+
x = torch.zeros([1, motion_token_num, self.num_heads, d])
|
| 322 |
+
x[..., ::2] = 1
|
| 323 |
+
|
| 324 |
+
gride_sizes = [[
|
| 325 |
+
torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
|
| 326 |
+
torch.tensor([
|
| 327 |
+
1, self.motioner.motion_side_len,
|
| 328 |
+
self.motioner.motion_side_len
|
| 329 |
+
]).unsqueeze(0).repeat(1, 1),
|
| 330 |
+
torch.tensor([
|
| 331 |
+
1, self.motioner.motion_side_len,
|
| 332 |
+
self.motioner.motion_side_len
|
| 333 |
+
]).unsqueeze(0).repeat(1, 1),
|
| 334 |
+
]]
|
| 335 |
+
token_freqs = s2v_rope_apply(x, gride_sizes, self.freqs)
|
| 336 |
+
token_freqs = token_freqs[0, :,
|
| 337 |
+
0].reshape(motion_token_num, -1, 2)
|
| 338 |
+
token_freqs = token_freqs * 0.01
|
| 339 |
+
self.token_freqs = torch.nn.Parameter(token_freqs)
|
| 340 |
+
|
| 341 |
+
if enable_framepack:
|
| 342 |
+
self.frame_packer = FramePackMotioner(
|
| 343 |
+
inner_dim=self.dim,
|
| 344 |
+
num_heads=self.num_heads,
|
| 345 |
+
zip_frame_buckets=[1, 2, 16],
|
| 346 |
+
drop_mode=framepack_drop_mode)
|
| 347 |
+
|
| 348 |
+
def enable_multi_gpus_inference(self,):
|
| 349 |
+
self.sp_world_size = get_sequence_parallel_world_size()
|
| 350 |
+
self.sp_world_rank = get_sequence_parallel_rank()
|
| 351 |
+
self.all_gather = get_sp_group().all_gather
|
| 352 |
+
for block in self.blocks:
|
| 353 |
+
block.self_attn.forward = types.MethodType(
|
| 354 |
+
usp_attn_s2v_forward, block.self_attn)
|
| 355 |
+
|
| 356 |
+
def process_motion(self, motion_latents, drop_motion_frames=False):
|
| 357 |
+
if drop_motion_frames or motion_latents[0].shape[1] == 0:
|
| 358 |
+
return [], []
|
| 359 |
+
self.lat_motion_frames = motion_latents[0].shape[1]
|
| 360 |
+
mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents]
|
| 361 |
+
batch_size = len(mot)
|
| 362 |
+
|
| 363 |
+
mot_remb = []
|
| 364 |
+
flattern_mot = []
|
| 365 |
+
for bs in range(batch_size):
|
| 366 |
+
height, width = mot[bs].shape[3], mot[bs].shape[4]
|
| 367 |
+
flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous()
|
| 368 |
+
motion_grid_sizes = [[
|
| 369 |
+
torch.tensor([-self.lat_motion_frames, 0,
|
| 370 |
+
0]).unsqueeze(0).repeat(1, 1),
|
| 371 |
+
torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1),
|
| 372 |
+
torch.tensor([self.lat_motion_frames, height,
|
| 373 |
+
width]).unsqueeze(0).repeat(1, 1)
|
| 374 |
+
]]
|
| 375 |
+
motion_rope_emb = rope_precompute(
|
| 376 |
+
flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads,
|
| 377 |
+
self.dim // self.num_heads),
|
| 378 |
+
motion_grid_sizes,
|
| 379 |
+
self.freqs,
|
| 380 |
+
start=None)
|
| 381 |
+
mot_remb.append(motion_rope_emb)
|
| 382 |
+
flattern_mot.append(flat_mot)
|
| 383 |
+
return flattern_mot, mot_remb
|
| 384 |
+
|
| 385 |
+
def process_motion_frame_pack(self,
|
| 386 |
+
motion_latents,
|
| 387 |
+
drop_motion_frames=False,
|
| 388 |
+
add_last_motion=2):
|
| 389 |
+
flattern_mot, mot_remb = self.frame_packer(motion_latents,
|
| 390 |
+
add_last_motion)
|
| 391 |
+
if drop_motion_frames:
|
| 392 |
+
return [m[:, :0] for m in flattern_mot
|
| 393 |
+
], [m[:, :0] for m in mot_remb]
|
| 394 |
+
else:
|
| 395 |
+
return flattern_mot, mot_remb
|
| 396 |
+
|
| 397 |
+
def process_motion_transformer_motioner(self,
|
| 398 |
+
motion_latents,
|
| 399 |
+
drop_motion_frames=False,
|
| 400 |
+
add_last_motion=True):
|
| 401 |
+
batch_size, height, width = len(
|
| 402 |
+
motion_latents), motion_latents[0].shape[2] // self.patch_size[
|
| 403 |
+
1], motion_latents[0].shape[3] // self.patch_size[2]
|
| 404 |
+
|
| 405 |
+
freqs = self.freqs
|
| 406 |
+
device = self.patch_embedding.weight.device
|
| 407 |
+
if freqs.device != device:
|
| 408 |
+
freqs = freqs.to(device)
|
| 409 |
+
if self.trainable_token_pos_emb:
|
| 410 |
+
with amp.autocast(dtype=torch.float64):
|
| 411 |
+
token_freqs = self.token_freqs.to(torch.float64)
|
| 412 |
+
token_freqs = token_freqs / token_freqs.norm(
|
| 413 |
+
dim=-1, keepdim=True)
|
| 414 |
+
freqs = [freqs, torch.view_as_complex(token_freqs)]
|
| 415 |
+
|
| 416 |
+
if not drop_motion_frames and add_last_motion:
|
| 417 |
+
last_motion_latent = [u[:, -1:] for u in motion_latents]
|
| 418 |
+
last_mot = [
|
| 419 |
+
self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent
|
| 420 |
+
]
|
| 421 |
+
last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot]
|
| 422 |
+
last_mot = torch.cat(last_mot)
|
| 423 |
+
gride_sizes = [[
|
| 424 |
+
torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
|
| 425 |
+
torch.tensor([0, height,
|
| 426 |
+
width]).unsqueeze(0).repeat(batch_size, 1),
|
| 427 |
+
torch.tensor([1, height,
|
| 428 |
+
width]).unsqueeze(0).repeat(batch_size, 1)
|
| 429 |
+
]]
|
| 430 |
+
else:
|
| 431 |
+
last_mot = torch.zeros([batch_size, 0, self.dim],
|
| 432 |
+
device=motion_latents[0].device,
|
| 433 |
+
dtype=motion_latents[0].dtype)
|
| 434 |
+
gride_sizes = []
|
| 435 |
+
|
| 436 |
+
zip_motion = self.motioner(motion_latents)
|
| 437 |
+
zip_motion = self.zip_motion_out(zip_motion)
|
| 438 |
+
if drop_motion_frames:
|
| 439 |
+
zip_motion = zip_motion * 0.0
|
| 440 |
+
zip_motion_grid_sizes = [[
|
| 441 |
+
torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
|
| 442 |
+
torch.tensor([
|
| 443 |
+
0, self.motioner.motion_side_len, self.motioner.motion_side_len
|
| 444 |
+
]).unsqueeze(0).repeat(batch_size, 1),
|
| 445 |
+
torch.tensor(
|
| 446 |
+
[1 if not self.trainable_token_pos_emb else -1, height,
|
| 447 |
+
width]).unsqueeze(0).repeat(batch_size, 1),
|
| 448 |
+
]]
|
| 449 |
+
|
| 450 |
+
mot = torch.cat([last_mot, zip_motion], dim=1)
|
| 451 |
+
gride_sizes = gride_sizes + zip_motion_grid_sizes
|
| 452 |
+
|
| 453 |
+
motion_rope_emb = rope_precompute(
|
| 454 |
+
mot.detach().view(batch_size, mot.shape[1], self.num_heads,
|
| 455 |
+
self.dim // self.num_heads),
|
| 456 |
+
gride_sizes,
|
| 457 |
+
freqs,
|
| 458 |
+
start=None)
|
| 459 |
+
return [m.unsqueeze(0) for m in mot
|
| 460 |
+
], [r.unsqueeze(0) for r in motion_rope_emb]
|
| 461 |
+
|
| 462 |
+
def inject_motion(self,
|
| 463 |
+
x,
|
| 464 |
+
seq_lens,
|
| 465 |
+
rope_embs,
|
| 466 |
+
mask_input,
|
| 467 |
+
motion_latents,
|
| 468 |
+
drop_motion_frames=False,
|
| 469 |
+
add_last_motion=True):
|
| 470 |
+
# Inject the motion frames token to the hidden states
|
| 471 |
+
if self.enable_motioner:
|
| 472 |
+
mot, mot_remb = self.process_motion_transformer_motioner(
|
| 473 |
+
motion_latents,
|
| 474 |
+
drop_motion_frames=drop_motion_frames,
|
| 475 |
+
add_last_motion=add_last_motion)
|
| 476 |
+
elif self.enable_framepack:
|
| 477 |
+
mot, mot_remb = self.process_motion_frame_pack(
|
| 478 |
+
motion_latents,
|
| 479 |
+
drop_motion_frames=drop_motion_frames,
|
| 480 |
+
add_last_motion=add_last_motion)
|
| 481 |
+
else:
|
| 482 |
+
mot, mot_remb = self.process_motion(
|
| 483 |
+
motion_latents, drop_motion_frames=drop_motion_frames)
|
| 484 |
+
|
| 485 |
+
if len(mot) > 0:
|
| 486 |
+
x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)]
|
| 487 |
+
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot],
|
| 488 |
+
dtype=torch.long)
|
| 489 |
+
rope_embs = [
|
| 490 |
+
torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)
|
| 491 |
+
]
|
| 492 |
+
mask_input = [
|
| 493 |
+
torch.cat([
|
| 494 |
+
m, 2 * torch.ones([1, u.shape[1] - m.shape[1]],
|
| 495 |
+
device=m.device,
|
| 496 |
+
dtype=m.dtype)
|
| 497 |
+
],
|
| 498 |
+
dim=1) for m, u in zip(mask_input, x)
|
| 499 |
+
]
|
| 500 |
+
return x, seq_lens, rope_embs, mask_input
|
| 501 |
+
|
| 502 |
+
def after_transformer_block(self, block_idx, hidden_states):
|
| 503 |
+
if block_idx in self.audio_injector.injected_block_id.keys():
|
| 504 |
+
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
|
| 505 |
+
audio_emb = self.merged_audio_emb # b f n c
|
| 506 |
+
num_frames = audio_emb.shape[1]
|
| 507 |
+
|
| 508 |
+
if self.sp_world_size > 1:
|
| 509 |
+
hidden_states = self.all_gather(hidden_states, dim=1)
|
| 510 |
+
|
| 511 |
+
input_hidden_states = hidden_states[:, :self.original_seq_len].clone()
|
| 512 |
+
input_hidden_states = rearrange(
|
| 513 |
+
input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
|
| 514 |
+
|
| 515 |
+
if self.enbale_adain and self.adain_mode == "attn_norm":
|
| 516 |
+
audio_emb_global = self.audio_emb_global
|
| 517 |
+
audio_emb_global = rearrange(audio_emb_global,
|
| 518 |
+
"b t n c -> (b t) n c")
|
| 519 |
+
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](
|
| 520 |
+
input_hidden_states, temb=audio_emb_global[:, 0]
|
| 521 |
+
)
|
| 522 |
+
attn_hidden_states = adain_hidden_states
|
| 523 |
+
else:
|
| 524 |
+
attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](
|
| 525 |
+
input_hidden_states
|
| 526 |
+
)
|
| 527 |
+
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
| 528 |
+
attn_audio_emb = audio_emb
|
| 529 |
+
context_lens = torch.ones(
|
| 530 |
+
attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device
|
| 531 |
+
) * attn_audio_emb.shape[1]
|
| 532 |
+
residual_out = self.audio_injector.injector[audio_attn_id](
|
| 533 |
+
x=attn_hidden_states,
|
| 534 |
+
context=attn_audio_emb,
|
| 535 |
+
context_lens=context_lens)
|
| 536 |
+
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
| 537 |
+
hidden_states[:, :self.original_seq_len] = hidden_states[:, :self.original_seq_len] + residual_out
|
| 538 |
+
|
| 539 |
+
if self.sp_world_size > 1:
|
| 540 |
+
hidden_states = torch.chunk(
|
| 541 |
+
hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 542 |
+
|
| 543 |
+
return hidden_states
|
| 544 |
+
|
| 545 |
+
def forward(
|
| 546 |
+
self,
|
| 547 |
+
x,
|
| 548 |
+
t,
|
| 549 |
+
context,
|
| 550 |
+
seq_len,
|
| 551 |
+
ref_latents,
|
| 552 |
+
motion_latents,
|
| 553 |
+
cond_states,
|
| 554 |
+
audio_input=None,
|
| 555 |
+
motion_frames=[17, 5],
|
| 556 |
+
add_last_motion=2,
|
| 557 |
+
drop_motion_frames=False,
|
| 558 |
+
cond_flag=True,
|
| 559 |
+
*extra_args,
|
| 560 |
+
**extra_kwargs
|
| 561 |
+
):
|
| 562 |
+
"""
|
| 563 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 564 |
+
t: [B].
|
| 565 |
+
context: A list of text embeddings each with shape [L, C].
|
| 566 |
+
seq_len: A list of video token lens, no need for this model.
|
| 567 |
+
ref_latents A list of reference image for each video with shape [C, 1, H, W].
|
| 568 |
+
motion_latents A list of motion frames for each video with shape [C, T_m, H, W].
|
| 569 |
+
cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W].
|
| 570 |
+
audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a].
|
| 571 |
+
motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]
|
| 572 |
+
add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added.
|
| 573 |
+
For frame packing, the behavior depends on the value of add_last_motion:
|
| 574 |
+
add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included.
|
| 575 |
+
add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included.
|
| 576 |
+
add_last_motion = 2: All motion-related latents are used.
|
| 577 |
+
drop_motion_frames Bool, whether drop the motion frames info
|
| 578 |
+
"""
|
| 579 |
+
device = self.patch_embedding.weight.device
|
| 580 |
+
dtype = x.dtype
|
| 581 |
+
add_last_motion = self.add_last_motion * add_last_motion
|
| 582 |
+
|
| 583 |
+
# Embeddings
|
| 584 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 585 |
+
|
| 586 |
+
# Audio process
|
| 587 |
+
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
|
| 588 |
+
audio_emb_res = self.casual_audio_encoder(audio_input)
|
| 589 |
+
if self.enbale_adain:
|
| 590 |
+
audio_emb_global, audio_emb = audio_emb_res
|
| 591 |
+
self.audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
|
| 592 |
+
else:
|
| 593 |
+
audio_emb = audio_emb_res
|
| 594 |
+
self.merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
|
| 595 |
+
|
| 596 |
+
# Cond states
|
| 597 |
+
cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states]
|
| 598 |
+
x = [x_ + pose for x_, pose in zip(x, cond)]
|
| 599 |
+
|
| 600 |
+
grid_sizes = torch.stack(
|
| 601 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 602 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 603 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 604 |
+
|
| 605 |
+
original_grid_sizes = deepcopy(grid_sizes)
|
| 606 |
+
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
|
| 607 |
+
|
| 608 |
+
# Ref latents
|
| 609 |
+
ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents]
|
| 610 |
+
batch_size = len(ref)
|
| 611 |
+
height, width = ref[0].shape[3], ref[0].shape[4]
|
| 612 |
+
ref = [r.flatten(2).transpose(1, 2) for r in ref] # r: 1 c f h w
|
| 613 |
+
x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)]
|
| 614 |
+
|
| 615 |
+
self.original_seq_len = seq_lens[0]
|
| 616 |
+
seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref], dtype=torch.long)
|
| 617 |
+
ref_grid_sizes = [
|
| 618 |
+
[
|
| 619 |
+
torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), # the start index
|
| 620 |
+
torch.tensor([31, height,width]).unsqueeze(0).repeat(batch_size, 1), # the end index
|
| 621 |
+
torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1),
|
| 622 |
+
] # the range
|
| 623 |
+
]
|
| 624 |
+
grid_sizes = grid_sizes + ref_grid_sizes
|
| 625 |
+
|
| 626 |
+
# Compute the rope embeddings for the input
|
| 627 |
+
x = torch.cat(x)
|
| 628 |
+
b, s, n, d = x.size(0), x.size(1), self.num_heads, self.dim // self.num_heads
|
| 629 |
+
self.pre_compute_freqs = rope_precompute(
|
| 630 |
+
x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None)
|
| 631 |
+
x = [u.unsqueeze(0) for u in x]
|
| 632 |
+
self.pre_compute_freqs = [u.unsqueeze(0) for u in self.pre_compute_freqs]
|
| 633 |
+
|
| 634 |
+
# Inject Motion latents.
|
| 635 |
+
# Initialize masks to indicate noisy latent, ref latent, and motion latent.
|
| 636 |
+
# However, at this point, only the first two (noisy and ref latents) are marked;
|
| 637 |
+
# the marking of motion latent will be implemented inside `inject_motion`.
|
| 638 |
+
mask_input = [
|
| 639 |
+
torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device)
|
| 640 |
+
for u in x
|
| 641 |
+
]
|
| 642 |
+
for i in range(len(mask_input)):
|
| 643 |
+
mask_input[i][:, self.original_seq_len:] = 1
|
| 644 |
+
|
| 645 |
+
self.lat_motion_frames = motion_latents[0].shape[1]
|
| 646 |
+
x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion(
|
| 647 |
+
x,
|
| 648 |
+
seq_lens,
|
| 649 |
+
self.pre_compute_freqs,
|
| 650 |
+
mask_input,
|
| 651 |
+
motion_latents,
|
| 652 |
+
drop_motion_frames=drop_motion_frames,
|
| 653 |
+
add_last_motion=add_last_motion)
|
| 654 |
+
x = torch.cat(x, dim=0)
|
| 655 |
+
self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0)
|
| 656 |
+
mask_input = torch.cat(mask_input, dim=0)
|
| 657 |
+
|
| 658 |
+
# Apply trainable_cond_mask
|
| 659 |
+
x = x + self.trainable_cond_mask(mask_input).to(x.dtype)
|
| 660 |
+
|
| 661 |
+
seq_len = seq_lens.max()
|
| 662 |
+
if self.sp_world_size > 1:
|
| 663 |
+
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
|
| 664 |
+
assert seq_lens.max() <= seq_len
|
| 665 |
+
x = torch.cat([
|
| 666 |
+
torch.cat([u.unsqueeze(0), u.new_zeros(1, seq_len - u.size(0), u.size(1))],
|
| 667 |
+
dim=1) for u in x
|
| 668 |
+
])
|
| 669 |
+
|
| 670 |
+
# Time embeddings
|
| 671 |
+
if self.zero_timestep:
|
| 672 |
+
t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])
|
| 673 |
+
with amp.autocast(dtype=torch.float32):
|
| 674 |
+
e = self.time_embedding(
|
| 675 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 676 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 677 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 678 |
+
if self.zero_timestep:
|
| 679 |
+
e = e[:-1]
|
| 680 |
+
zero_e0 = e0[-1:]
|
| 681 |
+
e0 = e0[:-1]
|
| 682 |
+
token_len = x.shape[1]
|
| 683 |
+
|
| 684 |
+
e0 = torch.cat(
|
| 685 |
+
[
|
| 686 |
+
e0.unsqueeze(2),
|
| 687 |
+
zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)
|
| 688 |
+
],
|
| 689 |
+
dim=2
|
| 690 |
+
)
|
| 691 |
+
e0 = [e0, self.original_seq_len]
|
| 692 |
+
else:
|
| 693 |
+
e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1)
|
| 694 |
+
e0 = [e0, 0]
|
| 695 |
+
|
| 696 |
+
# context
|
| 697 |
+
context_lens = None
|
| 698 |
+
context = self.text_embedding(
|
| 699 |
+
torch.stack([
|
| 700 |
+
torch.cat(
|
| 701 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 702 |
+
for u in context
|
| 703 |
+
]))
|
| 704 |
+
|
| 705 |
+
if self.sp_world_size > 1:
|
| 706 |
+
# Sharded tensors for long context attn
|
| 707 |
+
x = torch.chunk(x, self.sp_world_size, dim=1)
|
| 708 |
+
sq_size = [u.shape[1] for u in x]
|
| 709 |
+
sq_start_size = sum(sq_size[:self.sp_world_rank])
|
| 710 |
+
x = x[self.sp_world_rank]
|
| 711 |
+
# Confirm the application range of the time embedding in e0[0] for each sequence:
|
| 712 |
+
# - For tokens before seg_id: apply e0[0][:, :, 0]
|
| 713 |
+
# - For tokens after seg_id: apply e0[0][:, :, 1]
|
| 714 |
+
sp_size = x.shape[1]
|
| 715 |
+
seg_idx = e0[1] - sq_start_size
|
| 716 |
+
e0[1] = seg_idx
|
| 717 |
+
|
| 718 |
+
self.pre_compute_freqs = torch.chunk(self.pre_compute_freqs, self.sp_world_size, dim=1)
|
| 719 |
+
self.pre_compute_freqs = self.pre_compute_freqs[self.sp_world_rank]
|
| 720 |
+
|
| 721 |
+
# TeaCache
|
| 722 |
+
if self.teacache is not None:
|
| 723 |
+
if cond_flag:
|
| 724 |
+
if t.dim() != 1:
|
| 725 |
+
modulated_inp = e0[0][:, -1, :]
|
| 726 |
+
else:
|
| 727 |
+
modulated_inp = e0[0]
|
| 728 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 729 |
+
if skip_flag:
|
| 730 |
+
self.should_calc = True
|
| 731 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 732 |
+
else:
|
| 733 |
+
if cond_flag:
|
| 734 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 735 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 736 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 737 |
+
self.should_calc = False
|
| 738 |
+
else:
|
| 739 |
+
self.should_calc = True
|
| 740 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 741 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 742 |
+
self.teacache.should_calc = self.should_calc
|
| 743 |
+
else:
|
| 744 |
+
self.should_calc = self.teacache.should_calc
|
| 745 |
+
|
| 746 |
+
# TeaCache
|
| 747 |
+
if self.teacache is not None:
|
| 748 |
+
if not self.should_calc:
|
| 749 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 750 |
+
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
|
| 751 |
+
else:
|
| 752 |
+
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
|
| 753 |
+
|
| 754 |
+
for idx, block in enumerate(self.blocks):
|
| 755 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 756 |
+
|
| 757 |
+
def create_custom_forward(module):
|
| 758 |
+
def custom_forward(*inputs):
|
| 759 |
+
return module(*inputs)
|
| 760 |
+
|
| 761 |
+
return custom_forward
|
| 762 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 763 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 764 |
+
create_custom_forward(block),
|
| 765 |
+
x,
|
| 766 |
+
e0,
|
| 767 |
+
seq_lens,
|
| 768 |
+
grid_sizes,
|
| 769 |
+
self.pre_compute_freqs,
|
| 770 |
+
context,
|
| 771 |
+
context_lens,
|
| 772 |
+
dtype,
|
| 773 |
+
t,
|
| 774 |
+
**ckpt_kwargs,
|
| 775 |
+
)
|
| 776 |
+
x = self.after_transformer_block(idx, x)
|
| 777 |
+
else:
|
| 778 |
+
# arguments
|
| 779 |
+
kwargs = dict(
|
| 780 |
+
e=e0,
|
| 781 |
+
seq_lens=seq_lens,
|
| 782 |
+
grid_sizes=grid_sizes,
|
| 783 |
+
freqs=self.pre_compute_freqs,
|
| 784 |
+
context=context,
|
| 785 |
+
context_lens=context_lens,
|
| 786 |
+
dtype=dtype,
|
| 787 |
+
t=t
|
| 788 |
+
)
|
| 789 |
+
x = block(x, **kwargs)
|
| 790 |
+
x = self.after_transformer_block(idx, x)
|
| 791 |
+
|
| 792 |
+
if cond_flag:
|
| 793 |
+
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 794 |
+
else:
|
| 795 |
+
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 796 |
+
else:
|
| 797 |
+
for idx, block in enumerate(self.blocks):
|
| 798 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 799 |
+
|
| 800 |
+
def create_custom_forward(module):
|
| 801 |
+
def custom_forward(*inputs):
|
| 802 |
+
return module(*inputs)
|
| 803 |
+
|
| 804 |
+
return custom_forward
|
| 805 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 806 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 807 |
+
create_custom_forward(block),
|
| 808 |
+
x,
|
| 809 |
+
e0,
|
| 810 |
+
seq_lens,
|
| 811 |
+
grid_sizes,
|
| 812 |
+
self.pre_compute_freqs,
|
| 813 |
+
context,
|
| 814 |
+
context_lens,
|
| 815 |
+
dtype,
|
| 816 |
+
t,
|
| 817 |
+
**ckpt_kwargs,
|
| 818 |
+
)
|
| 819 |
+
x = self.after_transformer_block(idx, x)
|
| 820 |
+
else:
|
| 821 |
+
# arguments
|
| 822 |
+
kwargs = dict(
|
| 823 |
+
e=e0,
|
| 824 |
+
seq_lens=seq_lens,
|
| 825 |
+
grid_sizes=grid_sizes,
|
| 826 |
+
freqs=self.pre_compute_freqs,
|
| 827 |
+
context=context,
|
| 828 |
+
context_lens=context_lens,
|
| 829 |
+
dtype=dtype,
|
| 830 |
+
t=t
|
| 831 |
+
)
|
| 832 |
+
x = block(x, **kwargs)
|
| 833 |
+
x = self.after_transformer_block(idx, x)
|
| 834 |
+
|
| 835 |
+
# Context Parallel
|
| 836 |
+
if self.sp_world_size > 1:
|
| 837 |
+
x = self.all_gather(x.contiguous(), dim=1)
|
| 838 |
+
|
| 839 |
+
# Unpatchify
|
| 840 |
+
x = x[:, :self.original_seq_len]
|
| 841 |
+
# Head
|
| 842 |
+
x = self.head(x, e)
|
| 843 |
+
x = self.unpatchify(x, original_grid_sizes)
|
| 844 |
+
x = torch.stack(x)
|
| 845 |
+
if self.teacache is not None and cond_flag:
|
| 846 |
+
self.teacache.cnt += 1
|
| 847 |
+
if self.teacache.cnt == self.teacache.num_steps:
|
| 848 |
+
self.teacache.reset()
|
| 849 |
+
return x
|
| 850 |
+
|
| 851 |
+
def unpatchify(self, x, grid_sizes):
|
| 852 |
+
"""
|
| 853 |
+
Reconstruct video tensors from patch embeddings.
|
| 854 |
+
|
| 855 |
+
Args:
|
| 856 |
+
x (List[Tensor]):
|
| 857 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 858 |
+
grid_sizes (Tensor):
|
| 859 |
+
Original spatial-temporal grid dimensions before patching,
|
| 860 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 861 |
+
|
| 862 |
+
Returns:
|
| 863 |
+
List[Tensor]:
|
| 864 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 865 |
+
"""
|
| 866 |
+
|
| 867 |
+
c = self.out_dim
|
| 868 |
+
out = []
|
| 869 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 870 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 871 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 872 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 873 |
+
out.append(u)
|
| 874 |
+
return out
|
| 875 |
+
|
| 876 |
+
def zero_init_weights(self):
|
| 877 |
+
with torch.no_grad():
|
| 878 |
+
self.trainable_cond_mask = zero_module(self.trainable_cond_mask)
|
| 879 |
+
if hasattr(self, "cond_encoder"):
|
| 880 |
+
self.cond_encoder = zero_module(self.cond_encoder)
|
| 881 |
+
|
| 882 |
+
for i in range(self.audio_injector.injector.__len__()):
|
| 883 |
+
self.audio_injector.injector[i].o = zero_module(
|
| 884 |
+
self.audio_injector.injector[i].o)
|
| 885 |
+
if self.enbale_adain:
|
| 886 |
+
self.audio_injector.injector_adain_layers[i].linear = \
|
| 887 |
+
zero_module(self.audio_injector.injector_adain_layers[i].linear)
|
videox_fun/models/wan_transformer3d_vace.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
import torch.cuda.amp as amp
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from diffusers.configuration_utils import register_to_config
|
| 12 |
+
from diffusers.utils import is_torch_version
|
| 13 |
+
|
| 14 |
+
from .wan_transformer3d import (WanAttentionBlock, WanTransformer3DModel,
|
| 15 |
+
sinusoidal_embedding_1d)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
VIDEOX_OFFLOAD_VACE_LATENTS = os.environ.get("VIDEOX_OFFLOAD_VACE_LATENTS", False)
|
| 19 |
+
|
| 20 |
+
class VaceWanAttentionBlock(WanAttentionBlock):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
cross_attn_type,
|
| 24 |
+
dim,
|
| 25 |
+
ffn_dim,
|
| 26 |
+
num_heads,
|
| 27 |
+
window_size=(-1, -1),
|
| 28 |
+
qk_norm=True,
|
| 29 |
+
cross_attn_norm=False,
|
| 30 |
+
eps=1e-6,
|
| 31 |
+
block_id=0
|
| 32 |
+
):
|
| 33 |
+
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
| 34 |
+
self.block_id = block_id
|
| 35 |
+
if block_id == 0:
|
| 36 |
+
self.before_proj = nn.Linear(self.dim, self.dim)
|
| 37 |
+
nn.init.zeros_(self.before_proj.weight)
|
| 38 |
+
nn.init.zeros_(self.before_proj.bias)
|
| 39 |
+
self.after_proj = nn.Linear(self.dim, self.dim)
|
| 40 |
+
nn.init.zeros_(self.after_proj.weight)
|
| 41 |
+
nn.init.zeros_(self.after_proj.bias)
|
| 42 |
+
|
| 43 |
+
def forward(self, c, x, **kwargs):
|
| 44 |
+
if self.block_id == 0:
|
| 45 |
+
c = self.before_proj(c) + x
|
| 46 |
+
all_c = []
|
| 47 |
+
else:
|
| 48 |
+
all_c = list(torch.unbind(c))
|
| 49 |
+
c = all_c.pop(-1)
|
| 50 |
+
|
| 51 |
+
if VIDEOX_OFFLOAD_VACE_LATENTS:
|
| 52 |
+
c = c.to(x.device)
|
| 53 |
+
|
| 54 |
+
c = super().forward(c, **kwargs)
|
| 55 |
+
c_skip = self.after_proj(c)
|
| 56 |
+
|
| 57 |
+
if VIDEOX_OFFLOAD_VACE_LATENTS:
|
| 58 |
+
c_skip = c_skip.to("cpu")
|
| 59 |
+
c = c.to("cpu")
|
| 60 |
+
|
| 61 |
+
all_c += [c_skip, c]
|
| 62 |
+
c = torch.stack(all_c)
|
| 63 |
+
return c
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class BaseWanAttentionBlock(WanAttentionBlock):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
cross_attn_type,
|
| 70 |
+
dim,
|
| 71 |
+
ffn_dim,
|
| 72 |
+
num_heads,
|
| 73 |
+
window_size=(-1, -1),
|
| 74 |
+
qk_norm=True,
|
| 75 |
+
cross_attn_norm=False,
|
| 76 |
+
eps=1e-6,
|
| 77 |
+
block_id=None
|
| 78 |
+
):
|
| 79 |
+
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
| 80 |
+
self.block_id = block_id
|
| 81 |
+
|
| 82 |
+
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
| 83 |
+
x = super().forward(x, **kwargs)
|
| 84 |
+
if self.block_id is not None:
|
| 85 |
+
if VIDEOX_OFFLOAD_VACE_LATENTS:
|
| 86 |
+
x = x + hints[self.block_id].to(x.device) * context_scale
|
| 87 |
+
else:
|
| 88 |
+
x = x + hints[self.block_id] * context_scale
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class VaceWanTransformer3DModel(WanTransformer3DModel):
|
| 93 |
+
@register_to_config
|
| 94 |
+
def __init__(self,
|
| 95 |
+
vace_layers=None,
|
| 96 |
+
vace_in_dim=None,
|
| 97 |
+
model_type='t2v',
|
| 98 |
+
patch_size=(1, 2, 2),
|
| 99 |
+
text_len=512,
|
| 100 |
+
in_dim=16,
|
| 101 |
+
dim=2048,
|
| 102 |
+
ffn_dim=8192,
|
| 103 |
+
freq_dim=256,
|
| 104 |
+
text_dim=4096,
|
| 105 |
+
out_dim=16,
|
| 106 |
+
num_heads=16,
|
| 107 |
+
num_layers=32,
|
| 108 |
+
window_size=(-1, -1),
|
| 109 |
+
qk_norm=True,
|
| 110 |
+
cross_attn_norm=True,
|
| 111 |
+
eps=1e-6):
|
| 112 |
+
model_type = "t2v" # TODO: Hard code for both preview and official versions.
|
| 113 |
+
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
| 114 |
+
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
| 115 |
+
|
| 116 |
+
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
|
| 117 |
+
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
| 118 |
+
|
| 119 |
+
assert 0 in self.vace_layers
|
| 120 |
+
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
| 121 |
+
|
| 122 |
+
# blocks
|
| 123 |
+
self.blocks = nn.ModuleList([
|
| 124 |
+
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
| 125 |
+
self.cross_attn_norm, self.eps,
|
| 126 |
+
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
| 127 |
+
for i in range(self.num_layers)
|
| 128 |
+
])
|
| 129 |
+
|
| 130 |
+
# vace blocks
|
| 131 |
+
self.vace_blocks = nn.ModuleList([
|
| 132 |
+
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
| 133 |
+
self.cross_attn_norm, self.eps, block_id=i)
|
| 134 |
+
for i in self.vace_layers
|
| 135 |
+
])
|
| 136 |
+
|
| 137 |
+
# vace patch embeddings
|
| 138 |
+
self.vace_patch_embedding = nn.Conv3d(
|
| 139 |
+
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward_vace(
|
| 143 |
+
self,
|
| 144 |
+
x,
|
| 145 |
+
vace_context,
|
| 146 |
+
seq_len,
|
| 147 |
+
kwargs
|
| 148 |
+
):
|
| 149 |
+
# embeddings
|
| 150 |
+
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
| 151 |
+
c = [u.flatten(2).transpose(1, 2) for u in c]
|
| 152 |
+
c = torch.cat([
|
| 153 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 154 |
+
dim=1) for u in c
|
| 155 |
+
])
|
| 156 |
+
# Context Parallel
|
| 157 |
+
if self.sp_world_size > 1:
|
| 158 |
+
c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 159 |
+
|
| 160 |
+
# arguments
|
| 161 |
+
new_kwargs = dict(x=x)
|
| 162 |
+
new_kwargs.update(kwargs)
|
| 163 |
+
|
| 164 |
+
for block in self.vace_blocks:
|
| 165 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 166 |
+
def create_custom_forward(module, **static_kwargs):
|
| 167 |
+
def custom_forward(*inputs):
|
| 168 |
+
return module(*inputs, **static_kwargs)
|
| 169 |
+
return custom_forward
|
| 170 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 171 |
+
c = torch.utils.checkpoint.checkpoint(
|
| 172 |
+
create_custom_forward(block, **new_kwargs),
|
| 173 |
+
c,
|
| 174 |
+
**ckpt_kwargs,
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
c = block(c, **new_kwargs)
|
| 178 |
+
hints = torch.unbind(c)[:-1]
|
| 179 |
+
return hints
|
| 180 |
+
|
| 181 |
+
def forward(
|
| 182 |
+
self,
|
| 183 |
+
x,
|
| 184 |
+
t,
|
| 185 |
+
vace_context,
|
| 186 |
+
context,
|
| 187 |
+
seq_len,
|
| 188 |
+
vace_context_scale=1.0,
|
| 189 |
+
clip_fea=None,
|
| 190 |
+
y=None,
|
| 191 |
+
cond_flag=True
|
| 192 |
+
):
|
| 193 |
+
r"""
|
| 194 |
+
Forward pass through the diffusion model
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
x (List[Tensor]):
|
| 198 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 199 |
+
t (Tensor):
|
| 200 |
+
Diffusion timesteps tensor of shape [B]
|
| 201 |
+
context (List[Tensor]):
|
| 202 |
+
List of text embeddings each with shape [L, C]
|
| 203 |
+
seq_len (`int`):
|
| 204 |
+
Maximum sequence length for positional encoding
|
| 205 |
+
clip_fea (Tensor, *optional*):
|
| 206 |
+
CLIP image features for image-to-video mode
|
| 207 |
+
y (List[Tensor], *optional*):
|
| 208 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
List[Tensor]:
|
| 212 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 213 |
+
"""
|
| 214 |
+
# if self.model_type == 'i2v':
|
| 215 |
+
# assert clip_fea is not None and y is not None
|
| 216 |
+
# params
|
| 217 |
+
dtype = x.dtype
|
| 218 |
+
device = self.patch_embedding.weight.device
|
| 219 |
+
if self.freqs.device != device and torch.device(type="meta") != device:
|
| 220 |
+
self.freqs = self.freqs.to(device)
|
| 221 |
+
|
| 222 |
+
# if y is not None:
|
| 223 |
+
# x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 224 |
+
|
| 225 |
+
# embeddings
|
| 226 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 227 |
+
grid_sizes = torch.stack(
|
| 228 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 229 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 230 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 231 |
+
if self.sp_world_size > 1:
|
| 232 |
+
seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
|
| 233 |
+
assert seq_lens.max() <= seq_len
|
| 234 |
+
x = torch.cat([
|
| 235 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 236 |
+
dim=1) for u in x
|
| 237 |
+
])
|
| 238 |
+
|
| 239 |
+
# time embeddings
|
| 240 |
+
with amp.autocast(dtype=torch.float32):
|
| 241 |
+
e = self.time_embedding(
|
| 242 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 243 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 244 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 245 |
+
|
| 246 |
+
# context
|
| 247 |
+
context_lens = None
|
| 248 |
+
context = self.text_embedding(
|
| 249 |
+
torch.stack([
|
| 250 |
+
torch.cat(
|
| 251 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 252 |
+
for u in context
|
| 253 |
+
]))
|
| 254 |
+
|
| 255 |
+
# Context Parallel
|
| 256 |
+
if self.sp_world_size > 1:
|
| 257 |
+
x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
|
| 258 |
+
|
| 259 |
+
# arguments
|
| 260 |
+
kwargs = dict(
|
| 261 |
+
e=e0,
|
| 262 |
+
seq_lens=seq_lens,
|
| 263 |
+
grid_sizes=grid_sizes,
|
| 264 |
+
freqs=self.freqs,
|
| 265 |
+
context=context,
|
| 266 |
+
context_lens=context_lens,
|
| 267 |
+
dtype=dtype,
|
| 268 |
+
t=t)
|
| 269 |
+
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
| 270 |
+
|
| 271 |
+
kwargs['hints'] = hints
|
| 272 |
+
kwargs['context_scale'] = vace_context_scale
|
| 273 |
+
|
| 274 |
+
# TeaCache
|
| 275 |
+
if self.teacache is not None:
|
| 276 |
+
if cond_flag:
|
| 277 |
+
if t.dim() != 1:
|
| 278 |
+
modulated_inp = e0[:, -1, :]
|
| 279 |
+
else:
|
| 280 |
+
modulated_inp = e0
|
| 281 |
+
skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
|
| 282 |
+
if skip_flag:
|
| 283 |
+
self.should_calc = True
|
| 284 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 285 |
+
else:
|
| 286 |
+
if cond_flag:
|
| 287 |
+
rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
|
| 288 |
+
self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
|
| 289 |
+
if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
|
| 290 |
+
self.should_calc = False
|
| 291 |
+
else:
|
| 292 |
+
self.should_calc = True
|
| 293 |
+
self.teacache.accumulated_rel_l1_distance = 0
|
| 294 |
+
self.teacache.previous_modulated_input = modulated_inp
|
| 295 |
+
self.teacache.should_calc = self.should_calc
|
| 296 |
+
else:
|
| 297 |
+
self.should_calc = self.teacache.should_calc
|
| 298 |
+
|
| 299 |
+
# TeaCache
|
| 300 |
+
if self.teacache is not None:
|
| 301 |
+
if not self.should_calc:
|
| 302 |
+
previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
|
| 303 |
+
x = x + previous_residual.to(x.device)[-x.size()[0]:,]
|
| 304 |
+
else:
|
| 305 |
+
ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
|
| 306 |
+
|
| 307 |
+
for block in self.blocks:
|
| 308 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 309 |
+
def create_custom_forward(module, **static_kwargs):
|
| 310 |
+
def custom_forward(*inputs):
|
| 311 |
+
return module(*inputs, **static_kwargs)
|
| 312 |
+
return custom_forward
|
| 313 |
+
extra_kwargs = {
|
| 314 |
+
'e': e0,
|
| 315 |
+
'seq_lens': seq_lens,
|
| 316 |
+
'grid_sizes': grid_sizes,
|
| 317 |
+
'freqs': self.freqs,
|
| 318 |
+
'context': context,
|
| 319 |
+
'context_lens': context_lens,
|
| 320 |
+
'dtype': dtype,
|
| 321 |
+
't': t,
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 325 |
+
|
| 326 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 327 |
+
create_custom_forward(block, **extra_kwargs),
|
| 328 |
+
x,
|
| 329 |
+
hints,
|
| 330 |
+
vace_context_scale,
|
| 331 |
+
**ckpt_kwargs,
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
x = block(x, **kwargs)
|
| 335 |
+
|
| 336 |
+
if cond_flag:
|
| 337 |
+
self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 338 |
+
else:
|
| 339 |
+
self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
|
| 340 |
+
else:
|
| 341 |
+
for block in self.blocks:
|
| 342 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 343 |
+
def create_custom_forward(module, **static_kwargs):
|
| 344 |
+
def custom_forward(*inputs):
|
| 345 |
+
return module(*inputs, **static_kwargs)
|
| 346 |
+
return custom_forward
|
| 347 |
+
extra_kwargs = {
|
| 348 |
+
'e': e0,
|
| 349 |
+
'seq_lens': seq_lens,
|
| 350 |
+
'grid_sizes': grid_sizes,
|
| 351 |
+
'freqs': self.freqs,
|
| 352 |
+
'context': context,
|
| 353 |
+
'context_lens': context_lens,
|
| 354 |
+
'dtype': dtype,
|
| 355 |
+
't': t,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 359 |
+
|
| 360 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 361 |
+
create_custom_forward(block, **extra_kwargs),
|
| 362 |
+
x,
|
| 363 |
+
hints,
|
| 364 |
+
vace_context_scale,
|
| 365 |
+
**ckpt_kwargs,
|
| 366 |
+
)
|
| 367 |
+
else:
|
| 368 |
+
x = block(x, **kwargs)
|
| 369 |
+
|
| 370 |
+
# head
|
| 371 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 372 |
+
def create_custom_forward(module):
|
| 373 |
+
def custom_forward(*inputs):
|
| 374 |
+
return module(*inputs)
|
| 375 |
+
|
| 376 |
+
return custom_forward
|
| 377 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 378 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
|
| 379 |
+
else:
|
| 380 |
+
x = self.head(x, e)
|
| 381 |
+
|
| 382 |
+
if self.sp_world_size > 1:
|
| 383 |
+
x = self.all_gather(x, dim=1)
|
| 384 |
+
|
| 385 |
+
# unpatchify
|
| 386 |
+
x = self.unpatchify(x, grid_sizes)
|
| 387 |
+
x = torch.stack(x)
|
| 388 |
+
if self.teacache is not None and cond_flag:
|
| 389 |
+
self.teacache.cnt += 1
|
| 390 |
+
if self.teacache.cnt == self.teacache.num_steps:
|
| 391 |
+
self.teacache.reset()
|
| 392 |
+
return x
|
videox_fun/models/wan_vae.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
from typing import Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 10 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
| 11 |
+
DiagonalGaussianDistribution)
|
| 12 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 14 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
CACHE_T = 2
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CausalConv3d(nn.Conv3d):
|
| 22 |
+
"""
|
| 23 |
+
Causal 3d convolusion.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, *args, **kwargs):
|
| 27 |
+
super().__init__(*args, **kwargs)
|
| 28 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
| 29 |
+
self.padding[1], 2 * self.padding[0], 0)
|
| 30 |
+
self.padding = (0, 0, 0)
|
| 31 |
+
|
| 32 |
+
def forward(self, x, cache_x=None):
|
| 33 |
+
padding = list(self._padding)
|
| 34 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 35 |
+
cache_x = cache_x.to(x.device)
|
| 36 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 37 |
+
padding[4] -= cache_x.shape[2]
|
| 38 |
+
x = F.pad(x, padding)
|
| 39 |
+
|
| 40 |
+
return super().forward(x)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RMS_norm(nn.Module):
|
| 44 |
+
|
| 45 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 46 |
+
super().__init__()
|
| 47 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 48 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 49 |
+
|
| 50 |
+
self.channel_first = channel_first
|
| 51 |
+
self.scale = dim**0.5
|
| 52 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 53 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return F.normalize(
|
| 57 |
+
x, dim=(1 if self.channel_first else
|
| 58 |
+
-1)) * self.scale * self.gamma + self.bias
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Upsample(nn.Upsample):
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
"""
|
| 65 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 66 |
+
"""
|
| 67 |
+
return super().forward(x.float()).type_as(x)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Resample(nn.Module):
|
| 71 |
+
|
| 72 |
+
def __init__(self, dim, mode):
|
| 73 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
| 74 |
+
'downsample3d')
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.dim = dim
|
| 77 |
+
self.mode = mode
|
| 78 |
+
|
| 79 |
+
# layers
|
| 80 |
+
if mode == 'upsample2d':
|
| 81 |
+
self.resample = nn.Sequential(
|
| 82 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 83 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 84 |
+
elif mode == 'upsample3d':
|
| 85 |
+
self.resample = nn.Sequential(
|
| 86 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 87 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 88 |
+
self.time_conv = CausalConv3d(
|
| 89 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 90 |
+
|
| 91 |
+
elif mode == 'downsample2d':
|
| 92 |
+
self.resample = nn.Sequential(
|
| 93 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 94 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 95 |
+
elif mode == 'downsample3d':
|
| 96 |
+
self.resample = nn.Sequential(
|
| 97 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 98 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 99 |
+
self.time_conv = CausalConv3d(
|
| 100 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 101 |
+
|
| 102 |
+
else:
|
| 103 |
+
self.resample = nn.Identity()
|
| 104 |
+
|
| 105 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 106 |
+
b, c, t, h, w = x.size()
|
| 107 |
+
if self.mode == 'upsample3d':
|
| 108 |
+
if feat_cache is not None:
|
| 109 |
+
idx = feat_idx[0]
|
| 110 |
+
if feat_cache[idx] is None:
|
| 111 |
+
feat_cache[idx] = 'Rep'
|
| 112 |
+
feat_idx[0] += 1
|
| 113 |
+
else:
|
| 114 |
+
|
| 115 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 116 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 117 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
| 118 |
+
# cache last frame of last two chunk
|
| 119 |
+
cache_x = torch.cat([
|
| 120 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 121 |
+
cache_x.device), cache_x
|
| 122 |
+
],
|
| 123 |
+
dim=2)
|
| 124 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 125 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
| 126 |
+
cache_x = torch.cat([
|
| 127 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 128 |
+
cache_x
|
| 129 |
+
],
|
| 130 |
+
dim=2)
|
| 131 |
+
if feat_cache[idx] == 'Rep':
|
| 132 |
+
x = self.time_conv(x)
|
| 133 |
+
else:
|
| 134 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 135 |
+
feat_cache[idx] = cache_x
|
| 136 |
+
feat_idx[0] += 1
|
| 137 |
+
|
| 138 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 139 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 140 |
+
3)
|
| 141 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 142 |
+
t = x.shape[2]
|
| 143 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 144 |
+
x = self.resample(x)
|
| 145 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
| 146 |
+
|
| 147 |
+
if self.mode == 'downsample3d':
|
| 148 |
+
if feat_cache is not None:
|
| 149 |
+
idx = feat_idx[0]
|
| 150 |
+
if feat_cache[idx] is None:
|
| 151 |
+
feat_cache[idx] = x.clone()
|
| 152 |
+
feat_idx[0] += 1
|
| 153 |
+
else:
|
| 154 |
+
|
| 155 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 156 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
| 157 |
+
# # cache last frame of last two chunk
|
| 158 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 159 |
+
|
| 160 |
+
x = self.time_conv(
|
| 161 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 162 |
+
feat_cache[idx] = cache_x
|
| 163 |
+
feat_idx[0] += 1
|
| 164 |
+
return x
|
| 165 |
+
|
| 166 |
+
def init_weight(self, conv):
|
| 167 |
+
conv_weight = conv.weight
|
| 168 |
+
nn.init.zeros_(conv_weight)
|
| 169 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 170 |
+
one_matrix = torch.eye(c1, c2)
|
| 171 |
+
init_matrix = one_matrix
|
| 172 |
+
nn.init.zeros_(conv_weight)
|
| 173 |
+
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
| 174 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
| 175 |
+
conv.weight.data.copy_(conv_weight)
|
| 176 |
+
nn.init.zeros_(conv.bias.data)
|
| 177 |
+
|
| 178 |
+
def init_weight2(self, conv):
|
| 179 |
+
conv_weight = conv.weight.data
|
| 180 |
+
nn.init.zeros_(conv_weight)
|
| 181 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 182 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 183 |
+
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
| 184 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 185 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 186 |
+
conv.weight.data.copy_(conv_weight)
|
| 187 |
+
nn.init.zeros_(conv.bias.data)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class ResidualBlock(nn.Module):
|
| 191 |
+
|
| 192 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.in_dim = in_dim
|
| 195 |
+
self.out_dim = out_dim
|
| 196 |
+
|
| 197 |
+
# layers
|
| 198 |
+
self.residual = nn.Sequential(
|
| 199 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
| 200 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 201 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
| 202 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
| 203 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
| 204 |
+
if in_dim != out_dim else nn.Identity()
|
| 205 |
+
|
| 206 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 207 |
+
h = self.shortcut(x)
|
| 208 |
+
for layer in self.residual:
|
| 209 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 210 |
+
idx = feat_idx[0]
|
| 211 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 212 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 213 |
+
# cache last frame of last two chunk
|
| 214 |
+
cache_x = torch.cat([
|
| 215 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 216 |
+
cache_x.device), cache_x
|
| 217 |
+
],
|
| 218 |
+
dim=2)
|
| 219 |
+
x = layer(x, feat_cache[idx])
|
| 220 |
+
feat_cache[idx] = cache_x
|
| 221 |
+
feat_idx[0] += 1
|
| 222 |
+
else:
|
| 223 |
+
x = layer(x)
|
| 224 |
+
return x + h
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class AttentionBlock(nn.Module):
|
| 228 |
+
"""
|
| 229 |
+
Causal self-attention with a single head.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def __init__(self, dim):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.dim = dim
|
| 235 |
+
|
| 236 |
+
# layers
|
| 237 |
+
self.norm = RMS_norm(dim)
|
| 238 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 239 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 240 |
+
|
| 241 |
+
# zero out the last layer params
|
| 242 |
+
nn.init.zeros_(self.proj.weight)
|
| 243 |
+
|
| 244 |
+
def forward(self, x):
|
| 245 |
+
identity = x
|
| 246 |
+
b, c, t, h, w = x.size()
|
| 247 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 248 |
+
x = self.norm(x)
|
| 249 |
+
# compute query, key, value
|
| 250 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
| 251 |
+
-1).permute(0, 1, 3,
|
| 252 |
+
2).contiguous().chunk(
|
| 253 |
+
3, dim=-1)
|
| 254 |
+
|
| 255 |
+
# apply attention
|
| 256 |
+
x = F.scaled_dot_product_attention(
|
| 257 |
+
q,
|
| 258 |
+
k,
|
| 259 |
+
v,
|
| 260 |
+
)
|
| 261 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 262 |
+
|
| 263 |
+
# output
|
| 264 |
+
x = self.proj(x)
|
| 265 |
+
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
| 266 |
+
return x + identity
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class Encoder3d(nn.Module):
|
| 270 |
+
|
| 271 |
+
def __init__(self,
|
| 272 |
+
dim=128,
|
| 273 |
+
z_dim=4,
|
| 274 |
+
dim_mult=[1, 2, 4, 4],
|
| 275 |
+
num_res_blocks=2,
|
| 276 |
+
attn_scales=[],
|
| 277 |
+
temperal_downsample=[True, True, False],
|
| 278 |
+
dropout=0.0):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.dim = dim
|
| 281 |
+
self.z_dim = z_dim
|
| 282 |
+
self.dim_mult = dim_mult
|
| 283 |
+
self.num_res_blocks = num_res_blocks
|
| 284 |
+
self.attn_scales = attn_scales
|
| 285 |
+
self.temperal_downsample = temperal_downsample
|
| 286 |
+
|
| 287 |
+
# dimensions
|
| 288 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 289 |
+
scale = 1.0
|
| 290 |
+
|
| 291 |
+
# init block
|
| 292 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
| 293 |
+
|
| 294 |
+
# downsample blocks
|
| 295 |
+
downsamples = []
|
| 296 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 297 |
+
# residual (+attention) blocks
|
| 298 |
+
for _ in range(num_res_blocks):
|
| 299 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 300 |
+
if scale in attn_scales:
|
| 301 |
+
downsamples.append(AttentionBlock(out_dim))
|
| 302 |
+
in_dim = out_dim
|
| 303 |
+
|
| 304 |
+
# downsample block
|
| 305 |
+
if i != len(dim_mult) - 1:
|
| 306 |
+
mode = 'downsample3d' if temperal_downsample[
|
| 307 |
+
i] else 'downsample2d'
|
| 308 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 309 |
+
scale /= 2.0
|
| 310 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 311 |
+
|
| 312 |
+
# middle blocks
|
| 313 |
+
self.middle = nn.Sequential(
|
| 314 |
+
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
| 315 |
+
ResidualBlock(out_dim, out_dim, dropout))
|
| 316 |
+
|
| 317 |
+
# output blocks
|
| 318 |
+
self.head = nn.Sequential(
|
| 319 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 320 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
| 321 |
+
|
| 322 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 323 |
+
if feat_cache is not None:
|
| 324 |
+
idx = feat_idx[0]
|
| 325 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 326 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 327 |
+
# cache last frame of last two chunk
|
| 328 |
+
cache_x = torch.cat([
|
| 329 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 330 |
+
cache_x.device), cache_x
|
| 331 |
+
],
|
| 332 |
+
dim=2)
|
| 333 |
+
x = self.conv1(x, feat_cache[idx])
|
| 334 |
+
feat_cache[idx] = cache_x
|
| 335 |
+
feat_idx[0] += 1
|
| 336 |
+
else:
|
| 337 |
+
x = self.conv1(x)
|
| 338 |
+
|
| 339 |
+
## downsamples
|
| 340 |
+
for layer in self.downsamples:
|
| 341 |
+
if feat_cache is not None:
|
| 342 |
+
x = layer(x, feat_cache, feat_idx)
|
| 343 |
+
else:
|
| 344 |
+
x = layer(x)
|
| 345 |
+
|
| 346 |
+
## middle
|
| 347 |
+
for layer in self.middle:
|
| 348 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 349 |
+
x = layer(x, feat_cache, feat_idx)
|
| 350 |
+
else:
|
| 351 |
+
x = layer(x)
|
| 352 |
+
|
| 353 |
+
## head
|
| 354 |
+
for layer in self.head:
|
| 355 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 356 |
+
idx = feat_idx[0]
|
| 357 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 358 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 359 |
+
# cache last frame of last two chunk
|
| 360 |
+
cache_x = torch.cat([
|
| 361 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 362 |
+
cache_x.device), cache_x
|
| 363 |
+
],
|
| 364 |
+
dim=2)
|
| 365 |
+
x = layer(x, feat_cache[idx])
|
| 366 |
+
feat_cache[idx] = cache_x
|
| 367 |
+
feat_idx[0] += 1
|
| 368 |
+
else:
|
| 369 |
+
x = layer(x)
|
| 370 |
+
return x
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class Decoder3d(nn.Module):
|
| 374 |
+
|
| 375 |
+
def __init__(self,
|
| 376 |
+
dim=128,
|
| 377 |
+
z_dim=4,
|
| 378 |
+
dim_mult=[1, 2, 4, 4],
|
| 379 |
+
num_res_blocks=2,
|
| 380 |
+
attn_scales=[],
|
| 381 |
+
temperal_upsample=[False, True, True],
|
| 382 |
+
dropout=0.0):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.dim = dim
|
| 385 |
+
self.z_dim = z_dim
|
| 386 |
+
self.dim_mult = dim_mult
|
| 387 |
+
self.num_res_blocks = num_res_blocks
|
| 388 |
+
self.attn_scales = attn_scales
|
| 389 |
+
self.temperal_upsample = temperal_upsample
|
| 390 |
+
|
| 391 |
+
# dimensions
|
| 392 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 393 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 394 |
+
|
| 395 |
+
# init block
|
| 396 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 397 |
+
|
| 398 |
+
# middle blocks
|
| 399 |
+
self.middle = nn.Sequential(
|
| 400 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
| 401 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
| 402 |
+
|
| 403 |
+
# upsample blocks
|
| 404 |
+
upsamples = []
|
| 405 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 406 |
+
# residual (+attention) blocks
|
| 407 |
+
if i == 1 or i == 2 or i == 3:
|
| 408 |
+
in_dim = in_dim // 2
|
| 409 |
+
for _ in range(num_res_blocks + 1):
|
| 410 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 411 |
+
if scale in attn_scales:
|
| 412 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 413 |
+
in_dim = out_dim
|
| 414 |
+
|
| 415 |
+
# upsample block
|
| 416 |
+
if i != len(dim_mult) - 1:
|
| 417 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
| 418 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 419 |
+
scale *= 2.0
|
| 420 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 421 |
+
|
| 422 |
+
# output blocks
|
| 423 |
+
self.head = nn.Sequential(
|
| 424 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 425 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
| 426 |
+
|
| 427 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 428 |
+
## conv1
|
| 429 |
+
if feat_cache is not None:
|
| 430 |
+
idx = feat_idx[0]
|
| 431 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 432 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 433 |
+
# cache last frame of last two chunk
|
| 434 |
+
cache_x = torch.cat([
|
| 435 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 436 |
+
cache_x.device), cache_x
|
| 437 |
+
],
|
| 438 |
+
dim=2)
|
| 439 |
+
x = self.conv1(x, feat_cache[idx])
|
| 440 |
+
feat_cache[idx] = cache_x
|
| 441 |
+
feat_idx[0] += 1
|
| 442 |
+
else:
|
| 443 |
+
x = self.conv1(x)
|
| 444 |
+
|
| 445 |
+
## middle
|
| 446 |
+
for layer in self.middle:
|
| 447 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 448 |
+
x = layer(x, feat_cache, feat_idx)
|
| 449 |
+
else:
|
| 450 |
+
x = layer(x)
|
| 451 |
+
|
| 452 |
+
## upsamples
|
| 453 |
+
for layer in self.upsamples:
|
| 454 |
+
if feat_cache is not None:
|
| 455 |
+
x = layer(x, feat_cache, feat_idx)
|
| 456 |
+
else:
|
| 457 |
+
x = layer(x)
|
| 458 |
+
|
| 459 |
+
## head
|
| 460 |
+
for layer in self.head:
|
| 461 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 462 |
+
idx = feat_idx[0]
|
| 463 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 464 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 465 |
+
# cache last frame of last two chunk
|
| 466 |
+
cache_x = torch.cat([
|
| 467 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 468 |
+
cache_x.device), cache_x
|
| 469 |
+
],
|
| 470 |
+
dim=2)
|
| 471 |
+
x = layer(x, feat_cache[idx])
|
| 472 |
+
feat_cache[idx] = cache_x
|
| 473 |
+
feat_idx[0] += 1
|
| 474 |
+
else:
|
| 475 |
+
x = layer(x)
|
| 476 |
+
return x
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def count_conv3d(model):
|
| 480 |
+
count = 0
|
| 481 |
+
for m in model.modules():
|
| 482 |
+
if isinstance(m, CausalConv3d):
|
| 483 |
+
count += 1
|
| 484 |
+
return count
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class AutoencoderKLWan_(nn.Module):
|
| 488 |
+
|
| 489 |
+
def __init__(self,
|
| 490 |
+
dim=128,
|
| 491 |
+
z_dim=4,
|
| 492 |
+
dim_mult=[1, 2, 4, 4],
|
| 493 |
+
num_res_blocks=2,
|
| 494 |
+
attn_scales=[],
|
| 495 |
+
temperal_downsample=[True, True, False],
|
| 496 |
+
dropout=0.0):
|
| 497 |
+
super().__init__()
|
| 498 |
+
self.dim = dim
|
| 499 |
+
self.z_dim = z_dim
|
| 500 |
+
self.dim_mult = dim_mult
|
| 501 |
+
self.num_res_blocks = num_res_blocks
|
| 502 |
+
self.attn_scales = attn_scales
|
| 503 |
+
self.temperal_downsample = temperal_downsample
|
| 504 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 505 |
+
|
| 506 |
+
# modules
|
| 507 |
+
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
| 508 |
+
attn_scales, self.temperal_downsample, dropout)
|
| 509 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 510 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 511 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
| 512 |
+
attn_scales, self.temperal_upsample, dropout)
|
| 513 |
+
|
| 514 |
+
def forward(self, x):
|
| 515 |
+
mu, log_var = self.encode(x)
|
| 516 |
+
z = self.reparameterize(mu, log_var)
|
| 517 |
+
x_recon = self.decode(z)
|
| 518 |
+
return x_recon, mu, log_var
|
| 519 |
+
|
| 520 |
+
def encode(self, x, scale):
|
| 521 |
+
self.clear_cache()
|
| 522 |
+
## cache
|
| 523 |
+
t = x.shape[2]
|
| 524 |
+
iter_ = 1 + (t - 1) // 4
|
| 525 |
+
scale = [item.to(x.device, x.dtype) for item in scale]
|
| 526 |
+
## 对encode输入的x,按时间拆分为1、4、4、4....
|
| 527 |
+
for i in range(iter_):
|
| 528 |
+
self._enc_conv_idx = [0]
|
| 529 |
+
if i == 0:
|
| 530 |
+
out = self.encoder(
|
| 531 |
+
x[:, :, :1, :, :],
|
| 532 |
+
feat_cache=self._enc_feat_map,
|
| 533 |
+
feat_idx=self._enc_conv_idx)
|
| 534 |
+
else:
|
| 535 |
+
out_ = self.encoder(
|
| 536 |
+
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
| 537 |
+
feat_cache=self._enc_feat_map,
|
| 538 |
+
feat_idx=self._enc_conv_idx)
|
| 539 |
+
out = torch.cat([out, out_], 2)
|
| 540 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 541 |
+
if isinstance(scale[0], torch.Tensor):
|
| 542 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 543 |
+
1, self.z_dim, 1, 1, 1)
|
| 544 |
+
else:
|
| 545 |
+
mu = (mu - scale[0]) * scale[1]
|
| 546 |
+
x = torch.cat([mu, log_var], dim = 1)
|
| 547 |
+
self.clear_cache()
|
| 548 |
+
return x
|
| 549 |
+
|
| 550 |
+
def decode(self, z, scale):
|
| 551 |
+
self.clear_cache()
|
| 552 |
+
# z: [b,c,t,h,w]
|
| 553 |
+
scale = [item.to(z.device, z.dtype) for item in scale]
|
| 554 |
+
if isinstance(scale[0], torch.Tensor):
|
| 555 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 556 |
+
1, self.z_dim, 1, 1, 1)
|
| 557 |
+
else:
|
| 558 |
+
z = z / scale[1] + scale[0]
|
| 559 |
+
iter_ = z.shape[2]
|
| 560 |
+
x = self.conv2(z)
|
| 561 |
+
for i in range(iter_):
|
| 562 |
+
self._conv_idx = [0]
|
| 563 |
+
if i == 0:
|
| 564 |
+
out = self.decoder(
|
| 565 |
+
x[:, :, i:i + 1, :, :],
|
| 566 |
+
feat_cache=self._feat_map,
|
| 567 |
+
feat_idx=self._conv_idx)
|
| 568 |
+
else:
|
| 569 |
+
out_ = self.decoder(
|
| 570 |
+
x[:, :, i:i + 1, :, :],
|
| 571 |
+
feat_cache=self._feat_map,
|
| 572 |
+
feat_idx=self._conv_idx)
|
| 573 |
+
out = torch.cat([out, out_], 2)
|
| 574 |
+
self.clear_cache()
|
| 575 |
+
return out
|
| 576 |
+
|
| 577 |
+
def reparameterize(self, mu, log_var):
|
| 578 |
+
std = torch.exp(0.5 * log_var)
|
| 579 |
+
eps = torch.randn_like(std)
|
| 580 |
+
return eps * std + mu
|
| 581 |
+
|
| 582 |
+
def sample(self, imgs, deterministic=False):
|
| 583 |
+
mu, log_var = self.encode(imgs)
|
| 584 |
+
if deterministic:
|
| 585 |
+
return mu
|
| 586 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 587 |
+
return mu + std * torch.randn_like(std)
|
| 588 |
+
|
| 589 |
+
def clear_cache(self):
|
| 590 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 591 |
+
self._conv_idx = [0]
|
| 592 |
+
self._feat_map = [None] * self._conv_num
|
| 593 |
+
#cache encode
|
| 594 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 595 |
+
self._enc_conv_idx = [0]
|
| 596 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def _video_vae(z_dim=None, **kwargs):
|
| 600 |
+
"""
|
| 601 |
+
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
| 602 |
+
"""
|
| 603 |
+
# params
|
| 604 |
+
cfg = dict(
|
| 605 |
+
dim=96,
|
| 606 |
+
z_dim=z_dim,
|
| 607 |
+
dim_mult=[1, 2, 4, 4],
|
| 608 |
+
num_res_blocks=2,
|
| 609 |
+
attn_scales=[],
|
| 610 |
+
temperal_downsample=[False, True, True],
|
| 611 |
+
dropout=0.0)
|
| 612 |
+
cfg.update(**kwargs)
|
| 613 |
+
|
| 614 |
+
# init model
|
| 615 |
+
model = AutoencoderKLWan_(**cfg)
|
| 616 |
+
|
| 617 |
+
return model
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 621 |
+
|
| 622 |
+
@register_to_config
|
| 623 |
+
def __init__(
|
| 624 |
+
self,
|
| 625 |
+
latent_channels=16,
|
| 626 |
+
temporal_compression_ratio=4,
|
| 627 |
+
spatial_compression_ratio=8
|
| 628 |
+
):
|
| 629 |
+
super().__init__()
|
| 630 |
+
mean = [
|
| 631 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 632 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 633 |
+
]
|
| 634 |
+
std = [
|
| 635 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 636 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 637 |
+
]
|
| 638 |
+
self.mean = torch.tensor(mean, dtype=torch.float32)
|
| 639 |
+
self.std = torch.tensor(std, dtype=torch.float32)
|
| 640 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 641 |
+
|
| 642 |
+
# init model
|
| 643 |
+
self.model = _video_vae(
|
| 644 |
+
z_dim=latent_channels,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 648 |
+
x = [
|
| 649 |
+
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
|
| 650 |
+
for u in x
|
| 651 |
+
]
|
| 652 |
+
x = torch.stack(x)
|
| 653 |
+
return x
|
| 654 |
+
|
| 655 |
+
@apply_forward_hook
|
| 656 |
+
def encode(
|
| 657 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 658 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 659 |
+
h = self._encode(x)
|
| 660 |
+
|
| 661 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 662 |
+
|
| 663 |
+
if not return_dict:
|
| 664 |
+
return (posterior,)
|
| 665 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 666 |
+
|
| 667 |
+
def _decode(self, zs):
|
| 668 |
+
dec = [
|
| 669 |
+
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
|
| 670 |
+
for u in zs
|
| 671 |
+
]
|
| 672 |
+
dec = torch.stack(dec)
|
| 673 |
+
|
| 674 |
+
return DecoderOutput(sample=dec)
|
| 675 |
+
|
| 676 |
+
@apply_forward_hook
|
| 677 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 678 |
+
decoded = self._decode(z).sample
|
| 679 |
+
|
| 680 |
+
if not return_dict:
|
| 681 |
+
return (decoded,)
|
| 682 |
+
return DecoderOutput(sample=decoded)
|
| 683 |
+
|
| 684 |
+
@classmethod
|
| 685 |
+
def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
|
| 686 |
+
def filter_kwargs(cls, kwargs):
|
| 687 |
+
import inspect
|
| 688 |
+
sig = inspect.signature(cls.__init__)
|
| 689 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 690 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 691 |
+
return filtered_kwargs
|
| 692 |
+
|
| 693 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 694 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 695 |
+
from safetensors.torch import load_file, safe_open
|
| 696 |
+
state_dict = load_file(pretrained_model_path)
|
| 697 |
+
else:
|
| 698 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 699 |
+
tmp_state_dict = {}
|
| 700 |
+
for key in state_dict:
|
| 701 |
+
tmp_state_dict["model." + key] = state_dict[key]
|
| 702 |
+
state_dict = tmp_state_dict
|
| 703 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 704 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 705 |
+
print(m, u)
|
| 706 |
+
return model
|
videox_fun/models/wan_vae3_8.py
ADDED
|
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.cuda.amp as amp
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 10 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 11 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
| 12 |
+
DiagonalGaussianDistribution)
|
| 13 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 14 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 15 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
CACHE_T = 2
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CausalConv3d(nn.Conv3d):
|
| 23 |
+
"""
|
| 24 |
+
Causal 3d convolusion.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, *args, **kwargs):
|
| 28 |
+
super().__init__(*args, **kwargs)
|
| 29 |
+
self._padding = (
|
| 30 |
+
self.padding[2],
|
| 31 |
+
self.padding[2],
|
| 32 |
+
self.padding[1],
|
| 33 |
+
self.padding[1],
|
| 34 |
+
2 * self.padding[0],
|
| 35 |
+
0,
|
| 36 |
+
)
|
| 37 |
+
self.padding = (0, 0, 0)
|
| 38 |
+
|
| 39 |
+
def forward(self, x, cache_x=None):
|
| 40 |
+
padding = list(self._padding)
|
| 41 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 42 |
+
cache_x = cache_x.to(x.device)
|
| 43 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 44 |
+
padding[4] -= cache_x.shape[2]
|
| 45 |
+
x = F.pad(x, padding)
|
| 46 |
+
|
| 47 |
+
return super().forward(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RMS_norm(nn.Module):
|
| 51 |
+
|
| 52 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 53 |
+
super().__init__()
|
| 54 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 55 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 56 |
+
|
| 57 |
+
self.channel_first = channel_first
|
| 58 |
+
self.scale = dim**0.5
|
| 59 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 60 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
|
| 64 |
+
self.scale * self.gamma + self.bias)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Upsample(nn.Upsample):
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
"""
|
| 71 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 72 |
+
"""
|
| 73 |
+
return super().forward(x.float()).type_as(x)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Resample(nn.Module):
|
| 77 |
+
|
| 78 |
+
def __init__(self, dim, mode):
|
| 79 |
+
assert mode in (
|
| 80 |
+
"none",
|
| 81 |
+
"upsample2d",
|
| 82 |
+
"upsample3d",
|
| 83 |
+
"downsample2d",
|
| 84 |
+
"downsample3d",
|
| 85 |
+
)
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.dim = dim
|
| 88 |
+
self.mode = mode
|
| 89 |
+
|
| 90 |
+
# layers
|
| 91 |
+
if mode == "upsample2d":
|
| 92 |
+
self.resample = nn.Sequential(
|
| 93 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 94 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 95 |
+
)
|
| 96 |
+
elif mode == "upsample3d":
|
| 97 |
+
self.resample = nn.Sequential(
|
| 98 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 99 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 100 |
+
# nn.Conv2d(dim, dim//2, 3, padding=1)
|
| 101 |
+
)
|
| 102 |
+
self.time_conv = CausalConv3d(
|
| 103 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 104 |
+
elif mode == "downsample2d":
|
| 105 |
+
self.resample = nn.Sequential(
|
| 106 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 107 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 108 |
+
elif mode == "downsample3d":
|
| 109 |
+
self.resample = nn.Sequential(
|
| 110 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 111 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 112 |
+
self.time_conv = CausalConv3d(
|
| 113 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 114 |
+
else:
|
| 115 |
+
self.resample = nn.Identity()
|
| 116 |
+
|
| 117 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 118 |
+
b, c, t, h, w = x.size()
|
| 119 |
+
if self.mode == "upsample3d":
|
| 120 |
+
if feat_cache is not None:
|
| 121 |
+
idx = feat_idx[0]
|
| 122 |
+
if feat_cache[idx] is None:
|
| 123 |
+
feat_cache[idx] = "Rep"
|
| 124 |
+
feat_idx[0] += 1
|
| 125 |
+
else:
|
| 126 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 127 |
+
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
| 128 |
+
feat_cache[idx] != "Rep"):
|
| 129 |
+
# cache last frame of last two chunk
|
| 130 |
+
cache_x = torch.cat(
|
| 131 |
+
[
|
| 132 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 133 |
+
cache_x.device),
|
| 134 |
+
cache_x,
|
| 135 |
+
],
|
| 136 |
+
dim=2,
|
| 137 |
+
)
|
| 138 |
+
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
| 139 |
+
feat_cache[idx] == "Rep"):
|
| 140 |
+
cache_x = torch.cat(
|
| 141 |
+
[
|
| 142 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 143 |
+
cache_x
|
| 144 |
+
],
|
| 145 |
+
dim=2,
|
| 146 |
+
)
|
| 147 |
+
if feat_cache[idx] == "Rep":
|
| 148 |
+
x = self.time_conv(x)
|
| 149 |
+
else:
|
| 150 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 151 |
+
feat_cache[idx] = cache_x
|
| 152 |
+
feat_idx[0] += 1
|
| 153 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 154 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 155 |
+
3)
|
| 156 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 157 |
+
t = x.shape[2]
|
| 158 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 159 |
+
x = self.resample(x)
|
| 160 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
| 161 |
+
|
| 162 |
+
if self.mode == "downsample3d":
|
| 163 |
+
if feat_cache is not None:
|
| 164 |
+
idx = feat_idx[0]
|
| 165 |
+
if feat_cache[idx] is None:
|
| 166 |
+
feat_cache[idx] = x.clone()
|
| 167 |
+
feat_idx[0] += 1
|
| 168 |
+
else:
|
| 169 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 170 |
+
x = self.time_conv(
|
| 171 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 172 |
+
feat_cache[idx] = cache_x
|
| 173 |
+
feat_idx[0] += 1
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
def init_weight(self, conv):
|
| 177 |
+
conv_weight = conv.weight.detach().clone()
|
| 178 |
+
nn.init.zeros_(conv_weight)
|
| 179 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 180 |
+
one_matrix = torch.eye(c1, c2)
|
| 181 |
+
init_matrix = one_matrix
|
| 182 |
+
nn.init.zeros_(conv_weight)
|
| 183 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
| 184 |
+
conv.weight = nn.Parameter(conv_weight)
|
| 185 |
+
nn.init.zeros_(conv.bias.data)
|
| 186 |
+
|
| 187 |
+
def init_weight2(self, conv):
|
| 188 |
+
conv_weight = conv.weight.data.detach().clone()
|
| 189 |
+
nn.init.zeros_(conv_weight)
|
| 190 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 191 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 192 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 193 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 194 |
+
conv.weight = nn.Parameter(conv_weight)
|
| 195 |
+
nn.init.zeros_(conv.bias.data)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class ResidualBlock(nn.Module):
|
| 199 |
+
|
| 200 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.in_dim = in_dim
|
| 203 |
+
self.out_dim = out_dim
|
| 204 |
+
|
| 205 |
+
# layers
|
| 206 |
+
self.residual = nn.Sequential(
|
| 207 |
+
RMS_norm(in_dim, images=False),
|
| 208 |
+
nn.SiLU(),
|
| 209 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 210 |
+
RMS_norm(out_dim, images=False),
|
| 211 |
+
nn.SiLU(),
|
| 212 |
+
nn.Dropout(dropout),
|
| 213 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
| 214 |
+
)
|
| 215 |
+
self.shortcut = (
|
| 216 |
+
CausalConv3d(in_dim, out_dim, 1)
|
| 217 |
+
if in_dim != out_dim else nn.Identity())
|
| 218 |
+
|
| 219 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 220 |
+
h = self.shortcut(x)
|
| 221 |
+
for layer in self.residual:
|
| 222 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 223 |
+
idx = feat_idx[0]
|
| 224 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 225 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 226 |
+
# cache last frame of last two chunk
|
| 227 |
+
cache_x = torch.cat(
|
| 228 |
+
[
|
| 229 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 230 |
+
cache_x.device),
|
| 231 |
+
cache_x,
|
| 232 |
+
],
|
| 233 |
+
dim=2,
|
| 234 |
+
)
|
| 235 |
+
x = layer(x, feat_cache[idx])
|
| 236 |
+
feat_cache[idx] = cache_x
|
| 237 |
+
feat_idx[0] += 1
|
| 238 |
+
else:
|
| 239 |
+
x = layer(x)
|
| 240 |
+
return x + h
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class AttentionBlock(nn.Module):
|
| 244 |
+
"""
|
| 245 |
+
Causal self-attention with a single head.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self, dim):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.dim = dim
|
| 251 |
+
|
| 252 |
+
# layers
|
| 253 |
+
self.norm = RMS_norm(dim)
|
| 254 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 255 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 256 |
+
|
| 257 |
+
# zero out the last layer params
|
| 258 |
+
nn.init.zeros_(self.proj.weight)
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
identity = x
|
| 262 |
+
b, c, t, h, w = x.size()
|
| 263 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 264 |
+
x = self.norm(x)
|
| 265 |
+
# compute query, key, value
|
| 266 |
+
q, k, v = (
|
| 267 |
+
self.to_qkv(x).reshape(b * t, 1, c * 3,
|
| 268 |
+
-1).permute(0, 1, 3,
|
| 269 |
+
2).contiguous().chunk(3, dim=-1))
|
| 270 |
+
|
| 271 |
+
# apply attention
|
| 272 |
+
x = F.scaled_dot_product_attention(
|
| 273 |
+
q,
|
| 274 |
+
k,
|
| 275 |
+
v,
|
| 276 |
+
)
|
| 277 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 278 |
+
|
| 279 |
+
# output
|
| 280 |
+
x = self.proj(x)
|
| 281 |
+
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
|
| 282 |
+
return x + identity
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def patchify(x, patch_size):
|
| 286 |
+
if patch_size == 1:
|
| 287 |
+
return x
|
| 288 |
+
if x.dim() == 4:
|
| 289 |
+
x = rearrange(
|
| 290 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
| 291 |
+
elif x.dim() == 5:
|
| 292 |
+
x = rearrange(
|
| 293 |
+
x,
|
| 294 |
+
"b c f (h q) (w r) -> b (c r q) f h w",
|
| 295 |
+
q=patch_size,
|
| 296 |
+
r=patch_size,
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 300 |
+
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def unpatchify(x, patch_size):
|
| 305 |
+
if patch_size == 1:
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
if x.dim() == 4:
|
| 309 |
+
x = rearrange(
|
| 310 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
| 311 |
+
elif x.dim() == 5:
|
| 312 |
+
x = rearrange(
|
| 313 |
+
x,
|
| 314 |
+
"b (c r q) f h w -> b c f (h q) (w r)",
|
| 315 |
+
q=patch_size,
|
| 316 |
+
r=patch_size,
|
| 317 |
+
)
|
| 318 |
+
return x
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class AvgDown3D(nn.Module):
|
| 322 |
+
|
| 323 |
+
def __init__(
|
| 324 |
+
self,
|
| 325 |
+
in_channels,
|
| 326 |
+
out_channels,
|
| 327 |
+
factor_t,
|
| 328 |
+
factor_s=1,
|
| 329 |
+
):
|
| 330 |
+
super().__init__()
|
| 331 |
+
self.in_channels = in_channels
|
| 332 |
+
self.out_channels = out_channels
|
| 333 |
+
self.factor_t = factor_t
|
| 334 |
+
self.factor_s = factor_s
|
| 335 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 336 |
+
|
| 337 |
+
assert in_channels * self.factor % out_channels == 0
|
| 338 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 339 |
+
|
| 340 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 341 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 342 |
+
pad = (0, 0, 0, 0, pad_t, 0)
|
| 343 |
+
x = F.pad(x, pad)
|
| 344 |
+
B, C, T, H, W = x.shape
|
| 345 |
+
x = x.view(
|
| 346 |
+
B,
|
| 347 |
+
C,
|
| 348 |
+
T // self.factor_t,
|
| 349 |
+
self.factor_t,
|
| 350 |
+
H // self.factor_s,
|
| 351 |
+
self.factor_s,
|
| 352 |
+
W // self.factor_s,
|
| 353 |
+
self.factor_s,
|
| 354 |
+
)
|
| 355 |
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
| 356 |
+
x = x.view(
|
| 357 |
+
B,
|
| 358 |
+
C * self.factor,
|
| 359 |
+
T // self.factor_t,
|
| 360 |
+
H // self.factor_s,
|
| 361 |
+
W // self.factor_s,
|
| 362 |
+
)
|
| 363 |
+
x = x.view(
|
| 364 |
+
B,
|
| 365 |
+
self.out_channels,
|
| 366 |
+
self.group_size,
|
| 367 |
+
T // self.factor_t,
|
| 368 |
+
H // self.factor_s,
|
| 369 |
+
W // self.factor_s,
|
| 370 |
+
)
|
| 371 |
+
x = x.mean(dim=2)
|
| 372 |
+
return x
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class DupUp3D(nn.Module):
|
| 376 |
+
|
| 377 |
+
def __init__(
|
| 378 |
+
self,
|
| 379 |
+
in_channels: int,
|
| 380 |
+
out_channels: int,
|
| 381 |
+
factor_t,
|
| 382 |
+
factor_s=1,
|
| 383 |
+
):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.in_channels = in_channels
|
| 386 |
+
self.out_channels = out_channels
|
| 387 |
+
|
| 388 |
+
self.factor_t = factor_t
|
| 389 |
+
self.factor_s = factor_s
|
| 390 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 391 |
+
|
| 392 |
+
assert out_channels * self.factor % in_channels == 0
|
| 393 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 394 |
+
|
| 395 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 396 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 397 |
+
x = x.view(
|
| 398 |
+
x.size(0),
|
| 399 |
+
self.out_channels,
|
| 400 |
+
self.factor_t,
|
| 401 |
+
self.factor_s,
|
| 402 |
+
self.factor_s,
|
| 403 |
+
x.size(2),
|
| 404 |
+
x.size(3),
|
| 405 |
+
x.size(4),
|
| 406 |
+
)
|
| 407 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 408 |
+
x = x.view(
|
| 409 |
+
x.size(0),
|
| 410 |
+
self.out_channels,
|
| 411 |
+
x.size(2) * self.factor_t,
|
| 412 |
+
x.size(4) * self.factor_s,
|
| 413 |
+
x.size(6) * self.factor_s,
|
| 414 |
+
)
|
| 415 |
+
if first_chunk:
|
| 416 |
+
x = x[:, :, self.factor_t - 1:, :, :]
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class Down_ResidualBlock(nn.Module):
|
| 421 |
+
|
| 422 |
+
def __init__(self,
|
| 423 |
+
in_dim,
|
| 424 |
+
out_dim,
|
| 425 |
+
dropout,
|
| 426 |
+
mult,
|
| 427 |
+
temperal_downsample=False,
|
| 428 |
+
down_flag=False):
|
| 429 |
+
super().__init__()
|
| 430 |
+
|
| 431 |
+
# Shortcut path with downsample
|
| 432 |
+
self.avg_shortcut = AvgDown3D(
|
| 433 |
+
in_dim,
|
| 434 |
+
out_dim,
|
| 435 |
+
factor_t=2 if temperal_downsample else 1,
|
| 436 |
+
factor_s=2 if down_flag else 1,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Main path with residual blocks and downsample
|
| 440 |
+
downsamples = []
|
| 441 |
+
for _ in range(mult):
|
| 442 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 443 |
+
in_dim = out_dim
|
| 444 |
+
|
| 445 |
+
# Add the final downsample block
|
| 446 |
+
if down_flag:
|
| 447 |
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
| 448 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 449 |
+
|
| 450 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 451 |
+
|
| 452 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 453 |
+
x_copy = x.clone()
|
| 454 |
+
for module in self.downsamples:
|
| 455 |
+
x = module(x, feat_cache, feat_idx)
|
| 456 |
+
|
| 457 |
+
return x + self.avg_shortcut(x_copy)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class Up_ResidualBlock(nn.Module):
|
| 461 |
+
|
| 462 |
+
def __init__(self,
|
| 463 |
+
in_dim,
|
| 464 |
+
out_dim,
|
| 465 |
+
dropout,
|
| 466 |
+
mult,
|
| 467 |
+
temperal_upsample=False,
|
| 468 |
+
up_flag=False):
|
| 469 |
+
super().__init__()
|
| 470 |
+
# Shortcut path with upsample
|
| 471 |
+
if up_flag:
|
| 472 |
+
self.avg_shortcut = DupUp3D(
|
| 473 |
+
in_dim,
|
| 474 |
+
out_dim,
|
| 475 |
+
factor_t=2 if temperal_upsample else 1,
|
| 476 |
+
factor_s=2 if up_flag else 1,
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
self.avg_shortcut = None
|
| 480 |
+
|
| 481 |
+
# Main path with residual blocks and upsample
|
| 482 |
+
upsamples = []
|
| 483 |
+
for _ in range(mult):
|
| 484 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 485 |
+
in_dim = out_dim
|
| 486 |
+
|
| 487 |
+
# Add the final upsample block
|
| 488 |
+
if up_flag:
|
| 489 |
+
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
| 490 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 491 |
+
|
| 492 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 493 |
+
|
| 494 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 495 |
+
x_main = x.clone()
|
| 496 |
+
for module in self.upsamples:
|
| 497 |
+
x_main = module(x_main, feat_cache, feat_idx)
|
| 498 |
+
if self.avg_shortcut is not None:
|
| 499 |
+
x_shortcut = self.avg_shortcut(x, first_chunk)
|
| 500 |
+
return x_main + x_shortcut
|
| 501 |
+
else:
|
| 502 |
+
return x_main
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class Encoder3d(nn.Module):
|
| 506 |
+
|
| 507 |
+
def __init__(
|
| 508 |
+
self,
|
| 509 |
+
dim=128,
|
| 510 |
+
z_dim=4,
|
| 511 |
+
dim_mult=[1, 2, 4, 4],
|
| 512 |
+
num_res_blocks=2,
|
| 513 |
+
attn_scales=[],
|
| 514 |
+
temperal_downsample=[True, True, False],
|
| 515 |
+
dropout=0.0,
|
| 516 |
+
):
|
| 517 |
+
super().__init__()
|
| 518 |
+
self.dim = dim
|
| 519 |
+
self.z_dim = z_dim
|
| 520 |
+
self.dim_mult = dim_mult
|
| 521 |
+
self.num_res_blocks = num_res_blocks
|
| 522 |
+
self.attn_scales = attn_scales
|
| 523 |
+
self.temperal_downsample = temperal_downsample
|
| 524 |
+
|
| 525 |
+
# dimensions
|
| 526 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 527 |
+
scale = 1.0
|
| 528 |
+
|
| 529 |
+
# init block
|
| 530 |
+
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
| 531 |
+
|
| 532 |
+
# downsample blocks
|
| 533 |
+
downsamples = []
|
| 534 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 535 |
+
t_down_flag = (
|
| 536 |
+
temperal_downsample[i]
|
| 537 |
+
if i < len(temperal_downsample) else False)
|
| 538 |
+
downsamples.append(
|
| 539 |
+
Down_ResidualBlock(
|
| 540 |
+
in_dim=in_dim,
|
| 541 |
+
out_dim=out_dim,
|
| 542 |
+
dropout=dropout,
|
| 543 |
+
mult=num_res_blocks,
|
| 544 |
+
temperal_downsample=t_down_flag,
|
| 545 |
+
down_flag=i != len(dim_mult) - 1,
|
| 546 |
+
))
|
| 547 |
+
scale /= 2.0
|
| 548 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 549 |
+
|
| 550 |
+
# middle blocks
|
| 551 |
+
self.middle = nn.Sequential(
|
| 552 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 553 |
+
AttentionBlock(out_dim),
|
| 554 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# # output blocks
|
| 558 |
+
self.head = nn.Sequential(
|
| 559 |
+
RMS_norm(out_dim, images=False),
|
| 560 |
+
nn.SiLU(),
|
| 561 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 565 |
+
|
| 566 |
+
if feat_cache is not None:
|
| 567 |
+
idx = feat_idx[0]
|
| 568 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 569 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 570 |
+
cache_x = torch.cat(
|
| 571 |
+
[
|
| 572 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 573 |
+
cache_x.device),
|
| 574 |
+
cache_x,
|
| 575 |
+
],
|
| 576 |
+
dim=2,
|
| 577 |
+
)
|
| 578 |
+
x = self.conv1(x, feat_cache[idx])
|
| 579 |
+
feat_cache[idx] = cache_x
|
| 580 |
+
feat_idx[0] += 1
|
| 581 |
+
else:
|
| 582 |
+
x = self.conv1(x)
|
| 583 |
+
|
| 584 |
+
## downsamples
|
| 585 |
+
for layer in self.downsamples:
|
| 586 |
+
if feat_cache is not None:
|
| 587 |
+
x = layer(x, feat_cache, feat_idx)
|
| 588 |
+
else:
|
| 589 |
+
x = layer(x)
|
| 590 |
+
|
| 591 |
+
## middle
|
| 592 |
+
for layer in self.middle:
|
| 593 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 594 |
+
x = layer(x, feat_cache, feat_idx)
|
| 595 |
+
else:
|
| 596 |
+
x = layer(x)
|
| 597 |
+
|
| 598 |
+
## head
|
| 599 |
+
for layer in self.head:
|
| 600 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 601 |
+
idx = feat_idx[0]
|
| 602 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 603 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 604 |
+
cache_x = torch.cat(
|
| 605 |
+
[
|
| 606 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 607 |
+
cache_x.device),
|
| 608 |
+
cache_x,
|
| 609 |
+
],
|
| 610 |
+
dim=2,
|
| 611 |
+
)
|
| 612 |
+
x = layer(x, feat_cache[idx])
|
| 613 |
+
feat_cache[idx] = cache_x
|
| 614 |
+
feat_idx[0] += 1
|
| 615 |
+
else:
|
| 616 |
+
x = layer(x)
|
| 617 |
+
|
| 618 |
+
return x
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class Decoder3d(nn.Module):
|
| 622 |
+
|
| 623 |
+
def __init__(
|
| 624 |
+
self,
|
| 625 |
+
dim=128,
|
| 626 |
+
z_dim=4,
|
| 627 |
+
dim_mult=[1, 2, 4, 4],
|
| 628 |
+
num_res_blocks=2,
|
| 629 |
+
attn_scales=[],
|
| 630 |
+
temperal_upsample=[False, True, True],
|
| 631 |
+
dropout=0.0,
|
| 632 |
+
):
|
| 633 |
+
super().__init__()
|
| 634 |
+
self.dim = dim
|
| 635 |
+
self.z_dim = z_dim
|
| 636 |
+
self.dim_mult = dim_mult
|
| 637 |
+
self.num_res_blocks = num_res_blocks
|
| 638 |
+
self.attn_scales = attn_scales
|
| 639 |
+
self.temperal_upsample = temperal_upsample
|
| 640 |
+
|
| 641 |
+
# dimensions
|
| 642 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 643 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 644 |
+
# init block
|
| 645 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 646 |
+
|
| 647 |
+
# middle blocks
|
| 648 |
+
self.middle = nn.Sequential(
|
| 649 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 650 |
+
AttentionBlock(dims[0]),
|
| 651 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# upsample blocks
|
| 655 |
+
upsamples = []
|
| 656 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 657 |
+
t_up_flag = temperal_upsample[i] if i < len(
|
| 658 |
+
temperal_upsample) else False
|
| 659 |
+
upsamples.append(
|
| 660 |
+
Up_ResidualBlock(
|
| 661 |
+
in_dim=in_dim,
|
| 662 |
+
out_dim=out_dim,
|
| 663 |
+
dropout=dropout,
|
| 664 |
+
mult=num_res_blocks + 1,
|
| 665 |
+
temperal_upsample=t_up_flag,
|
| 666 |
+
up_flag=i != len(dim_mult) - 1,
|
| 667 |
+
))
|
| 668 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 669 |
+
|
| 670 |
+
# output blocks
|
| 671 |
+
self.head = nn.Sequential(
|
| 672 |
+
RMS_norm(out_dim, images=False),
|
| 673 |
+
nn.SiLU(),
|
| 674 |
+
CausalConv3d(out_dim, 12, 3, padding=1),
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 678 |
+
if feat_cache is not None:
|
| 679 |
+
idx = feat_idx[0]
|
| 680 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 681 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 682 |
+
cache_x = torch.cat(
|
| 683 |
+
[
|
| 684 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 685 |
+
cache_x.device),
|
| 686 |
+
cache_x,
|
| 687 |
+
],
|
| 688 |
+
dim=2,
|
| 689 |
+
)
|
| 690 |
+
x = self.conv1(x, feat_cache[idx])
|
| 691 |
+
feat_cache[idx] = cache_x
|
| 692 |
+
feat_idx[0] += 1
|
| 693 |
+
else:
|
| 694 |
+
x = self.conv1(x)
|
| 695 |
+
|
| 696 |
+
for layer in self.middle:
|
| 697 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 698 |
+
x = layer(x, feat_cache, feat_idx)
|
| 699 |
+
else:
|
| 700 |
+
x = layer(x)
|
| 701 |
+
|
| 702 |
+
## upsamples
|
| 703 |
+
for layer in self.upsamples:
|
| 704 |
+
if feat_cache is not None:
|
| 705 |
+
x = layer(x, feat_cache, feat_idx, first_chunk)
|
| 706 |
+
else:
|
| 707 |
+
x = layer(x)
|
| 708 |
+
|
| 709 |
+
## head
|
| 710 |
+
for layer in self.head:
|
| 711 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 712 |
+
idx = feat_idx[0]
|
| 713 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 714 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 715 |
+
cache_x = torch.cat(
|
| 716 |
+
[
|
| 717 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 718 |
+
cache_x.device),
|
| 719 |
+
cache_x,
|
| 720 |
+
],
|
| 721 |
+
dim=2,
|
| 722 |
+
)
|
| 723 |
+
x = layer(x, feat_cache[idx])
|
| 724 |
+
feat_cache[idx] = cache_x
|
| 725 |
+
feat_idx[0] += 1
|
| 726 |
+
else:
|
| 727 |
+
x = layer(x)
|
| 728 |
+
return x
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def count_conv3d(model):
|
| 732 |
+
count = 0
|
| 733 |
+
for m in model.modules():
|
| 734 |
+
if isinstance(m, CausalConv3d):
|
| 735 |
+
count += 1
|
| 736 |
+
return count
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class AutoencoderKLWan2_2_(nn.Module):
|
| 740 |
+
|
| 741 |
+
def __init__(
|
| 742 |
+
self,
|
| 743 |
+
dim=160,
|
| 744 |
+
dec_dim=256,
|
| 745 |
+
z_dim=16,
|
| 746 |
+
dim_mult=[1, 2, 4, 4],
|
| 747 |
+
num_res_blocks=2,
|
| 748 |
+
attn_scales=[],
|
| 749 |
+
temperal_downsample=[True, True, False],
|
| 750 |
+
dropout=0.0,
|
| 751 |
+
):
|
| 752 |
+
super().__init__()
|
| 753 |
+
self.dim = dim
|
| 754 |
+
self.z_dim = z_dim
|
| 755 |
+
self.dim_mult = dim_mult
|
| 756 |
+
self.num_res_blocks = num_res_blocks
|
| 757 |
+
self.attn_scales = attn_scales
|
| 758 |
+
self.temperal_downsample = temperal_downsample
|
| 759 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 760 |
+
|
| 761 |
+
# modules
|
| 762 |
+
self.encoder = Encoder3d(
|
| 763 |
+
dim,
|
| 764 |
+
z_dim * 2,
|
| 765 |
+
dim_mult,
|
| 766 |
+
num_res_blocks,
|
| 767 |
+
attn_scales,
|
| 768 |
+
self.temperal_downsample,
|
| 769 |
+
dropout,
|
| 770 |
+
)
|
| 771 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 772 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 773 |
+
self.decoder = Decoder3d(
|
| 774 |
+
dec_dim,
|
| 775 |
+
z_dim,
|
| 776 |
+
dim_mult,
|
| 777 |
+
num_res_blocks,
|
| 778 |
+
attn_scales,
|
| 779 |
+
self.temperal_upsample,
|
| 780 |
+
dropout,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
def forward(self, x, scale=[0, 1]):
|
| 784 |
+
mu = self.encode(x, scale)
|
| 785 |
+
x_recon = self.decode(mu, scale)
|
| 786 |
+
return x_recon, mu
|
| 787 |
+
|
| 788 |
+
def encode(self, x, scale):
|
| 789 |
+
self.clear_cache()
|
| 790 |
+
# z: [b,c,t,h,w]
|
| 791 |
+
scale = [item.to(x.device, x.dtype) for item in scale]
|
| 792 |
+
x = patchify(x, patch_size=2)
|
| 793 |
+
t = x.shape[2]
|
| 794 |
+
iter_ = 1 + (t - 1) // 4
|
| 795 |
+
for i in range(iter_):
|
| 796 |
+
self._enc_conv_idx = [0]
|
| 797 |
+
if i == 0:
|
| 798 |
+
out = self.encoder(
|
| 799 |
+
x[:, :, :1, :, :],
|
| 800 |
+
feat_cache=self._enc_feat_map,
|
| 801 |
+
feat_idx=self._enc_conv_idx,
|
| 802 |
+
)
|
| 803 |
+
else:
|
| 804 |
+
out_ = self.encoder(
|
| 805 |
+
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
| 806 |
+
feat_cache=self._enc_feat_map,
|
| 807 |
+
feat_idx=self._enc_conv_idx,
|
| 808 |
+
)
|
| 809 |
+
out = torch.cat([out, out_], 2)
|
| 810 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 811 |
+
if isinstance(scale[0], torch.Tensor):
|
| 812 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 813 |
+
1, self.z_dim, 1, 1, 1)
|
| 814 |
+
else:
|
| 815 |
+
mu = (mu - scale[0]) * scale[1]
|
| 816 |
+
x = torch.cat([mu, log_var], dim = 1)
|
| 817 |
+
self.clear_cache()
|
| 818 |
+
return x
|
| 819 |
+
|
| 820 |
+
def decode(self, z, scale):
|
| 821 |
+
self.clear_cache()
|
| 822 |
+
# z: [b,c,t,h,w]
|
| 823 |
+
scale = [item.to(z.device, z.dtype) for item in scale]
|
| 824 |
+
if isinstance(scale[0], torch.Tensor):
|
| 825 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 826 |
+
1, self.z_dim, 1, 1, 1)
|
| 827 |
+
else:
|
| 828 |
+
z = z / scale[1] + scale[0]
|
| 829 |
+
iter_ = z.shape[2]
|
| 830 |
+
x = self.conv2(z)
|
| 831 |
+
for i in range(iter_):
|
| 832 |
+
self._conv_idx = [0]
|
| 833 |
+
if i == 0:
|
| 834 |
+
out = self.decoder(
|
| 835 |
+
x[:, :, i:i + 1, :, :],
|
| 836 |
+
feat_cache=self._feat_map,
|
| 837 |
+
feat_idx=self._conv_idx,
|
| 838 |
+
first_chunk=True,
|
| 839 |
+
)
|
| 840 |
+
else:
|
| 841 |
+
out_ = self.decoder(
|
| 842 |
+
x[:, :, i:i + 1, :, :],
|
| 843 |
+
feat_cache=self._feat_map,
|
| 844 |
+
feat_idx=self._conv_idx,
|
| 845 |
+
)
|
| 846 |
+
out = torch.cat([out, out_], 2)
|
| 847 |
+
out = unpatchify(out, patch_size=2)
|
| 848 |
+
self.clear_cache()
|
| 849 |
+
return out
|
| 850 |
+
|
| 851 |
+
def reparameterize(self, mu, log_var):
|
| 852 |
+
std = torch.exp(0.5 * log_var)
|
| 853 |
+
eps = torch.randn_like(std)
|
| 854 |
+
return eps * std + mu
|
| 855 |
+
|
| 856 |
+
def sample(self, imgs, deterministic=False):
|
| 857 |
+
mu, log_var = self.encode(imgs)
|
| 858 |
+
if deterministic:
|
| 859 |
+
return mu
|
| 860 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 861 |
+
return mu + std * torch.randn_like(std)
|
| 862 |
+
|
| 863 |
+
def clear_cache(self):
|
| 864 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 865 |
+
self._conv_idx = [0]
|
| 866 |
+
self._feat_map = [None] * self._conv_num
|
| 867 |
+
# cache encode
|
| 868 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 869 |
+
self._enc_conv_idx = [0]
|
| 870 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
|
| 874 |
+
# params
|
| 875 |
+
cfg = dict(
|
| 876 |
+
dim=dim,
|
| 877 |
+
z_dim=z_dim,
|
| 878 |
+
dim_mult=[1, 2, 4, 4],
|
| 879 |
+
num_res_blocks=2,
|
| 880 |
+
attn_scales=[],
|
| 881 |
+
temperal_downsample=[True, True, True],
|
| 882 |
+
dropout=0.0,
|
| 883 |
+
)
|
| 884 |
+
cfg.update(**kwargs)
|
| 885 |
+
|
| 886 |
+
# init model
|
| 887 |
+
model = AutoencoderKLWan2_2_(**cfg)
|
| 888 |
+
|
| 889 |
+
return model
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
class AutoencoderKLWan3_8(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 893 |
+
|
| 894 |
+
@register_to_config
|
| 895 |
+
def __init__(
|
| 896 |
+
self,
|
| 897 |
+
latent_channels=48,
|
| 898 |
+
c_dim=160,
|
| 899 |
+
vae_pth=None,
|
| 900 |
+
dim_mult=[1, 2, 4, 4],
|
| 901 |
+
temperal_downsample=[False, True, True],
|
| 902 |
+
temporal_compression_ratio=4,
|
| 903 |
+
spatial_compression_ratio=8
|
| 904 |
+
):
|
| 905 |
+
super().__init__()
|
| 906 |
+
mean = torch.tensor(
|
| 907 |
+
[
|
| 908 |
+
-0.2289,
|
| 909 |
+
-0.0052,
|
| 910 |
+
-0.1323,
|
| 911 |
+
-0.2339,
|
| 912 |
+
-0.2799,
|
| 913 |
+
0.0174,
|
| 914 |
+
0.1838,
|
| 915 |
+
0.1557,
|
| 916 |
+
-0.1382,
|
| 917 |
+
0.0542,
|
| 918 |
+
0.2813,
|
| 919 |
+
0.0891,
|
| 920 |
+
0.1570,
|
| 921 |
+
-0.0098,
|
| 922 |
+
0.0375,
|
| 923 |
+
-0.1825,
|
| 924 |
+
-0.2246,
|
| 925 |
+
-0.1207,
|
| 926 |
+
-0.0698,
|
| 927 |
+
0.5109,
|
| 928 |
+
0.2665,
|
| 929 |
+
-0.2108,
|
| 930 |
+
-0.2158,
|
| 931 |
+
0.2502,
|
| 932 |
+
-0.2055,
|
| 933 |
+
-0.0322,
|
| 934 |
+
0.1109,
|
| 935 |
+
0.1567,
|
| 936 |
+
-0.0729,
|
| 937 |
+
0.0899,
|
| 938 |
+
-0.2799,
|
| 939 |
+
-0.1230,
|
| 940 |
+
-0.0313,
|
| 941 |
+
-0.1649,
|
| 942 |
+
0.0117,
|
| 943 |
+
0.0723,
|
| 944 |
+
-0.2839,
|
| 945 |
+
-0.2083,
|
| 946 |
+
-0.0520,
|
| 947 |
+
0.3748,
|
| 948 |
+
0.0152,
|
| 949 |
+
0.1957,
|
| 950 |
+
0.1433,
|
| 951 |
+
-0.2944,
|
| 952 |
+
0.3573,
|
| 953 |
+
-0.0548,
|
| 954 |
+
-0.1681,
|
| 955 |
+
-0.0667,
|
| 956 |
+
], dtype=torch.float32
|
| 957 |
+
)
|
| 958 |
+
std = torch.tensor(
|
| 959 |
+
[
|
| 960 |
+
0.4765,
|
| 961 |
+
1.0364,
|
| 962 |
+
0.4514,
|
| 963 |
+
1.1677,
|
| 964 |
+
0.5313,
|
| 965 |
+
0.4990,
|
| 966 |
+
0.4818,
|
| 967 |
+
0.5013,
|
| 968 |
+
0.8158,
|
| 969 |
+
1.0344,
|
| 970 |
+
0.5894,
|
| 971 |
+
1.0901,
|
| 972 |
+
0.6885,
|
| 973 |
+
0.6165,
|
| 974 |
+
0.8454,
|
| 975 |
+
0.4978,
|
| 976 |
+
0.5759,
|
| 977 |
+
0.3523,
|
| 978 |
+
0.7135,
|
| 979 |
+
0.6804,
|
| 980 |
+
0.5833,
|
| 981 |
+
1.4146,
|
| 982 |
+
0.8986,
|
| 983 |
+
0.5659,
|
| 984 |
+
0.7069,
|
| 985 |
+
0.5338,
|
| 986 |
+
0.4889,
|
| 987 |
+
0.4917,
|
| 988 |
+
0.4069,
|
| 989 |
+
0.4999,
|
| 990 |
+
0.6866,
|
| 991 |
+
0.4093,
|
| 992 |
+
0.5709,
|
| 993 |
+
0.6065,
|
| 994 |
+
0.6415,
|
| 995 |
+
0.4944,
|
| 996 |
+
0.5726,
|
| 997 |
+
1.2042,
|
| 998 |
+
0.5458,
|
| 999 |
+
1.6887,
|
| 1000 |
+
0.3971,
|
| 1001 |
+
1.0600,
|
| 1002 |
+
0.3943,
|
| 1003 |
+
0.5537,
|
| 1004 |
+
0.5444,
|
| 1005 |
+
0.4089,
|
| 1006 |
+
0.7468,
|
| 1007 |
+
0.7744,
|
| 1008 |
+
], dtype=torch.float32
|
| 1009 |
+
)
|
| 1010 |
+
self.scale = [mean, 1.0 / std]
|
| 1011 |
+
|
| 1012 |
+
# init model
|
| 1013 |
+
self.model = _video_vae(
|
| 1014 |
+
pretrained_path=vae_pth,
|
| 1015 |
+
z_dim=latent_channels,
|
| 1016 |
+
dim=c_dim,
|
| 1017 |
+
dim_mult=dim_mult,
|
| 1018 |
+
temperal_downsample=temperal_downsample,
|
| 1019 |
+
).eval().requires_grad_(False)
|
| 1020 |
+
|
| 1021 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 1022 |
+
x = [
|
| 1023 |
+
self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
|
| 1024 |
+
for u in x
|
| 1025 |
+
]
|
| 1026 |
+
x = torch.stack(x)
|
| 1027 |
+
return x
|
| 1028 |
+
|
| 1029 |
+
@apply_forward_hook
|
| 1030 |
+
def encode(
|
| 1031 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 1032 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 1033 |
+
h = self._encode(x)
|
| 1034 |
+
|
| 1035 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 1036 |
+
|
| 1037 |
+
if not return_dict:
|
| 1038 |
+
return (posterior,)
|
| 1039 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 1040 |
+
|
| 1041 |
+
def _decode(self, zs):
|
| 1042 |
+
dec = [
|
| 1043 |
+
self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
|
| 1044 |
+
for u in zs
|
| 1045 |
+
]
|
| 1046 |
+
dec = torch.stack(dec)
|
| 1047 |
+
|
| 1048 |
+
return DecoderOutput(sample=dec)
|
| 1049 |
+
|
| 1050 |
+
@apply_forward_hook
|
| 1051 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1052 |
+
decoded = self._decode(z).sample
|
| 1053 |
+
|
| 1054 |
+
if not return_dict:
|
| 1055 |
+
return (decoded,)
|
| 1056 |
+
return DecoderOutput(sample=decoded)
|
| 1057 |
+
|
| 1058 |
+
@classmethod
|
| 1059 |
+
def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
|
| 1060 |
+
def filter_kwargs(cls, kwargs):
|
| 1061 |
+
import inspect
|
| 1062 |
+
sig = inspect.signature(cls.__init__)
|
| 1063 |
+
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
|
| 1064 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
| 1065 |
+
return filtered_kwargs
|
| 1066 |
+
|
| 1067 |
+
model = cls(**filter_kwargs(cls, additional_kwargs))
|
| 1068 |
+
if pretrained_model_path.endswith(".safetensors"):
|
| 1069 |
+
from safetensors.torch import load_file, safe_open
|
| 1070 |
+
state_dict = load_file(pretrained_model_path)
|
| 1071 |
+
else:
|
| 1072 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
| 1073 |
+
tmp_state_dict = {}
|
| 1074 |
+
for key in state_dict:
|
| 1075 |
+
tmp_state_dict["model." + key] = state_dict[key]
|
| 1076 |
+
state_dict = tmp_state_dict
|
| 1077 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 1078 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 1079 |
+
print(m, u)
|
| 1080 |
+
return model
|
videox_fun/models/wan_xlm_roberta.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
__all__ = ['XLMRoberta', 'xlm_roberta_large']
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SelfAttention(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
| 13 |
+
assert dim % num_heads == 0
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.dim = dim
|
| 16 |
+
self.num_heads = num_heads
|
| 17 |
+
self.head_dim = dim // num_heads
|
| 18 |
+
self.eps = eps
|
| 19 |
+
|
| 20 |
+
# layers
|
| 21 |
+
self.q = nn.Linear(dim, dim)
|
| 22 |
+
self.k = nn.Linear(dim, dim)
|
| 23 |
+
self.v = nn.Linear(dim, dim)
|
| 24 |
+
self.o = nn.Linear(dim, dim)
|
| 25 |
+
self.dropout = nn.Dropout(dropout)
|
| 26 |
+
|
| 27 |
+
def forward(self, x, mask):
|
| 28 |
+
"""
|
| 29 |
+
x: [B, L, C].
|
| 30 |
+
"""
|
| 31 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 32 |
+
|
| 33 |
+
# compute query, key, value
|
| 34 |
+
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 35 |
+
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 36 |
+
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 37 |
+
|
| 38 |
+
# compute attention
|
| 39 |
+
p = self.dropout.p if self.training else 0.0
|
| 40 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
| 41 |
+
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
| 42 |
+
|
| 43 |
+
# output
|
| 44 |
+
x = self.o(x)
|
| 45 |
+
x = self.dropout(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AttentionBlock(nn.Module):
|
| 50 |
+
|
| 51 |
+
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.dim = dim
|
| 54 |
+
self.num_heads = num_heads
|
| 55 |
+
self.post_norm = post_norm
|
| 56 |
+
self.eps = eps
|
| 57 |
+
|
| 58 |
+
# layers
|
| 59 |
+
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
| 60 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
| 61 |
+
self.ffn = nn.Sequential(
|
| 62 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
| 63 |
+
nn.Dropout(dropout))
|
| 64 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
| 65 |
+
|
| 66 |
+
def forward(self, x, mask):
|
| 67 |
+
if self.post_norm:
|
| 68 |
+
x = self.norm1(x + self.attn(x, mask))
|
| 69 |
+
x = self.norm2(x + self.ffn(x))
|
| 70 |
+
else:
|
| 71 |
+
x = x + self.attn(self.norm1(x), mask)
|
| 72 |
+
x = x + self.ffn(self.norm2(x))
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class XLMRoberta(nn.Module):
|
| 77 |
+
"""
|
| 78 |
+
XLMRobertaModel with no pooler and no LM head.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self,
|
| 82 |
+
vocab_size=250002,
|
| 83 |
+
max_seq_len=514,
|
| 84 |
+
type_size=1,
|
| 85 |
+
pad_id=1,
|
| 86 |
+
dim=1024,
|
| 87 |
+
num_heads=16,
|
| 88 |
+
num_layers=24,
|
| 89 |
+
post_norm=True,
|
| 90 |
+
dropout=0.1,
|
| 91 |
+
eps=1e-5):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.vocab_size = vocab_size
|
| 94 |
+
self.max_seq_len = max_seq_len
|
| 95 |
+
self.type_size = type_size
|
| 96 |
+
self.pad_id = pad_id
|
| 97 |
+
self.dim = dim
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
self.num_layers = num_layers
|
| 100 |
+
self.post_norm = post_norm
|
| 101 |
+
self.eps = eps
|
| 102 |
+
|
| 103 |
+
# embeddings
|
| 104 |
+
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
| 105 |
+
self.type_embedding = nn.Embedding(type_size, dim)
|
| 106 |
+
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
| 107 |
+
self.dropout = nn.Dropout(dropout)
|
| 108 |
+
|
| 109 |
+
# blocks
|
| 110 |
+
self.blocks = nn.ModuleList([
|
| 111 |
+
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
| 112 |
+
for _ in range(num_layers)
|
| 113 |
+
])
|
| 114 |
+
|
| 115 |
+
# norm layer
|
| 116 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
| 117 |
+
|
| 118 |
+
def forward(self, ids):
|
| 119 |
+
"""
|
| 120 |
+
ids: [B, L] of torch.LongTensor.
|
| 121 |
+
"""
|
| 122 |
+
b, s = ids.shape
|
| 123 |
+
mask = ids.ne(self.pad_id).long()
|
| 124 |
+
|
| 125 |
+
# embeddings
|
| 126 |
+
x = self.token_embedding(ids) + \
|
| 127 |
+
self.type_embedding(torch.zeros_like(ids)) + \
|
| 128 |
+
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
| 129 |
+
if self.post_norm:
|
| 130 |
+
x = self.norm(x)
|
| 131 |
+
x = self.dropout(x)
|
| 132 |
+
|
| 133 |
+
# blocks
|
| 134 |
+
mask = torch.where(
|
| 135 |
+
mask.view(b, 1, 1, s).gt(0), 0.0,
|
| 136 |
+
torch.finfo(x.dtype).min)
|
| 137 |
+
for block in self.blocks:
|
| 138 |
+
x = block(x, mask)
|
| 139 |
+
|
| 140 |
+
# output
|
| 141 |
+
if not self.post_norm:
|
| 142 |
+
x = self.norm(x)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def xlm_roberta_large(pretrained=False,
|
| 147 |
+
return_tokenizer=False,
|
| 148 |
+
device='cpu',
|
| 149 |
+
**kwargs):
|
| 150 |
+
"""
|
| 151 |
+
XLMRobertaLarge adapted from Huggingface.
|
| 152 |
+
"""
|
| 153 |
+
# params
|
| 154 |
+
cfg = dict(
|
| 155 |
+
vocab_size=250002,
|
| 156 |
+
max_seq_len=514,
|
| 157 |
+
type_size=1,
|
| 158 |
+
pad_id=1,
|
| 159 |
+
dim=1024,
|
| 160 |
+
num_heads=16,
|
| 161 |
+
num_layers=24,
|
| 162 |
+
post_norm=True,
|
| 163 |
+
dropout=0.1,
|
| 164 |
+
eps=1e-5)
|
| 165 |
+
cfg.update(**kwargs)
|
| 166 |
+
|
| 167 |
+
# init a model on device
|
| 168 |
+
with torch.device(device):
|
| 169 |
+
model = XLMRoberta(**cfg)
|
| 170 |
+
return model
|