XiangpengYang commited on
Commit
6f25f9f
·
1 Parent(s): 3daab90
.gitignore CHANGED
@@ -1,4 +1,3 @@
1
  samples/
2
- models/
3
  __pycache__/
4
  *.pyc
 
1
  samples/
 
2
  __pycache__/
3
  *.pyc
videox_fun/models/__init__.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ from diffusers import AutoencoderKL
4
+ from transformers import (AutoTokenizer, CLIPImageProcessor, CLIPTextModel,
5
+ CLIPTokenizer, CLIPVisionModelWithProjection,
6
+ T5EncoderModel, T5Tokenizer, T5TokenizerFast)
7
+
8
+ try:
9
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
10
+ except:
11
+ Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
12
+ print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")
13
+
14
+ from .cogvideox_transformer3d import CogVideoXTransformer3DModel
15
+ from .cogvideox_vae import AutoencoderKLCogVideoX
16
+ from .flux_transformer2d import FluxTransformer2DModel
17
+ from .qwenimage_transformer2d import QwenImageTransformer2DModel
18
+ from .qwenimage_vae import AutoencoderKLQwenImage
19
+ # from .wan_audio_encoder import WanAudioEncoder
20
+ from .wan_image_encoder import CLIPModel
21
+ from .wan_text_encoder import WanT5EncoderModel
22
+ from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
23
+ WanSelfAttention, WanTransformer3DModel)
24
+ # from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
25
+ from .wan_transformer3d_vace import VaceWanTransformer3DModel
26
+ from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
27
+ from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
28
+
29
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
30
+ if importlib.util.find_spec("paifuser") is not None:
31
+ # --------------------------------------------------------------- #
32
+ # The simple_wrapper is used to solve the problem
33
+ # about conflicts between cython and torch.compile
34
+ # --------------------------------------------------------------- #
35
+ def simple_wrapper(func):
36
+ def inner(*args, **kwargs):
37
+ return func(*args, **kwargs)
38
+ return inner
39
+
40
+ # --------------------------------------------------------------- #
41
+ # VAE Parallel Kernel
42
+ # --------------------------------------------------------------- #
43
+ from ..dist import parallel_magvit_vae
44
+ AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
45
+ AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))
46
+
47
+ # --------------------------------------------------------------- #
48
+ # Sparse Attention
49
+ # --------------------------------------------------------------- #
50
+ import torch
51
+ from paifuser.ops import wan_sparse_attention_wrapper
52
+
53
+ WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
54
+ print("Import Sparse Attention")
55
+
56
+ WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)
57
+
58
+ # --------------------------------------------------------------- #
59
+ # CFG Skip Turbo
60
+ # --------------------------------------------------------------- #
61
+ import os
62
+
63
+ if importlib.util.find_spec("paifuser.accelerator") is not None:
64
+ from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
65
+ enable_cfg_skip, share_cfg_skip)
66
+ else:
67
+ from paifuser import (cfg_skip_turbo, disable_cfg_skip,
68
+ enable_cfg_skip, share_cfg_skip)
69
+
70
+ WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
71
+ WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
72
+ WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)
73
+ print("Import CFG Skip Turbo")
74
+
75
+ # --------------------------------------------------------------- #
76
+ # RMS Norm Kernel
77
+ # --------------------------------------------------------------- #
78
+ from paifuser.ops import rms_norm_forward
79
+ WanRMSNorm.forward = rms_norm_forward
80
+ print("Import PAI RMS Fuse")
81
+
82
+ # --------------------------------------------------------------- #
83
+ # Fast Rope Kernel
84
+ # --------------------------------------------------------------- #
85
+ import types
86
+
87
+ import torch
88
+ from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
89
+ rope_apply_real_qk)
90
+
91
+ from . import wan_transformer3d
92
+
93
+ def deepcopy_function(f):
94
+ return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
95
+
96
+ local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)
97
+
98
+ if ENABLE_KERNEL:
99
+ def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
100
+ if torch.is_grad_enabled():
101
+ return local_rope_apply_qk(q, k, grid_sizes, freqs)
102
+ else:
103
+ return fast_rope_apply_qk(q, k, grid_sizes, freqs)
104
+ else:
105
+ def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
106
+ return rope_apply_real_qk(q, k, grid_sizes, freqs)
107
+
108
+ wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
109
+ rope_apply_qk = adaptive_fast_rope_apply_qk
110
+ print("Import PAI Fast rope")
videox_fun/models/attention_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import warnings
5
+
6
+ try:
7
+ import flash_attn_interface
8
+ FLASH_ATTN_3_AVAILABLE = True
9
+ except ModuleNotFoundError:
10
+ FLASH_ATTN_3_AVAILABLE = False
11
+
12
+ try:
13
+ import flash_attn
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+ try:
19
+ major, minor = torch.cuda.get_device_capability(0)
20
+ if f"{major}.{minor}" == "8.0":
21
+ from sageattention_sm80 import sageattn
22
+ SAGE_ATTENTION_AVAILABLE = True
23
+ elif f"{major}.{minor}" == "8.6":
24
+ from sageattention_sm86 import sageattn
25
+ SAGE_ATTENTION_AVAILABLE = True
26
+ elif f"{major}.{minor}" == "8.9":
27
+ from sageattention_sm89 import sageattn
28
+ SAGE_ATTENTION_AVAILABLE = True
29
+ elif f"{major}.{minor}" == "9.0":
30
+ from sageattention_sm90 import sageattn
31
+ SAGE_ATTENTION_AVAILABLE = True
32
+ elif major>9:
33
+ from sageattention_sm120 import sageattn
34
+ SAGE_ATTENTION_AVAILABLE = True
35
+ except:
36
+ try:
37
+ from sageattention import sageattn
38
+ SAGE_ATTENTION_AVAILABLE = True
39
+ except:
40
+ sageattn = None
41
+ SAGE_ATTENTION_AVAILABLE = False
42
+
43
+ def flash_attention(
44
+ q,
45
+ k,
46
+ v,
47
+ q_lens=None,
48
+ k_lens=None,
49
+ dropout_p=0.,
50
+ softmax_scale=None,
51
+ q_scale=None,
52
+ causal=False,
53
+ window_size=(-1, -1),
54
+ deterministic=False,
55
+ dtype=torch.bfloat16,
56
+ version=None,
57
+ ):
58
+ """
59
+ q: [B, Lq, Nq, C1].
60
+ k: [B, Lk, Nk, C1].
61
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
62
+ q_lens: [B].
63
+ k_lens: [B].
64
+ dropout_p: float. Dropout probability.
65
+ softmax_scale: float. The scaling of QK^T before applying softmax.
66
+ causal: bool. Whether to apply causal attention mask.
67
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
68
+ deterministic: bool. If True, slightly slower and uses more memory.
69
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
70
+ """
71
+ half_dtypes = (torch.float16, torch.bfloat16)
72
+ assert dtype in half_dtypes
73
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
74
+
75
+ # params
76
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
77
+
78
+ def half(x):
79
+ return x if x.dtype in half_dtypes else x.to(dtype)
80
+
81
+ # preprocess query
82
+ if q_lens is None:
83
+ q = half(q.flatten(0, 1))
84
+ q_lens = torch.tensor(
85
+ [lq] * b, dtype=torch.int32).to(
86
+ device=q.device, non_blocking=True)
87
+ else:
88
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
89
+
90
+ # preprocess key, value
91
+ if k_lens is None:
92
+ k = half(k.flatten(0, 1))
93
+ v = half(v.flatten(0, 1))
94
+ k_lens = torch.tensor(
95
+ [lk] * b, dtype=torch.int32).to(
96
+ device=k.device, non_blocking=True)
97
+ else:
98
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
99
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
100
+
101
+ q = q.to(v.dtype)
102
+ k = k.to(v.dtype)
103
+
104
+ if q_scale is not None:
105
+ q = q * q_scale
106
+
107
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
108
+ warnings.warn(
109
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
110
+ )
111
+
112
+ # apply attention
113
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
114
+ # Note: dropout_p, window_size are not supported in FA3 now.
115
+ x = flash_attn_interface.flash_attn_varlen_func(
116
+ q=q,
117
+ k=k,
118
+ v=v,
119
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
120
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
121
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
122
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
123
+ seqused_q=None,
124
+ seqused_k=None,
125
+ max_seqlen_q=lq,
126
+ max_seqlen_k=lk,
127
+ softmax_scale=softmax_scale,
128
+ causal=causal,
129
+ deterministic=deterministic).unflatten(0, (b, lq))
130
+ else:
131
+ assert FLASH_ATTN_2_AVAILABLE
132
+ x = flash_attn.flash_attn_varlen_func(
133
+ q=q,
134
+ k=k,
135
+ v=v,
136
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
137
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
138
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
139
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
140
+ max_seqlen_q=lq,
141
+ max_seqlen_k=lk,
142
+ dropout_p=dropout_p,
143
+ softmax_scale=softmax_scale,
144
+ causal=causal,
145
+ window_size=window_size,
146
+ deterministic=deterministic).unflatten(0, (b, lq))
147
+
148
+ # output
149
+ return x.type(out_dtype)
150
+
151
+
152
+ def attention(
153
+ q,
154
+ k,
155
+ v,
156
+ q_lens=None,
157
+ k_lens=None,
158
+ dropout_p=0.,
159
+ softmax_scale=None,
160
+ q_scale=None,
161
+ causal=False,
162
+ window_size=(-1, -1),
163
+ deterministic=False,
164
+ dtype=torch.bfloat16,
165
+ fa_version=None,
166
+ attention_type=None,
167
+ attn_mask=None,
168
+ ):
169
+ attention_type = os.environ.get("VIDEOX_ATTENTION_TYPE", "FLASH_ATTENTION") if attention_type is None else attention_type
170
+ if torch.is_grad_enabled() and attention_type == "SAGE_ATTENTION":
171
+ attention_type = "FLASH_ATTENTION"
172
+
173
+ if attention_type == "SAGE_ATTENTION" and SAGE_ATTENTION_AVAILABLE:
174
+ if q_lens is not None or k_lens is not None:
175
+ warnings.warn(
176
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
177
+ )
178
+
179
+ out = sageattn(
180
+ q, k, v, attn_mask=attn_mask, tensor_layout="NHD", is_causal=causal, dropout_p=dropout_p)
181
+
182
+ elif attention_type == "FLASH_ATTENTION" and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE):
183
+ return flash_attention(
184
+ q=q,
185
+ k=k,
186
+ v=v,
187
+ q_lens=q_lens,
188
+ k_lens=k_lens,
189
+ dropout_p=dropout_p,
190
+ softmax_scale=softmax_scale,
191
+ q_scale=q_scale,
192
+ causal=causal,
193
+ window_size=window_size,
194
+ deterministic=deterministic,
195
+ dtype=dtype,
196
+ version=fa_version,
197
+ )
198
+ else:
199
+ if q_lens is not None or k_lens is not None:
200
+ warnings.warn(
201
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
202
+ )
203
+ q = q.transpose(1, 2)
204
+ k = k.transpose(1, 2)
205
+ v = v.transpose(1, 2)
206
+
207
+ out = torch.nn.functional.scaled_dot_product_attention(
208
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
209
+
210
+ out = out.transpose(1, 2).contiguous()
211
+ return out
videox_fun/models/cache_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def get_teacache_coefficients(model_name):
5
+ if "wan2.1-t2v-1.3b" in model_name.lower() or "wan2.1-fun-1.3b" in model_name.lower() \
6
+ or "wan2.1-fun-v1.1-1.3b" in model_name.lower() or "wan2.1-vace-1.3b" in model_name.lower():
7
+ return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
8
+ elif "wan2.1-t2v-14b" in model_name.lower():
9
+ return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
10
+ elif "wan2.1-i2v-14b-480p" in model_name.lower():
11
+ return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
12
+ elif "wan2.1-i2v-14b-720p" in model_name.lower() or "wan2.1-fun-14b" in model_name.lower() or "wan2.2-fun" in model_name.lower() \
13
+ or "wan2.2-i2v-a14b" in model_name.lower() or "wan2.2-t2v-a14b" in model_name.lower() or "wan2.2-ti2v-5b" in model_name.lower() \
14
+ or "wan2.2-s2v" in model_name.lower() or "wan2.1-vace-14b" in model_name.lower() or "wan2.2-vace-fun" in model_name.lower():
15
+ return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
16
+ else:
17
+ print(f"The model {model_name} is not supported by TeaCache.")
18
+ return None
19
+
20
+
21
+ class TeaCache():
22
+ """
23
+ Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
24
+ the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
25
+ Please refer to:
26
+ 1. https://github.com/ali-vilab/TeaCache.
27
+ 2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
28
+ """
29
+ def __init__(
30
+ self,
31
+ coefficients: list[float],
32
+ num_steps: int,
33
+ rel_l1_thresh: float = 0.0,
34
+ num_skip_start_steps: int = 0,
35
+ offload: bool = True,
36
+ ):
37
+ if num_steps < 1:
38
+ raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
39
+ if rel_l1_thresh < 0:
40
+ raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
41
+ if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
42
+ raise ValueError(
43
+ "`num_skip_start_steps` must be great than or equal to 0 and "
44
+ f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
45
+ )
46
+ self.coefficients = coefficients
47
+ self.num_steps = num_steps
48
+ self.rel_l1_thresh = rel_l1_thresh
49
+ self.num_skip_start_steps = num_skip_start_steps
50
+ self.offload = offload
51
+ self.rescale_func = np.poly1d(self.coefficients)
52
+
53
+ self.cnt = 0
54
+ self.should_calc = True
55
+ self.accumulated_rel_l1_distance = 0
56
+ self.previous_modulated_input = None
57
+ # Some pipelines concatenate the unconditional and text guide in forward.
58
+ self.previous_residual = None
59
+ # Some pipelines perform forward propagation separately on the unconditional and text guide.
60
+ self.previous_residual_cond = None
61
+ self.previous_residual_uncond = None
62
+
63
+ @staticmethod
64
+ def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
65
+ rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
66
+
67
+ return rel_l1_distance.cpu().item()
68
+
69
+ def reset(self):
70
+ self.cnt = 0
71
+ self.should_calc = True
72
+ self.accumulated_rel_l1_distance = 0
73
+ self.previous_modulated_input = None
74
+ self.previous_residual = None
75
+ self.previous_residual_cond = None
76
+ self.previous_residual_uncond = None
videox_fun/models/cogvideox_transformer3d.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.attention import Attention, FeedForward
25
+ from diffusers.models.attention_processor import (
26
+ AttentionProcessor, CogVideoXAttnProcessor2_0,
27
+ FusedCogVideoXAttnProcessor2_0)
28
+ from diffusers.models.embeddings import (CogVideoXPatchEmbed,
29
+ TimestepEmbedding, Timesteps,
30
+ get_3d_sincos_pos_embed)
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
34
+ from diffusers.utils import is_torch_version, logging
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from torch import nn
37
+
38
+ from ..dist import (get_sequence_parallel_rank,
39
+ get_sequence_parallel_world_size, get_sp_group,
40
+ xFuserLongContextAttention)
41
+ from ..dist.cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+
46
+ class CogVideoXPatchEmbed(nn.Module):
47
+ def __init__(
48
+ self,
49
+ patch_size: int = 2,
50
+ patch_size_t: Optional[int] = None,
51
+ in_channels: int = 16,
52
+ embed_dim: int = 1920,
53
+ text_embed_dim: int = 4096,
54
+ bias: bool = True,
55
+ sample_width: int = 90,
56
+ sample_height: int = 60,
57
+ sample_frames: int = 49,
58
+ temporal_compression_ratio: int = 4,
59
+ max_text_seq_length: int = 226,
60
+ spatial_interpolation_scale: float = 1.875,
61
+ temporal_interpolation_scale: float = 1.0,
62
+ use_positional_embeddings: bool = True,
63
+ use_learned_positional_embeddings: bool = True,
64
+ ) -> None:
65
+ super().__init__()
66
+
67
+ post_patch_height = sample_height // patch_size
68
+ post_patch_width = sample_width // patch_size
69
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
70
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
71
+ self.post_patch_height = post_patch_height
72
+ self.post_patch_width = post_patch_width
73
+ self.post_time_compression_frames = post_time_compression_frames
74
+ self.patch_size = patch_size
75
+ self.patch_size_t = patch_size_t
76
+ self.embed_dim = embed_dim
77
+ self.sample_height = sample_height
78
+ self.sample_width = sample_width
79
+ self.sample_frames = sample_frames
80
+ self.temporal_compression_ratio = temporal_compression_ratio
81
+ self.max_text_seq_length = max_text_seq_length
82
+ self.spatial_interpolation_scale = spatial_interpolation_scale
83
+ self.temporal_interpolation_scale = temporal_interpolation_scale
84
+ self.use_positional_embeddings = use_positional_embeddings
85
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
86
+
87
+ if patch_size_t is None:
88
+ # CogVideoX 1.0 checkpoints
89
+ self.proj = nn.Conv2d(
90
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
91
+ )
92
+ else:
93
+ # CogVideoX 1.5 checkpoints
94
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
95
+
96
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
97
+
98
+ if use_positional_embeddings or use_learned_positional_embeddings:
99
+ persistent = use_learned_positional_embeddings
100
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
101
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
102
+
103
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
104
+ post_patch_height = sample_height // self.patch_size
105
+ post_patch_width = sample_width // self.patch_size
106
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
107
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
108
+
109
+ pos_embedding = get_3d_sincos_pos_embed(
110
+ self.embed_dim,
111
+ (post_patch_width, post_patch_height),
112
+ post_time_compression_frames,
113
+ self.spatial_interpolation_scale,
114
+ self.temporal_interpolation_scale,
115
+ )
116
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
117
+ joint_pos_embedding = torch.zeros(
118
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
119
+ )
120
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
121
+
122
+ return joint_pos_embedding
123
+
124
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
125
+ r"""
126
+ Args:
127
+ text_embeds (`torch.Tensor`):
128
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
129
+ image_embeds (`torch.Tensor`):
130
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
131
+ """
132
+ text_embeds = self.text_proj(text_embeds)
133
+
134
+ text_batch_size, text_seq_length, text_channels = text_embeds.shape
135
+ batch_size, num_frames, channels, height, width = image_embeds.shape
136
+
137
+ if self.patch_size_t is None:
138
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
139
+ image_embeds = self.proj(image_embeds)
140
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
141
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
142
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
143
+ else:
144
+ p = self.patch_size
145
+ p_t = self.patch_size_t
146
+
147
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
148
+ # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
149
+ image_embeds = image_embeds.reshape(
150
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
151
+ )
152
+ # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
153
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
154
+ image_embeds = self.proj(image_embeds)
155
+
156
+ embeds = torch.cat(
157
+ [text_embeds, image_embeds], dim=1
158
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
159
+
160
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
161
+ seq_length = height * width * num_frames // (self.patch_size**2)
162
+ # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
163
+ pos_embeds = self.pos_embedding
164
+ emb_size = embeds.size()[-1]
165
+ pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
166
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
167
+ pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
168
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
169
+ pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
170
+ pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
171
+ embeds = embeds + pos_embeds
172
+
173
+ return embeds
174
+
175
+ @maybe_allow_in_graph
176
+ class CogVideoXBlock(nn.Module):
177
+ r"""
178
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
179
+
180
+ Parameters:
181
+ dim (`int`):
182
+ The number of channels in the input and output.
183
+ num_attention_heads (`int`):
184
+ The number of heads to use for multi-head attention.
185
+ attention_head_dim (`int`):
186
+ The number of channels in each head.
187
+ time_embed_dim (`int`):
188
+ The number of channels in timestep embedding.
189
+ dropout (`float`, defaults to `0.0`):
190
+ The dropout probability to use.
191
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
192
+ Activation function to be used in feed-forward.
193
+ attention_bias (`bool`, defaults to `False`):
194
+ Whether or not to use bias in attention projection layers.
195
+ qk_norm (`bool`, defaults to `True`):
196
+ Whether or not to use normalization after query and key projections in Attention.
197
+ norm_elementwise_affine (`bool`, defaults to `True`):
198
+ Whether to use learnable elementwise affine parameters for normalization.
199
+ norm_eps (`float`, defaults to `1e-5`):
200
+ Epsilon value for normalization layers.
201
+ final_dropout (`bool` defaults to `False`):
202
+ Whether to apply a final dropout after the last feed-forward layer.
203
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
204
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
205
+ ff_bias (`bool`, defaults to `True`):
206
+ Whether or not to use bias in Feed-forward layer.
207
+ attention_out_bias (`bool`, defaults to `True`):
208
+ Whether or not to use bias in Attention output projection layer.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ dim: int,
214
+ num_attention_heads: int,
215
+ attention_head_dim: int,
216
+ time_embed_dim: int,
217
+ dropout: float = 0.0,
218
+ activation_fn: str = "gelu-approximate",
219
+ attention_bias: bool = False,
220
+ qk_norm: bool = True,
221
+ norm_elementwise_affine: bool = True,
222
+ norm_eps: float = 1e-5,
223
+ final_dropout: bool = True,
224
+ ff_inner_dim: Optional[int] = None,
225
+ ff_bias: bool = True,
226
+ attention_out_bias: bool = True,
227
+ ):
228
+ super().__init__()
229
+
230
+ # 1. Self Attention
231
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
232
+
233
+ self.attn1 = Attention(
234
+ query_dim=dim,
235
+ dim_head=attention_head_dim,
236
+ heads=num_attention_heads,
237
+ qk_norm="layer_norm" if qk_norm else None,
238
+ eps=1e-6,
239
+ bias=attention_bias,
240
+ out_bias=attention_out_bias,
241
+ processor=CogVideoXAttnProcessor2_0(),
242
+ )
243
+
244
+ # 2. Feed Forward
245
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
246
+
247
+ self.ff = FeedForward(
248
+ dim,
249
+ dropout=dropout,
250
+ activation_fn=activation_fn,
251
+ final_dropout=final_dropout,
252
+ inner_dim=ff_inner_dim,
253
+ bias=ff_bias,
254
+ )
255
+
256
+ def forward(
257
+ self,
258
+ hidden_states: torch.Tensor,
259
+ encoder_hidden_states: torch.Tensor,
260
+ temb: torch.Tensor,
261
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
262
+ ) -> torch.Tensor:
263
+ text_seq_length = encoder_hidden_states.size(1)
264
+
265
+ # norm & modulate
266
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
267
+ hidden_states, encoder_hidden_states, temb
268
+ )
269
+
270
+ # attention
271
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
272
+ hidden_states=norm_hidden_states,
273
+ encoder_hidden_states=norm_encoder_hidden_states,
274
+ image_rotary_emb=image_rotary_emb,
275
+ )
276
+
277
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
278
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
279
+
280
+ # norm & modulate
281
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
282
+ hidden_states, encoder_hidden_states, temb
283
+ )
284
+
285
+ # feed-forward
286
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
287
+ ff_output = self.ff(norm_hidden_states)
288
+
289
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
290
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
291
+
292
+ return hidden_states, encoder_hidden_states
293
+
294
+
295
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
296
+ """
297
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
298
+
299
+ Parameters:
300
+ num_attention_heads (`int`, defaults to `30`):
301
+ The number of heads to use for multi-head attention.
302
+ attention_head_dim (`int`, defaults to `64`):
303
+ The number of channels in each head.
304
+ in_channels (`int`, defaults to `16`):
305
+ The number of channels in the input.
306
+ out_channels (`int`, *optional*, defaults to `16`):
307
+ The number of channels in the output.
308
+ flip_sin_to_cos (`bool`, defaults to `True`):
309
+ Whether to flip the sin to cos in the time embedding.
310
+ time_embed_dim (`int`, defaults to `512`):
311
+ Output dimension of timestep embeddings.
312
+ text_embed_dim (`int`, defaults to `4096`):
313
+ Input dimension of text embeddings from the text encoder.
314
+ num_layers (`int`, defaults to `30`):
315
+ The number of layers of Transformer blocks to use.
316
+ dropout (`float`, defaults to `0.0`):
317
+ The dropout probability to use.
318
+ attention_bias (`bool`, defaults to `True`):
319
+ Whether or not to use bias in the attention projection layers.
320
+ sample_width (`int`, defaults to `90`):
321
+ The width of the input latents.
322
+ sample_height (`int`, defaults to `60`):
323
+ The height of the input latents.
324
+ sample_frames (`int`, defaults to `49`):
325
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
326
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
327
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
328
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
329
+ patch_size (`int`, defaults to `2`):
330
+ The size of the patches to use in the patch embedding layer.
331
+ temporal_compression_ratio (`int`, defaults to `4`):
332
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
333
+ max_text_seq_length (`int`, defaults to `226`):
334
+ The maximum sequence length of the input text embeddings.
335
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
336
+ Activation function to use in feed-forward.
337
+ timestep_activation_fn (`str`, defaults to `"silu"`):
338
+ Activation function to use when generating the timestep embeddings.
339
+ norm_elementwise_affine (`bool`, defaults to `True`):
340
+ Whether or not to use elementwise affine in normalization layers.
341
+ norm_eps (`float`, defaults to `1e-5`):
342
+ The epsilon value to use in normalization layers.
343
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
344
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
345
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
346
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
347
+ """
348
+
349
+ _supports_gradient_checkpointing = True
350
+
351
+ @register_to_config
352
+ def __init__(
353
+ self,
354
+ num_attention_heads: int = 30,
355
+ attention_head_dim: int = 64,
356
+ in_channels: int = 16,
357
+ out_channels: Optional[int] = 16,
358
+ flip_sin_to_cos: bool = True,
359
+ freq_shift: int = 0,
360
+ time_embed_dim: int = 512,
361
+ text_embed_dim: int = 4096,
362
+ num_layers: int = 30,
363
+ dropout: float = 0.0,
364
+ attention_bias: bool = True,
365
+ sample_width: int = 90,
366
+ sample_height: int = 60,
367
+ sample_frames: int = 49,
368
+ patch_size: int = 2,
369
+ patch_size_t: Optional[int] = None,
370
+ temporal_compression_ratio: int = 4,
371
+ max_text_seq_length: int = 226,
372
+ activation_fn: str = "gelu-approximate",
373
+ timestep_activation_fn: str = "silu",
374
+ norm_elementwise_affine: bool = True,
375
+ norm_eps: float = 1e-5,
376
+ spatial_interpolation_scale: float = 1.875,
377
+ temporal_interpolation_scale: float = 1.0,
378
+ use_rotary_positional_embeddings: bool = False,
379
+ use_learned_positional_embeddings: bool = False,
380
+ patch_bias: bool = True,
381
+ add_noise_in_inpaint_model: bool = False,
382
+ ):
383
+ super().__init__()
384
+ inner_dim = num_attention_heads * attention_head_dim
385
+ self.patch_size_t = patch_size_t
386
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
387
+ raise ValueError(
388
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
389
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
390
+ "issue at https://github.com/huggingface/diffusers/issues."
391
+ )
392
+
393
+ # 1. Patch embedding
394
+ self.patch_embed = CogVideoXPatchEmbed(
395
+ patch_size=patch_size,
396
+ patch_size_t=patch_size_t,
397
+ in_channels=in_channels,
398
+ embed_dim=inner_dim,
399
+ text_embed_dim=text_embed_dim,
400
+ bias=patch_bias,
401
+ sample_width=sample_width,
402
+ sample_height=sample_height,
403
+ sample_frames=sample_frames,
404
+ temporal_compression_ratio=temporal_compression_ratio,
405
+ max_text_seq_length=max_text_seq_length,
406
+ spatial_interpolation_scale=spatial_interpolation_scale,
407
+ temporal_interpolation_scale=temporal_interpolation_scale,
408
+ use_positional_embeddings=not use_rotary_positional_embeddings,
409
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
410
+ )
411
+ self.embedding_dropout = nn.Dropout(dropout)
412
+
413
+ # 2. Time embeddings
414
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
415
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
416
+
417
+ # 3. Define spatio-temporal transformers blocks
418
+ self.transformer_blocks = nn.ModuleList(
419
+ [
420
+ CogVideoXBlock(
421
+ dim=inner_dim,
422
+ num_attention_heads=num_attention_heads,
423
+ attention_head_dim=attention_head_dim,
424
+ time_embed_dim=time_embed_dim,
425
+ dropout=dropout,
426
+ activation_fn=activation_fn,
427
+ attention_bias=attention_bias,
428
+ norm_elementwise_affine=norm_elementwise_affine,
429
+ norm_eps=norm_eps,
430
+ )
431
+ for _ in range(num_layers)
432
+ ]
433
+ )
434
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
435
+
436
+ # 4. Output blocks
437
+ self.norm_out = AdaLayerNorm(
438
+ embedding_dim=time_embed_dim,
439
+ output_dim=2 * inner_dim,
440
+ norm_elementwise_affine=norm_elementwise_affine,
441
+ norm_eps=norm_eps,
442
+ chunk_dim=1,
443
+ )
444
+
445
+ if patch_size_t is None:
446
+ # For CogVideox 1.0
447
+ output_dim = patch_size * patch_size * out_channels
448
+ else:
449
+ # For CogVideoX 1.5
450
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
451
+
452
+ self.proj_out = nn.Linear(inner_dim, output_dim)
453
+
454
+ self.gradient_checkpointing = False
455
+ self.sp_world_size = 1
456
+ self.sp_world_rank = 0
457
+
458
+ def _set_gradient_checkpointing(self, module, value=False):
459
+ self.gradient_checkpointing = value
460
+
461
+ def enable_multi_gpus_inference(self,):
462
+ self.sp_world_size = get_sequence_parallel_world_size()
463
+ self.sp_world_rank = get_sequence_parallel_rank()
464
+ self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
465
+
466
+ @property
467
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
468
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
469
+ r"""
470
+ Returns:
471
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
472
+ indexed by its weight name.
473
+ """
474
+ # set recursively
475
+ processors = {}
476
+
477
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
478
+ if hasattr(module, "get_processor"):
479
+ processors[f"{name}.processor"] = module.get_processor()
480
+
481
+ for sub_name, child in module.named_children():
482
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
483
+
484
+ return processors
485
+
486
+ for name, module in self.named_children():
487
+ fn_recursive_add_processors(name, module, processors)
488
+
489
+ return processors
490
+
491
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
492
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
493
+ r"""
494
+ Sets the attention processor to use to compute attention.
495
+
496
+ Parameters:
497
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
498
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
499
+ for **all** `Attention` layers.
500
+
501
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
502
+ processor. This is strongly recommended when setting trainable attention processors.
503
+
504
+ """
505
+ count = len(self.attn_processors.keys())
506
+
507
+ if isinstance(processor, dict) and len(processor) != count:
508
+ raise ValueError(
509
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
510
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
511
+ )
512
+
513
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
514
+ if hasattr(module, "set_processor"):
515
+ if not isinstance(processor, dict):
516
+ module.set_processor(processor)
517
+ else:
518
+ module.set_processor(processor.pop(f"{name}.processor"))
519
+
520
+ for sub_name, child in module.named_children():
521
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
522
+
523
+ for name, module in self.named_children():
524
+ fn_recursive_attn_processor(name, module, processor)
525
+
526
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
527
+ def fuse_qkv_projections(self):
528
+ """
529
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
530
+ are fused. For cross-attention modules, key and value projection matrices are fused.
531
+
532
+ <Tip warning={true}>
533
+
534
+ This API is 🧪 experimental.
535
+
536
+ </Tip>
537
+ """
538
+ self.original_attn_processors = None
539
+
540
+ for _, attn_processor in self.attn_processors.items():
541
+ if "Added" in str(attn_processor.__class__.__name__):
542
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
543
+
544
+ self.original_attn_processors = self.attn_processors
545
+
546
+ for module in self.modules():
547
+ if isinstance(module, Attention):
548
+ module.fuse_projections(fuse=True)
549
+
550
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
551
+
552
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
553
+ def unfuse_qkv_projections(self):
554
+ """Disables the fused QKV projection if enabled.
555
+
556
+ <Tip warning={true}>
557
+
558
+ This API is 🧪 experimental.
559
+
560
+ </Tip>
561
+
562
+ """
563
+ if self.original_attn_processors is not None:
564
+ self.set_attn_processor(self.original_attn_processors)
565
+
566
+ def forward(
567
+ self,
568
+ hidden_states: torch.Tensor,
569
+ encoder_hidden_states: torch.Tensor,
570
+ timestep: Union[int, float, torch.LongTensor],
571
+ timestep_cond: Optional[torch.Tensor] = None,
572
+ inpaint_latents: Optional[torch.Tensor] = None,
573
+ control_latents: Optional[torch.Tensor] = None,
574
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
575
+ return_dict: bool = True,
576
+ ):
577
+ batch_size, num_frames, channels, height, width = hidden_states.shape
578
+ if num_frames == 1 and self.patch_size_t is not None:
579
+ hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
580
+ if inpaint_latents is not None:
581
+ inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
582
+ if control_latents is not None:
583
+ control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
584
+ local_num_frames = num_frames + 1
585
+ else:
586
+ local_num_frames = num_frames
587
+
588
+ # 1. Time embedding
589
+ timesteps = timestep
590
+ t_emb = self.time_proj(timesteps)
591
+
592
+ # timesteps does not contain any weights and will always return f32 tensors
593
+ # but time_embedding might actually be running in fp16. so we need to cast here.
594
+ # there might be better ways to encapsulate this.
595
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
596
+ emb = self.time_embedding(t_emb, timestep_cond)
597
+
598
+ # 2. Patch embedding
599
+ if inpaint_latents is not None:
600
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
601
+ if control_latents is not None:
602
+ hidden_states = torch.concat([hidden_states, control_latents], 2)
603
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
604
+ hidden_states = self.embedding_dropout(hidden_states)
605
+
606
+ text_seq_length = encoder_hidden_states.shape[1]
607
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
608
+ hidden_states = hidden_states[:, text_seq_length:]
609
+
610
+ # Context Parallel
611
+ if self.sp_world_size > 1:
612
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
613
+ if image_rotary_emb is not None:
614
+ image_rotary_emb = (
615
+ torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
616
+ torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
617
+ )
618
+
619
+ # 3. Transformer blocks
620
+ for i, block in enumerate(self.transformer_blocks):
621
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
622
+
623
+ def create_custom_forward(module):
624
+ def custom_forward(*inputs):
625
+ return module(*inputs)
626
+
627
+ return custom_forward
628
+
629
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
630
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
631
+ create_custom_forward(block),
632
+ hidden_states,
633
+ encoder_hidden_states,
634
+ emb,
635
+ image_rotary_emb,
636
+ **ckpt_kwargs,
637
+ )
638
+ else:
639
+ hidden_states, encoder_hidden_states = block(
640
+ hidden_states=hidden_states,
641
+ encoder_hidden_states=encoder_hidden_states,
642
+ temb=emb,
643
+ image_rotary_emb=image_rotary_emb,
644
+ )
645
+
646
+ if not self.config.use_rotary_positional_embeddings:
647
+ # CogVideoX-2B
648
+ hidden_states = self.norm_final(hidden_states)
649
+ else:
650
+ # CogVideoX-5B
651
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
652
+ hidden_states = self.norm_final(hidden_states)
653
+ hidden_states = hidden_states[:, text_seq_length:]
654
+
655
+ # 4. Final block
656
+ hidden_states = self.norm_out(hidden_states, temb=emb)
657
+ hidden_states = self.proj_out(hidden_states)
658
+
659
+ if self.sp_world_size > 1:
660
+ hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
661
+
662
+ # 5. Unpatchify
663
+ p = self.config.patch_size
664
+ p_t = self.config.patch_size_t
665
+
666
+ if p_t is None:
667
+ output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
668
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
669
+ else:
670
+ output = hidden_states.reshape(
671
+ batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
672
+ )
673
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
674
+
675
+ if num_frames == 1:
676
+ output = output[:, :num_frames, :]
677
+
678
+ if not return_dict:
679
+ return (output,)
680
+ return Transformer2DModelOutput(sample=output)
681
+
682
+ @classmethod
683
+ def from_pretrained(
684
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
685
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
686
+ ):
687
+ if subfolder is not None:
688
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
689
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
690
+
691
+ config_file = os.path.join(pretrained_model_path, 'config.json')
692
+ if not os.path.isfile(config_file):
693
+ raise RuntimeError(f"{config_file} does not exist")
694
+ with open(config_file, "r") as f:
695
+ config = json.load(f)
696
+
697
+ from diffusers.utils import WEIGHTS_NAME
698
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
699
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
700
+
701
+ if "dict_mapping" in transformer_additional_kwargs.keys():
702
+ for key in transformer_additional_kwargs["dict_mapping"]:
703
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
704
+
705
+ if low_cpu_mem_usage:
706
+ try:
707
+ import re
708
+
709
+ from diffusers import __version__ as diffusers_version
710
+ if diffusers_version >= "0.33.0":
711
+ from diffusers.models.model_loading_utils import \
712
+ load_model_dict_into_meta
713
+ else:
714
+ from diffusers.models.modeling_utils import \
715
+ load_model_dict_into_meta
716
+ from diffusers.utils import is_accelerate_available
717
+ if is_accelerate_available():
718
+ import accelerate
719
+
720
+ # Instantiate model with empty weights
721
+ with accelerate.init_empty_weights():
722
+ model = cls.from_config(config, **transformer_additional_kwargs)
723
+
724
+ param_device = "cpu"
725
+ if os.path.exists(model_file):
726
+ state_dict = torch.load(model_file, map_location="cpu")
727
+ elif os.path.exists(model_file_safetensors):
728
+ from safetensors.torch import load_file, safe_open
729
+ state_dict = load_file(model_file_safetensors)
730
+ else:
731
+ from safetensors.torch import load_file, safe_open
732
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
733
+ state_dict = {}
734
+ for _model_file_safetensors in model_files_safetensors:
735
+ _state_dict = load_file(_model_file_safetensors)
736
+ for key in _state_dict:
737
+ state_dict[key] = _state_dict[key]
738
+ model._convert_deprecated_attention_blocks(state_dict)
739
+
740
+ if diffusers_version >= "0.33.0":
741
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
742
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
743
+ load_model_dict_into_meta(
744
+ model,
745
+ state_dict,
746
+ dtype=torch_dtype,
747
+ model_name_or_path=pretrained_model_path,
748
+ )
749
+ else:
750
+ # move the params from meta device to cpu
751
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
752
+ if len(missing_keys) > 0:
753
+ raise ValueError(
754
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
755
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
756
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
757
+ " those weights or else make sure your checkpoint file is correct."
758
+ )
759
+
760
+ unexpected_keys = load_model_dict_into_meta(
761
+ model,
762
+ state_dict,
763
+ device=param_device,
764
+ dtype=torch_dtype,
765
+ model_name_or_path=pretrained_model_path,
766
+ )
767
+
768
+ if cls._keys_to_ignore_on_load_unexpected is not None:
769
+ for pat in cls._keys_to_ignore_on_load_unexpected:
770
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
771
+
772
+ if len(unexpected_keys) > 0:
773
+ print(
774
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
775
+ )
776
+
777
+ return model
778
+ except Exception as e:
779
+ print(
780
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
781
+ )
782
+
783
+ model = cls.from_config(config, **transformer_additional_kwargs)
784
+ if os.path.exists(model_file):
785
+ state_dict = torch.load(model_file, map_location="cpu")
786
+ elif os.path.exists(model_file_safetensors):
787
+ from safetensors.torch import load_file, safe_open
788
+ state_dict = load_file(model_file_safetensors)
789
+ else:
790
+ from safetensors.torch import load_file, safe_open
791
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
792
+ state_dict = {}
793
+ for _model_file_safetensors in model_files_safetensors:
794
+ _state_dict = load_file(_model_file_safetensors)
795
+ for key in _state_dict:
796
+ state_dict[key] = _state_dict[key]
797
+
798
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
799
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
800
+ if len(new_shape) == 5:
801
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
802
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
803
+ elif len(new_shape) == 2:
804
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
805
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
806
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
807
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
808
+ else:
809
+ model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
810
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
811
+ else:
812
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
813
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
814
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
815
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
816
+ else:
817
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
818
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
819
+
820
+ tmp_state_dict = {}
821
+ for key in state_dict:
822
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
823
+ tmp_state_dict[key] = state_dict[key]
824
+ else:
825
+ print(key, "Size don't match, skip")
826
+
827
+ state_dict = tmp_state_dict
828
+
829
+ m, u = model.load_state_dict(state_dict, strict=False)
830
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
831
+ print(m)
832
+
833
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
834
+ print(f"### All Parameters: {sum(params) / 1e6} M")
835
+
836
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
837
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
838
+
839
+ model = model.to(torch_dtype)
840
+ return model
videox_fun/models/cogvideox_vae.py ADDED
@@ -0,0 +1,1675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import json
23
+ import os
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
27
+ from diffusers.utils import logging
28
+ from diffusers.utils.accelerate_utils import apply_forward_hook
29
+ from diffusers.models.activations import get_activation
30
+ from diffusers.models.downsampling import CogVideoXDownsample3D
31
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.upsampling import CogVideoXUpsample3D
34
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ class CogVideoXSafeConv3d(nn.Conv3d):
41
+ r"""
42
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
43
+ """
44
+
45
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
46
+ memory_count = (
47
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
48
+ )
49
+
50
+ # Set to 2GB, suitable for CuDNN
51
+ if memory_count > 2:
52
+ kernel_size = self.kernel_size[0]
53
+ part_num = int(memory_count / 2) + 1
54
+ input_chunks = torch.chunk(input, part_num, dim=2)
55
+
56
+ if kernel_size > 1:
57
+ input_chunks = [input_chunks[0]] + [
58
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
59
+ for i in range(1, len(input_chunks))
60
+ ]
61
+
62
+ output_chunks = []
63
+ for input_chunk in input_chunks:
64
+ output_chunks.append(super().forward(input_chunk))
65
+ output = torch.cat(output_chunks, dim=2)
66
+ return output
67
+ else:
68
+ return super().forward(input)
69
+
70
+
71
+ class CogVideoXCausalConv3d(nn.Module):
72
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
73
+
74
+ Args:
75
+ in_channels (`int`): Number of channels in the input tensor.
76
+ out_channels (`int`): Number of output channels produced by the convolution.
77
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
78
+ stride (`int`, defaults to `1`): Stride of the convolution.
79
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
80
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: int,
87
+ kernel_size: Union[int, Tuple[int, int, int]],
88
+ stride: int = 1,
89
+ dilation: int = 1,
90
+ pad_mode: str = "constant",
91
+ ):
92
+ super().__init__()
93
+
94
+ if isinstance(kernel_size, int):
95
+ kernel_size = (kernel_size,) * 3
96
+
97
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
98
+
99
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
100
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
101
+ time_pad = time_kernel_size - 1
102
+ height_pad = (height_kernel_size - 1) // 2
103
+ width_pad = (width_kernel_size - 1) // 2
104
+
105
+ self.pad_mode = pad_mode
106
+ self.height_pad = height_pad
107
+ self.width_pad = width_pad
108
+ self.time_pad = time_pad
109
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
110
+
111
+ self.temporal_dim = 2
112
+ self.time_kernel_size = time_kernel_size
113
+
114
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
115
+ dilation = (dilation, 1, 1)
116
+ self.conv = CogVideoXSafeConv3d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=kernel_size,
120
+ stride=stride,
121
+ dilation=dilation,
122
+ )
123
+
124
+ def fake_context_parallel_forward(
125
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
126
+ ) -> torch.Tensor:
127
+ if self.pad_mode == "replicate":
128
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
129
+ else:
130
+ kernel_size = self.time_kernel_size
131
+ if kernel_size > 1:
132
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
133
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
134
+ return inputs
135
+
136
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
137
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
138
+
139
+ if self.pad_mode == "replicate":
140
+ conv_cache = None
141
+ else:
142
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
143
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
144
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
145
+
146
+ output = self.conv(inputs)
147
+ return output, conv_cache
148
+
149
+
150
+ class CogVideoXSpatialNorm3D(nn.Module):
151
+ r"""
152
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
153
+ to 3D-video like data.
154
+
155
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
156
+
157
+ Args:
158
+ f_channels (`int`):
159
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
160
+ zq_channels (`int`):
161
+ The number of channels for the quantized vector as described in the paper.
162
+ groups (`int`):
163
+ Number of groups to separate the channels into for group normalization.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ f_channels: int,
169
+ zq_channels: int,
170
+ groups: int = 32,
171
+ ):
172
+ super().__init__()
173
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
174
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
175
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
176
+
177
+ def forward(
178
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
179
+ ) -> torch.Tensor:
180
+ new_conv_cache = {}
181
+ conv_cache = conv_cache or {}
182
+
183
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
184
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
185
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
186
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
187
+ z_first = F.interpolate(z_first, size=f_first_size)
188
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
189
+ zq = torch.cat([z_first, z_rest], dim=2)
190
+ else:
191
+ zq = F.interpolate(zq, size=f.shape[-3:])
192
+
193
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
194
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
195
+
196
+ norm_f = self.norm_layer(f)
197
+ new_f = norm_f * conv_y + conv_b
198
+ return new_f, new_conv_cache
199
+
200
+
201
+ class CogVideoXUpsample3D(nn.Module):
202
+ r"""
203
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
204
+
205
+ Args:
206
+ in_channels (`int`):
207
+ Number of channels in the input image.
208
+ out_channels (`int`):
209
+ Number of channels produced by the convolution.
210
+ kernel_size (`int`, defaults to `3`):
211
+ Size of the convolving kernel.
212
+ stride (`int`, defaults to `1`):
213
+ Stride of the convolution.
214
+ padding (`int`, defaults to `1`):
215
+ Padding added to all four sides of the input.
216
+ compress_time (`bool`, defaults to `False`):
217
+ Whether or not to compress the time dimension.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ in_channels: int,
223
+ out_channels: int,
224
+ kernel_size: int = 3,
225
+ stride: int = 1,
226
+ padding: int = 1,
227
+ compress_time: bool = False,
228
+ ) -> None:
229
+ super().__init__()
230
+
231
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
232
+ self.compress_time = compress_time
233
+
234
+ self.auto_split_process = True
235
+ self.first_frame_flag = False
236
+
237
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
238
+ if self.compress_time:
239
+ if self.auto_split_process:
240
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
241
+ # split first frame
242
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
243
+
244
+ x_first = F.interpolate(x_first, scale_factor=2.0)
245
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
246
+ x_first = x_first[:, :, None, :, :]
247
+ inputs = torch.cat([x_first, x_rest], dim=2)
248
+ elif inputs.shape[2] > 1:
249
+ inputs = F.interpolate(inputs, scale_factor=2.0)
250
+ else:
251
+ inputs = inputs.squeeze(2)
252
+ inputs = F.interpolate(inputs, scale_factor=2.0)
253
+ inputs = inputs[:, :, None, :, :]
254
+ else:
255
+ if self.first_frame_flag:
256
+ inputs = inputs.squeeze(2)
257
+ inputs = F.interpolate(inputs, scale_factor=2.0)
258
+ inputs = inputs[:, :, None, :, :]
259
+ else:
260
+ inputs = F.interpolate(inputs, scale_factor=2.0)
261
+ else:
262
+ # only interpolate 2D
263
+ b, c, t, h, w = inputs.shape
264
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
265
+ inputs = F.interpolate(inputs, scale_factor=2.0)
266
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
267
+
268
+ b, c, t, h, w = inputs.shape
269
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
270
+ inputs = self.conv(inputs)
271
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
272
+
273
+ return inputs
274
+
275
+
276
+ class CogVideoXResnetBlock3D(nn.Module):
277
+ r"""
278
+ A 3D ResNet block used in the CogVideoX model.
279
+
280
+ Args:
281
+ in_channels (`int`):
282
+ Number of input channels.
283
+ out_channels (`int`, *optional*):
284
+ Number of output channels. If None, defaults to `in_channels`.
285
+ dropout (`float`, defaults to `0.0`):
286
+ Dropout rate.
287
+ temb_channels (`int`, defaults to `512`):
288
+ Number of time embedding channels.
289
+ groups (`int`, defaults to `32`):
290
+ Number of groups to separate the channels into for group normalization.
291
+ eps (`float`, defaults to `1e-6`):
292
+ Epsilon value for normalization layers.
293
+ non_linearity (`str`, defaults to `"swish"`):
294
+ Activation function to use.
295
+ conv_shortcut (bool, defaults to `False`):
296
+ Whether or not to use a convolution shortcut.
297
+ spatial_norm_dim (`int`, *optional*):
298
+ The dimension to use for spatial norm if it is to be used instead of group norm.
299
+ pad_mode (str, defaults to `"first"`):
300
+ Padding mode.
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ in_channels: int,
306
+ out_channels: Optional[int] = None,
307
+ dropout: float = 0.0,
308
+ temb_channels: int = 512,
309
+ groups: int = 32,
310
+ eps: float = 1e-6,
311
+ non_linearity: str = "swish",
312
+ conv_shortcut: bool = False,
313
+ spatial_norm_dim: Optional[int] = None,
314
+ pad_mode: str = "first",
315
+ ):
316
+ super().__init__()
317
+
318
+ out_channels = out_channels or in_channels
319
+
320
+ self.in_channels = in_channels
321
+ self.out_channels = out_channels
322
+ self.nonlinearity = get_activation(non_linearity)
323
+ self.use_conv_shortcut = conv_shortcut
324
+ self.spatial_norm_dim = spatial_norm_dim
325
+
326
+ if spatial_norm_dim is None:
327
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
328
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
329
+ else:
330
+ self.norm1 = CogVideoXSpatialNorm3D(
331
+ f_channels=in_channels,
332
+ zq_channels=spatial_norm_dim,
333
+ groups=groups,
334
+ )
335
+ self.norm2 = CogVideoXSpatialNorm3D(
336
+ f_channels=out_channels,
337
+ zq_channels=spatial_norm_dim,
338
+ groups=groups,
339
+ )
340
+
341
+ self.conv1 = CogVideoXCausalConv3d(
342
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
343
+ )
344
+
345
+ if temb_channels > 0:
346
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
347
+
348
+ self.dropout = nn.Dropout(dropout)
349
+ self.conv2 = CogVideoXCausalConv3d(
350
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
351
+ )
352
+
353
+ if self.in_channels != self.out_channels:
354
+ if self.use_conv_shortcut:
355
+ self.conv_shortcut = CogVideoXCausalConv3d(
356
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
357
+ )
358
+ else:
359
+ self.conv_shortcut = CogVideoXSafeConv3d(
360
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
361
+ )
362
+
363
+ def forward(
364
+ self,
365
+ inputs: torch.Tensor,
366
+ temb: Optional[torch.Tensor] = None,
367
+ zq: Optional[torch.Tensor] = None,
368
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
369
+ ) -> torch.Tensor:
370
+ new_conv_cache = {}
371
+ conv_cache = conv_cache or {}
372
+
373
+ hidden_states = inputs
374
+
375
+ if zq is not None:
376
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
377
+ else:
378
+ hidden_states = self.norm1(hidden_states)
379
+
380
+ hidden_states = self.nonlinearity(hidden_states)
381
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
382
+
383
+ if temb is not None:
384
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
385
+
386
+ if zq is not None:
387
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
388
+ else:
389
+ hidden_states = self.norm2(hidden_states)
390
+
391
+ hidden_states = self.nonlinearity(hidden_states)
392
+ hidden_states = self.dropout(hidden_states)
393
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
394
+
395
+ if self.in_channels != self.out_channels:
396
+ if self.use_conv_shortcut:
397
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
398
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
399
+ )
400
+ else:
401
+ inputs = self.conv_shortcut(inputs)
402
+
403
+ hidden_states = hidden_states + inputs
404
+ return hidden_states, new_conv_cache
405
+
406
+
407
+ class CogVideoXDownBlock3D(nn.Module):
408
+ r"""
409
+ A downsampling block used in the CogVideoX model.
410
+
411
+ Args:
412
+ in_channels (`int`):
413
+ Number of input channels.
414
+ out_channels (`int`, *optional*):
415
+ Number of output channels. If None, defaults to `in_channels`.
416
+ temb_channels (`int`, defaults to `512`):
417
+ Number of time embedding channels.
418
+ num_layers (`int`, defaults to `1`):
419
+ Number of resnet layers.
420
+ dropout (`float`, defaults to `0.0`):
421
+ Dropout rate.
422
+ resnet_eps (`float`, defaults to `1e-6`):
423
+ Epsilon value for normalization layers.
424
+ resnet_act_fn (`str`, defaults to `"swish"`):
425
+ Activation function to use.
426
+ resnet_groups (`int`, defaults to `32`):
427
+ Number of groups to separate the channels into for group normalization.
428
+ add_downsample (`bool`, defaults to `True`):
429
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
430
+ compress_time (`bool`, defaults to `False`):
431
+ Whether or not to downsample across temporal dimension.
432
+ pad_mode (str, defaults to `"first"`):
433
+ Padding mode.
434
+ """
435
+
436
+ _supports_gradient_checkpointing = True
437
+
438
+ def __init__(
439
+ self,
440
+ in_channels: int,
441
+ out_channels: int,
442
+ temb_channels: int,
443
+ dropout: float = 0.0,
444
+ num_layers: int = 1,
445
+ resnet_eps: float = 1e-6,
446
+ resnet_act_fn: str = "swish",
447
+ resnet_groups: int = 32,
448
+ add_downsample: bool = True,
449
+ downsample_padding: int = 0,
450
+ compress_time: bool = False,
451
+ pad_mode: str = "first",
452
+ ):
453
+ super().__init__()
454
+
455
+ resnets = []
456
+ for i in range(num_layers):
457
+ in_channel = in_channels if i == 0 else out_channels
458
+ resnets.append(
459
+ CogVideoXResnetBlock3D(
460
+ in_channels=in_channel,
461
+ out_channels=out_channels,
462
+ dropout=dropout,
463
+ temb_channels=temb_channels,
464
+ groups=resnet_groups,
465
+ eps=resnet_eps,
466
+ non_linearity=resnet_act_fn,
467
+ pad_mode=pad_mode,
468
+ )
469
+ )
470
+
471
+ self.resnets = nn.ModuleList(resnets)
472
+ self.downsamplers = None
473
+
474
+ if add_downsample:
475
+ self.downsamplers = nn.ModuleList(
476
+ [
477
+ CogVideoXDownsample3D(
478
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
479
+ )
480
+ ]
481
+ )
482
+
483
+ self.gradient_checkpointing = False
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ temb: Optional[torch.Tensor] = None,
489
+ zq: Optional[torch.Tensor] = None,
490
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
491
+ ) -> torch.Tensor:
492
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
493
+
494
+ new_conv_cache = {}
495
+ conv_cache = conv_cache or {}
496
+
497
+ for i, resnet in enumerate(self.resnets):
498
+ conv_cache_key = f"resnet_{i}"
499
+
500
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
501
+
502
+ def create_custom_forward(module):
503
+ def create_forward(*inputs):
504
+ return module(*inputs)
505
+
506
+ return create_forward
507
+
508
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
509
+ create_custom_forward(resnet),
510
+ hidden_states,
511
+ temb,
512
+ zq,
513
+ conv_cache.get(conv_cache_key),
514
+ )
515
+ else:
516
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
517
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
518
+ )
519
+
520
+ if self.downsamplers is not None:
521
+ for downsampler in self.downsamplers:
522
+ hidden_states = downsampler(hidden_states)
523
+
524
+ return hidden_states, new_conv_cache
525
+
526
+
527
+ class CogVideoXMidBlock3D(nn.Module):
528
+ r"""
529
+ A middle block used in the CogVideoX model.
530
+
531
+ Args:
532
+ in_channels (`int`):
533
+ Number of input channels.
534
+ temb_channels (`int`, defaults to `512`):
535
+ Number of time embedding channels.
536
+ dropout (`float`, defaults to `0.0`):
537
+ Dropout rate.
538
+ num_layers (`int`, defaults to `1`):
539
+ Number of resnet layers.
540
+ resnet_eps (`float`, defaults to `1e-6`):
541
+ Epsilon value for normalization layers.
542
+ resnet_act_fn (`str`, defaults to `"swish"`):
543
+ Activation function to use.
544
+ resnet_groups (`int`, defaults to `32`):
545
+ Number of groups to separate the channels into for group normalization.
546
+ spatial_norm_dim (`int`, *optional*):
547
+ The dimension to use for spatial norm if it is to be used instead of group norm.
548
+ pad_mode (str, defaults to `"first"`):
549
+ Padding mode.
550
+ """
551
+
552
+ _supports_gradient_checkpointing = True
553
+
554
+ def __init__(
555
+ self,
556
+ in_channels: int,
557
+ temb_channels: int,
558
+ dropout: float = 0.0,
559
+ num_layers: int = 1,
560
+ resnet_eps: float = 1e-6,
561
+ resnet_act_fn: str = "swish",
562
+ resnet_groups: int = 32,
563
+ spatial_norm_dim: Optional[int] = None,
564
+ pad_mode: str = "first",
565
+ ):
566
+ super().__init__()
567
+
568
+ resnets = []
569
+ for _ in range(num_layers):
570
+ resnets.append(
571
+ CogVideoXResnetBlock3D(
572
+ in_channels=in_channels,
573
+ out_channels=in_channels,
574
+ dropout=dropout,
575
+ temb_channels=temb_channels,
576
+ groups=resnet_groups,
577
+ eps=resnet_eps,
578
+ spatial_norm_dim=spatial_norm_dim,
579
+ non_linearity=resnet_act_fn,
580
+ pad_mode=pad_mode,
581
+ )
582
+ )
583
+ self.resnets = nn.ModuleList(resnets)
584
+
585
+ self.gradient_checkpointing = False
586
+
587
+ def forward(
588
+ self,
589
+ hidden_states: torch.Tensor,
590
+ temb: Optional[torch.Tensor] = None,
591
+ zq: Optional[torch.Tensor] = None,
592
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
593
+ ) -> torch.Tensor:
594
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
595
+
596
+ new_conv_cache = {}
597
+ conv_cache = conv_cache or {}
598
+
599
+ for i, resnet in enumerate(self.resnets):
600
+ conv_cache_key = f"resnet_{i}"
601
+
602
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
603
+
604
+ def create_custom_forward(module):
605
+ def create_forward(*inputs):
606
+ return module(*inputs)
607
+
608
+ return create_forward
609
+
610
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
611
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
612
+ )
613
+ else:
614
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
615
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
616
+ )
617
+
618
+ return hidden_states, new_conv_cache
619
+
620
+
621
+ class CogVideoXUpBlock3D(nn.Module):
622
+ r"""
623
+ An upsampling block used in the CogVideoX model.
624
+
625
+ Args:
626
+ in_channels (`int`):
627
+ Number of input channels.
628
+ out_channels (`int`, *optional*):
629
+ Number of output channels. If None, defaults to `in_channels`.
630
+ temb_channels (`int`, defaults to `512`):
631
+ Number of time embedding channels.
632
+ dropout (`float`, defaults to `0.0`):
633
+ Dropout rate.
634
+ num_layers (`int`, defaults to `1`):
635
+ Number of resnet layers.
636
+ resnet_eps (`float`, defaults to `1e-6`):
637
+ Epsilon value for normalization layers.
638
+ resnet_act_fn (`str`, defaults to `"swish"`):
639
+ Activation function to use.
640
+ resnet_groups (`int`, defaults to `32`):
641
+ Number of groups to separate the channels into for group normalization.
642
+ spatial_norm_dim (`int`, defaults to `16`):
643
+ The dimension to use for spatial norm if it is to be used instead of group norm.
644
+ add_upsample (`bool`, defaults to `True`):
645
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
646
+ compress_time (`bool`, defaults to `False`):
647
+ Whether or not to downsample across temporal dimension.
648
+ pad_mode (str, defaults to `"first"`):
649
+ Padding mode.
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ in_channels: int,
655
+ out_channels: int,
656
+ temb_channels: int,
657
+ dropout: float = 0.0,
658
+ num_layers: int = 1,
659
+ resnet_eps: float = 1e-6,
660
+ resnet_act_fn: str = "swish",
661
+ resnet_groups: int = 32,
662
+ spatial_norm_dim: int = 16,
663
+ add_upsample: bool = True,
664
+ upsample_padding: int = 1,
665
+ compress_time: bool = False,
666
+ pad_mode: str = "first",
667
+ ):
668
+ super().__init__()
669
+
670
+ resnets = []
671
+ for i in range(num_layers):
672
+ in_channel = in_channels if i == 0 else out_channels
673
+ resnets.append(
674
+ CogVideoXResnetBlock3D(
675
+ in_channels=in_channel,
676
+ out_channels=out_channels,
677
+ dropout=dropout,
678
+ temb_channels=temb_channels,
679
+ groups=resnet_groups,
680
+ eps=resnet_eps,
681
+ non_linearity=resnet_act_fn,
682
+ spatial_norm_dim=spatial_norm_dim,
683
+ pad_mode=pad_mode,
684
+ )
685
+ )
686
+
687
+ self.resnets = nn.ModuleList(resnets)
688
+ self.upsamplers = None
689
+
690
+ if add_upsample:
691
+ self.upsamplers = nn.ModuleList(
692
+ [
693
+ CogVideoXUpsample3D(
694
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
695
+ )
696
+ ]
697
+ )
698
+
699
+ self.gradient_checkpointing = False
700
+
701
+ def forward(
702
+ self,
703
+ hidden_states: torch.Tensor,
704
+ temb: Optional[torch.Tensor] = None,
705
+ zq: Optional[torch.Tensor] = None,
706
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
707
+ ) -> torch.Tensor:
708
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
709
+
710
+ new_conv_cache = {}
711
+ conv_cache = conv_cache or {}
712
+
713
+ for i, resnet in enumerate(self.resnets):
714
+ conv_cache_key = f"resnet_{i}"
715
+
716
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
717
+
718
+ def create_custom_forward(module):
719
+ def create_forward(*inputs):
720
+ return module(*inputs)
721
+
722
+ return create_forward
723
+
724
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
725
+ create_custom_forward(resnet),
726
+ hidden_states,
727
+ temb,
728
+ zq,
729
+ conv_cache.get(conv_cache_key),
730
+ )
731
+ else:
732
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
733
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
734
+ )
735
+
736
+ if self.upsamplers is not None:
737
+ for upsampler in self.upsamplers:
738
+ hidden_states = upsampler(hidden_states)
739
+
740
+ return hidden_states, new_conv_cache
741
+
742
+
743
+ class CogVideoXEncoder3D(nn.Module):
744
+ r"""
745
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
746
+
747
+ Args:
748
+ in_channels (`int`, *optional*, defaults to 3):
749
+ The number of input channels.
750
+ out_channels (`int`, *optional*, defaults to 3):
751
+ The number of output channels.
752
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
753
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
754
+ options.
755
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
756
+ The number of output channels for each block.
757
+ act_fn (`str`, *optional*, defaults to `"silu"`):
758
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
759
+ layers_per_block (`int`, *optional*, defaults to 2):
760
+ The number of layers per block.
761
+ norm_num_groups (`int`, *optional*, defaults to 32):
762
+ The number of groups for normalization.
763
+ """
764
+
765
+ _supports_gradient_checkpointing = True
766
+
767
+ def __init__(
768
+ self,
769
+ in_channels: int = 3,
770
+ out_channels: int = 16,
771
+ down_block_types: Tuple[str, ...] = (
772
+ "CogVideoXDownBlock3D",
773
+ "CogVideoXDownBlock3D",
774
+ "CogVideoXDownBlock3D",
775
+ "CogVideoXDownBlock3D",
776
+ ),
777
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
778
+ layers_per_block: int = 3,
779
+ act_fn: str = "silu",
780
+ norm_eps: float = 1e-6,
781
+ norm_num_groups: int = 32,
782
+ dropout: float = 0.0,
783
+ pad_mode: str = "first",
784
+ temporal_compression_ratio: float = 4,
785
+ ):
786
+ super().__init__()
787
+
788
+ # log2 of temporal_compress_times
789
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
790
+
791
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
792
+ self.down_blocks = nn.ModuleList([])
793
+
794
+ # down blocks
795
+ output_channel = block_out_channels[0]
796
+ for i, down_block_type in enumerate(down_block_types):
797
+ input_channel = output_channel
798
+ output_channel = block_out_channels[i]
799
+ is_final_block = i == len(block_out_channels) - 1
800
+ compress_time = i < temporal_compress_level
801
+
802
+ if down_block_type == "CogVideoXDownBlock3D":
803
+ down_block = CogVideoXDownBlock3D(
804
+ in_channels=input_channel,
805
+ out_channels=output_channel,
806
+ temb_channels=0,
807
+ dropout=dropout,
808
+ num_layers=layers_per_block,
809
+ resnet_eps=norm_eps,
810
+ resnet_act_fn=act_fn,
811
+ resnet_groups=norm_num_groups,
812
+ add_downsample=not is_final_block,
813
+ compress_time=compress_time,
814
+ )
815
+ else:
816
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
817
+
818
+ self.down_blocks.append(down_block)
819
+
820
+ # mid block
821
+ self.mid_block = CogVideoXMidBlock3D(
822
+ in_channels=block_out_channels[-1],
823
+ temb_channels=0,
824
+ dropout=dropout,
825
+ num_layers=2,
826
+ resnet_eps=norm_eps,
827
+ resnet_act_fn=act_fn,
828
+ resnet_groups=norm_num_groups,
829
+ pad_mode=pad_mode,
830
+ )
831
+
832
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
833
+ self.conv_act = nn.SiLU()
834
+ self.conv_out = CogVideoXCausalConv3d(
835
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
836
+ )
837
+
838
+ self.gradient_checkpointing = False
839
+
840
+ def forward(
841
+ self,
842
+ sample: torch.Tensor,
843
+ temb: Optional[torch.Tensor] = None,
844
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
845
+ ) -> torch.Tensor:
846
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
847
+
848
+ new_conv_cache = {}
849
+ conv_cache = conv_cache or {}
850
+
851
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
+
853
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
854
+
855
+ def create_custom_forward(module):
856
+ def custom_forward(*inputs):
857
+ return module(*inputs)
858
+
859
+ return custom_forward
860
+
861
+ # 1. Down
862
+ for i, down_block in enumerate(self.down_blocks):
863
+ conv_cache_key = f"down_block_{i}"
864
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
865
+ create_custom_forward(down_block),
866
+ hidden_states,
867
+ temb,
868
+ None,
869
+ conv_cache.get(conv_cache_key),
870
+ )
871
+
872
+ # 2. Mid
873
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
874
+ create_custom_forward(self.mid_block),
875
+ hidden_states,
876
+ temb,
877
+ None,
878
+ conv_cache.get("mid_block"),
879
+ )
880
+ else:
881
+ # 1. Down
882
+ for i, down_block in enumerate(self.down_blocks):
883
+ conv_cache_key = f"down_block_{i}"
884
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
885
+ hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
886
+ )
887
+
888
+ # 2. Mid
889
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
890
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
891
+ )
892
+
893
+ # 3. Post-process
894
+ hidden_states = self.norm_out(hidden_states)
895
+ hidden_states = self.conv_act(hidden_states)
896
+
897
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
898
+
899
+ return hidden_states, new_conv_cache
900
+
901
+
902
+ class CogVideoXDecoder3D(nn.Module):
903
+ r"""
904
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
905
+ sample.
906
+
907
+ Args:
908
+ in_channels (`int`, *optional*, defaults to 3):
909
+ The number of input channels.
910
+ out_channels (`int`, *optional*, defaults to 3):
911
+ The number of output channels.
912
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
913
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
914
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
915
+ The number of output channels for each block.
916
+ act_fn (`str`, *optional*, defaults to `"silu"`):
917
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
918
+ layers_per_block (`int`, *optional*, defaults to 2):
919
+ The number of layers per block.
920
+ norm_num_groups (`int`, *optional*, defaults to 32):
921
+ The number of groups for normalization.
922
+ """
923
+
924
+ _supports_gradient_checkpointing = True
925
+
926
+ def __init__(
927
+ self,
928
+ in_channels: int = 16,
929
+ out_channels: int = 3,
930
+ up_block_types: Tuple[str, ...] = (
931
+ "CogVideoXUpBlock3D",
932
+ "CogVideoXUpBlock3D",
933
+ "CogVideoXUpBlock3D",
934
+ "CogVideoXUpBlock3D",
935
+ ),
936
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
937
+ layers_per_block: int = 3,
938
+ act_fn: str = "silu",
939
+ norm_eps: float = 1e-6,
940
+ norm_num_groups: int = 32,
941
+ dropout: float = 0.0,
942
+ pad_mode: str = "first",
943
+ temporal_compression_ratio: float = 4,
944
+ ):
945
+ super().__init__()
946
+
947
+ reversed_block_out_channels = list(reversed(block_out_channels))
948
+
949
+ self.conv_in = CogVideoXCausalConv3d(
950
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
951
+ )
952
+
953
+ # mid block
954
+ self.mid_block = CogVideoXMidBlock3D(
955
+ in_channels=reversed_block_out_channels[0],
956
+ temb_channels=0,
957
+ num_layers=2,
958
+ resnet_eps=norm_eps,
959
+ resnet_act_fn=act_fn,
960
+ resnet_groups=norm_num_groups,
961
+ spatial_norm_dim=in_channels,
962
+ pad_mode=pad_mode,
963
+ )
964
+
965
+ # up blocks
966
+ self.up_blocks = nn.ModuleList([])
967
+
968
+ output_channel = reversed_block_out_channels[0]
969
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
970
+
971
+ for i, up_block_type in enumerate(up_block_types):
972
+ prev_output_channel = output_channel
973
+ output_channel = reversed_block_out_channels[i]
974
+ is_final_block = i == len(block_out_channels) - 1
975
+ compress_time = i < temporal_compress_level
976
+
977
+ if up_block_type == "CogVideoXUpBlock3D":
978
+ up_block = CogVideoXUpBlock3D(
979
+ in_channels=prev_output_channel,
980
+ out_channels=output_channel,
981
+ temb_channels=0,
982
+ dropout=dropout,
983
+ num_layers=layers_per_block + 1,
984
+ resnet_eps=norm_eps,
985
+ resnet_act_fn=act_fn,
986
+ resnet_groups=norm_num_groups,
987
+ spatial_norm_dim=in_channels,
988
+ add_upsample=not is_final_block,
989
+ compress_time=compress_time,
990
+ pad_mode=pad_mode,
991
+ )
992
+ prev_output_channel = output_channel
993
+ else:
994
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
995
+
996
+ self.up_blocks.append(up_block)
997
+
998
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
999
+ self.conv_act = nn.SiLU()
1000
+ self.conv_out = CogVideoXCausalConv3d(
1001
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
1002
+ )
1003
+
1004
+ self.gradient_checkpointing = False
1005
+
1006
+ def forward(
1007
+ self,
1008
+ sample: torch.Tensor,
1009
+ temb: Optional[torch.Tensor] = None,
1010
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
1011
+ ) -> torch.Tensor:
1012
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
1013
+
1014
+ new_conv_cache = {}
1015
+ conv_cache = conv_cache or {}
1016
+
1017
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
1018
+
1019
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1020
+
1021
+ def create_custom_forward(module):
1022
+ def custom_forward(*inputs):
1023
+ return module(*inputs)
1024
+
1025
+ return custom_forward
1026
+
1027
+ # 1. Mid
1028
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
1029
+ create_custom_forward(self.mid_block),
1030
+ hidden_states,
1031
+ temb,
1032
+ sample,
1033
+ conv_cache.get("mid_block"),
1034
+ )
1035
+
1036
+ # 2. Up
1037
+ for i, up_block in enumerate(self.up_blocks):
1038
+ conv_cache_key = f"up_block_{i}"
1039
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
1040
+ create_custom_forward(up_block),
1041
+ hidden_states,
1042
+ temb,
1043
+ sample,
1044
+ conv_cache.get(conv_cache_key),
1045
+ )
1046
+ else:
1047
+ # 1. Mid
1048
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
1049
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
1050
+ )
1051
+
1052
+ # 2. Up
1053
+ for i, up_block in enumerate(self.up_blocks):
1054
+ conv_cache_key = f"up_block_{i}"
1055
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
1056
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
1057
+ )
1058
+
1059
+ # 3. Post-process
1060
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
1061
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
1062
+ )
1063
+ hidden_states = self.conv_act(hidden_states)
1064
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
1065
+
1066
+ return hidden_states, new_conv_cache
1067
+
1068
+
1069
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1070
+ r"""
1071
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
1072
+ [CogVideoX](https://github.com/THUDM/CogVideo).
1073
+
1074
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1075
+ for all models (such as downloading or saving).
1076
+
1077
+ Parameters:
1078
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1079
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1080
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1081
+ Tuple of downsample block types.
1082
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1083
+ Tuple of upsample block types.
1084
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1085
+ Tuple of block output channels.
1086
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1087
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1088
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
1089
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
1090
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
1091
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
1092
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
1093
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
1094
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
1095
+ force_upcast (`bool`, *optional*, default to `True`):
1096
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1097
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
1098
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1099
+ """
1100
+
1101
+ _supports_gradient_checkpointing = True
1102
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
1103
+
1104
+ @register_to_config
1105
+ def __init__(
1106
+ self,
1107
+ in_channels: int = 3,
1108
+ out_channels: int = 3,
1109
+ down_block_types: Tuple[str] = (
1110
+ "CogVideoXDownBlock3D",
1111
+ "CogVideoXDownBlock3D",
1112
+ "CogVideoXDownBlock3D",
1113
+ "CogVideoXDownBlock3D",
1114
+ ),
1115
+ up_block_types: Tuple[str] = (
1116
+ "CogVideoXUpBlock3D",
1117
+ "CogVideoXUpBlock3D",
1118
+ "CogVideoXUpBlock3D",
1119
+ "CogVideoXUpBlock3D",
1120
+ ),
1121
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
1122
+ latent_channels: int = 16,
1123
+ layers_per_block: int = 3,
1124
+ act_fn: str = "silu",
1125
+ norm_eps: float = 1e-6,
1126
+ norm_num_groups: int = 32,
1127
+ temporal_compression_ratio: float = 4,
1128
+ sample_height: int = 480,
1129
+ sample_width: int = 720,
1130
+ scaling_factor: float = 1.15258426,
1131
+ shift_factor: Optional[float] = None,
1132
+ latents_mean: Optional[Tuple[float]] = None,
1133
+ latents_std: Optional[Tuple[float]] = None,
1134
+ force_upcast: float = True,
1135
+ use_quant_conv: bool = False,
1136
+ use_post_quant_conv: bool = False,
1137
+ invert_scale_latents: bool = False,
1138
+ ):
1139
+ super().__init__()
1140
+
1141
+ self.encoder = CogVideoXEncoder3D(
1142
+ in_channels=in_channels,
1143
+ out_channels=latent_channels,
1144
+ down_block_types=down_block_types,
1145
+ block_out_channels=block_out_channels,
1146
+ layers_per_block=layers_per_block,
1147
+ act_fn=act_fn,
1148
+ norm_eps=norm_eps,
1149
+ norm_num_groups=norm_num_groups,
1150
+ temporal_compression_ratio=temporal_compression_ratio,
1151
+ )
1152
+ self.decoder = CogVideoXDecoder3D(
1153
+ in_channels=latent_channels,
1154
+ out_channels=out_channels,
1155
+ up_block_types=up_block_types,
1156
+ block_out_channels=block_out_channels,
1157
+ layers_per_block=layers_per_block,
1158
+ act_fn=act_fn,
1159
+ norm_eps=norm_eps,
1160
+ norm_num_groups=norm_num_groups,
1161
+ temporal_compression_ratio=temporal_compression_ratio,
1162
+ )
1163
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
1164
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
1165
+
1166
+ self.use_slicing = False
1167
+ self.use_tiling = False
1168
+ self.auto_split_process = False
1169
+
1170
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
1171
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
1172
+ # If you decode X latent frames together, the number of output frames is:
1173
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
1174
+ #
1175
+ # Example with num_latent_frames_batch_size = 2:
1176
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
1177
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1178
+ # => 6 * 8 = 48 frames
1179
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
1180
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
1181
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1182
+ # => 1 * 9 + 5 * 8 = 49 frames
1183
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
1184
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1185
+ # number of temporal frames.
1186
+ self.num_latent_frames_batch_size = 2
1187
+ self.num_sample_frames_batch_size = 8
1188
+
1189
+ # We make the minimum height and width of sample for tiling half that of the generally supported
1190
+ self.tile_sample_min_height = sample_height // 2
1191
+ self.tile_sample_min_width = sample_width // 2
1192
+ self.tile_latent_min_height = int(
1193
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1194
+ )
1195
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1196
+
1197
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1198
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1199
+ # and so the tiling implementation has only been tested on those specific resolutions.
1200
+ self.tile_overlap_factor_height = 1 / 6
1201
+ self.tile_overlap_factor_width = 1 / 5
1202
+
1203
+ def _set_gradient_checkpointing(self, module, value=False):
1204
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1205
+ module.gradient_checkpointing = value
1206
+
1207
+ def enable_tiling(
1208
+ self,
1209
+ tile_sample_min_height: Optional[int] = None,
1210
+ tile_sample_min_width: Optional[int] = None,
1211
+ tile_overlap_factor_height: Optional[float] = None,
1212
+ tile_overlap_factor_width: Optional[float] = None,
1213
+ ) -> None:
1214
+ r"""
1215
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1216
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1217
+ processing larger images.
1218
+
1219
+ Args:
1220
+ tile_sample_min_height (`int`, *optional*):
1221
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1222
+ tile_sample_min_width (`int`, *optional*):
1223
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1224
+ tile_overlap_factor_height (`int`, *optional*):
1225
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1226
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1227
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1228
+ tile_overlap_factor_width (`int`, *optional*):
1229
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1230
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1231
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1232
+ """
1233
+ self.use_tiling = True
1234
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1235
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1236
+ self.tile_latent_min_height = int(
1237
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1238
+ )
1239
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1240
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1241
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1242
+
1243
+ def disable_tiling(self) -> None:
1244
+ r"""
1245
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1246
+ decoding in one step.
1247
+ """
1248
+ self.use_tiling = False
1249
+
1250
+ def enable_slicing(self) -> None:
1251
+ r"""
1252
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1253
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1254
+ """
1255
+ self.use_slicing = True
1256
+
1257
+ def disable_slicing(self) -> None:
1258
+ r"""
1259
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1260
+ decoding in one step.
1261
+ """
1262
+ self.use_slicing = False
1263
+
1264
+ def _set_first_frame(self):
1265
+ for name, module in self.named_modules():
1266
+ if isinstance(module, CogVideoXUpsample3D):
1267
+ module.auto_split_process = False
1268
+ module.first_frame_flag = True
1269
+
1270
+ def _set_rest_frame(self):
1271
+ for name, module in self.named_modules():
1272
+ if isinstance(module, CogVideoXUpsample3D):
1273
+ module.auto_split_process = False
1274
+ module.first_frame_flag = False
1275
+
1276
+ def enable_auto_split_process(self) -> None:
1277
+ self.auto_split_process = True
1278
+ for name, module in self.named_modules():
1279
+ if isinstance(module, CogVideoXUpsample3D):
1280
+ module.auto_split_process = True
1281
+
1282
+ def disable_auto_split_process(self) -> None:
1283
+ self.auto_split_process = False
1284
+
1285
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1286
+ batch_size, num_channels, num_frames, height, width = x.shape
1287
+
1288
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1289
+ return self.tiled_encode(x)
1290
+
1291
+ frame_batch_size = self.num_sample_frames_batch_size
1292
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1293
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1294
+ num_batches = max(num_frames // frame_batch_size, 1)
1295
+ conv_cache = None
1296
+ enc = []
1297
+
1298
+ for i in range(num_batches):
1299
+ remaining_frames = num_frames % frame_batch_size
1300
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1301
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1302
+ x_intermediate = x[:, :, start_frame:end_frame]
1303
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1304
+ if self.quant_conv is not None:
1305
+ x_intermediate = self.quant_conv(x_intermediate)
1306
+ enc.append(x_intermediate)
1307
+
1308
+ enc = torch.cat(enc, dim=2)
1309
+ return enc
1310
+
1311
+ @apply_forward_hook
1312
+ def encode(
1313
+ self, x: torch.Tensor, return_dict: bool = True
1314
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1315
+ """
1316
+ Encode a batch of images into latents.
1317
+
1318
+ Args:
1319
+ x (`torch.Tensor`): Input batch of images.
1320
+ return_dict (`bool`, *optional*, defaults to `True`):
1321
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1322
+
1323
+ Returns:
1324
+ The latent representations of the encoded videos. If `return_dict` is True, a
1325
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1326
+ """
1327
+ if self.use_slicing and x.shape[0] > 1:
1328
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1329
+ h = torch.cat(encoded_slices)
1330
+ else:
1331
+ h = self._encode(x)
1332
+
1333
+ posterior = DiagonalGaussianDistribution(h)
1334
+
1335
+ if not return_dict:
1336
+ return (posterior,)
1337
+ return AutoencoderKLOutput(latent_dist=posterior)
1338
+
1339
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1340
+ batch_size, num_channels, num_frames, height, width = z.shape
1341
+
1342
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1343
+ return self.tiled_decode(z, return_dict=return_dict)
1344
+
1345
+ if self.auto_split_process:
1346
+ frame_batch_size = self.num_latent_frames_batch_size
1347
+ num_batches = max(num_frames // frame_batch_size, 1)
1348
+ conv_cache = None
1349
+ dec = []
1350
+
1351
+ for i in range(num_batches):
1352
+ remaining_frames = num_frames % frame_batch_size
1353
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1354
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1355
+ z_intermediate = z[:, :, start_frame:end_frame]
1356
+ if self.post_quant_conv is not None:
1357
+ z_intermediate = self.post_quant_conv(z_intermediate)
1358
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1359
+ dec.append(z_intermediate)
1360
+ else:
1361
+ conv_cache = None
1362
+ start_frame = 0
1363
+ end_frame = 1
1364
+ dec = []
1365
+
1366
+ self._set_first_frame()
1367
+ z_intermediate = z[:, :, start_frame:end_frame]
1368
+ if self.post_quant_conv is not None:
1369
+ z_intermediate = self.post_quant_conv(z_intermediate)
1370
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1371
+ dec.append(z_intermediate)
1372
+
1373
+ self._set_rest_frame()
1374
+ start_frame = end_frame
1375
+ end_frame += self.num_latent_frames_batch_size
1376
+
1377
+ while start_frame < num_frames:
1378
+ z_intermediate = z[:, :, start_frame:end_frame]
1379
+ if self.post_quant_conv is not None:
1380
+ z_intermediate = self.post_quant_conv(z_intermediate)
1381
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1382
+ dec.append(z_intermediate)
1383
+ start_frame = end_frame
1384
+ end_frame += self.num_latent_frames_batch_size
1385
+
1386
+ dec = torch.cat(dec, dim=2)
1387
+
1388
+ if not return_dict:
1389
+ return (dec,)
1390
+
1391
+ return DecoderOutput(sample=dec)
1392
+
1393
+ @apply_forward_hook
1394
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1395
+ """
1396
+ Decode a batch of images.
1397
+
1398
+ Args:
1399
+ z (`torch.Tensor`): Input batch of latent vectors.
1400
+ return_dict (`bool`, *optional*, defaults to `True`):
1401
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1402
+
1403
+ Returns:
1404
+ [`~models.vae.DecoderOutput`] or `tuple`:
1405
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1406
+ returned.
1407
+ """
1408
+ if self.use_slicing and z.shape[0] > 1:
1409
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1410
+ decoded = torch.cat(decoded_slices)
1411
+ else:
1412
+ decoded = self._decode(z).sample
1413
+
1414
+ if not return_dict:
1415
+ return (decoded,)
1416
+ return DecoderOutput(sample=decoded)
1417
+
1418
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1419
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1420
+ for y in range(blend_extent):
1421
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1422
+ y / blend_extent
1423
+ )
1424
+ return b
1425
+
1426
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1427
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1428
+ for x in range(blend_extent):
1429
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1430
+ x / blend_extent
1431
+ )
1432
+ return b
1433
+
1434
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1435
+ r"""Encode a batch of images using a tiled encoder.
1436
+
1437
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1438
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1439
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1440
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1441
+ output, but they should be much less noticeable.
1442
+
1443
+ Args:
1444
+ x (`torch.Tensor`): Input batch of videos.
1445
+
1446
+ Returns:
1447
+ `torch.Tensor`:
1448
+ The latent representation of the encoded videos.
1449
+ """
1450
+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1451
+ batch_size, num_channels, num_frames, height, width = x.shape
1452
+
1453
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1454
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1455
+ blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1456
+ blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1457
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
1458
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
1459
+ frame_batch_size = self.num_sample_frames_batch_size
1460
+
1461
+ # Split x into overlapping tiles and encode them separately.
1462
+ # The tiles have an overlap to avoid seams between tiles.
1463
+ rows = []
1464
+ for i in range(0, height, overlap_height):
1465
+ row = []
1466
+ for j in range(0, width, overlap_width):
1467
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1468
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1469
+ num_batches = max(num_frames // frame_batch_size, 1)
1470
+ conv_cache = None
1471
+ time = []
1472
+
1473
+ for k in range(num_batches):
1474
+ remaining_frames = num_frames % frame_batch_size
1475
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1476
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1477
+ tile = x[
1478
+ :,
1479
+ :,
1480
+ start_frame:end_frame,
1481
+ i : i + self.tile_sample_min_height,
1482
+ j : j + self.tile_sample_min_width,
1483
+ ]
1484
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1485
+ if self.quant_conv is not None:
1486
+ tile = self.quant_conv(tile)
1487
+ time.append(tile)
1488
+
1489
+ row.append(torch.cat(time, dim=2))
1490
+ rows.append(row)
1491
+
1492
+ result_rows = []
1493
+ for i, row in enumerate(rows):
1494
+ result_row = []
1495
+ for j, tile in enumerate(row):
1496
+ # blend the above tile and the left tile
1497
+ # to the current tile and add the current tile to the result row
1498
+ if i > 0:
1499
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1500
+ if j > 0:
1501
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1502
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1503
+ result_rows.append(torch.cat(result_row, dim=4))
1504
+
1505
+ enc = torch.cat(result_rows, dim=3)
1506
+ return enc
1507
+
1508
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1509
+ r"""
1510
+ Decode a batch of images using a tiled decoder.
1511
+
1512
+ Args:
1513
+ z (`torch.Tensor`): Input batch of latent vectors.
1514
+ return_dict (`bool`, *optional*, defaults to `True`):
1515
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1516
+
1517
+ Returns:
1518
+ [`~models.vae.DecoderOutput`] or `tuple`:
1519
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1520
+ returned.
1521
+ """
1522
+ # Rough memory assessment:
1523
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1524
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1525
+ # - Assume fp16 (2 bytes per value).
1526
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1527
+ #
1528
+ # Memory assessment when using tiling:
1529
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1530
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1531
+
1532
+ batch_size, num_channels, num_frames, height, width = z.shape
1533
+
1534
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1535
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1536
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1537
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1538
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1539
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1540
+ frame_batch_size = self.num_latent_frames_batch_size
1541
+
1542
+ # Split z into overlapping tiles and decode them separately.
1543
+ # The tiles have an overlap to avoid seams between tiles.
1544
+ rows = []
1545
+ for i in range(0, height, overlap_height):
1546
+ row = []
1547
+ for j in range(0, width, overlap_width):
1548
+ if self.auto_split_process:
1549
+ num_batches = max(num_frames // frame_batch_size, 1)
1550
+ conv_cache = None
1551
+ time = []
1552
+
1553
+ for k in range(num_batches):
1554
+ remaining_frames = num_frames % frame_batch_size
1555
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1556
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1557
+ tile = z[
1558
+ :,
1559
+ :,
1560
+ start_frame:end_frame,
1561
+ i : i + self.tile_latent_min_height,
1562
+ j : j + self.tile_latent_min_width,
1563
+ ]
1564
+ if self.post_quant_conv is not None:
1565
+ tile = self.post_quant_conv(tile)
1566
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1567
+ time.append(tile)
1568
+
1569
+ row.append(torch.cat(time, dim=2))
1570
+ else:
1571
+ conv_cache = None
1572
+ start_frame = 0
1573
+ end_frame = 1
1574
+ dec = []
1575
+
1576
+ tile = z[
1577
+ :,
1578
+ :,
1579
+ start_frame:end_frame,
1580
+ i : i + self.tile_latent_min_height,
1581
+ j : j + self.tile_latent_min_width,
1582
+ ]
1583
+
1584
+ self._set_first_frame()
1585
+ if self.post_quant_conv is not None:
1586
+ tile = self.post_quant_conv(tile)
1587
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1588
+ dec.append(tile)
1589
+
1590
+ self._set_rest_frame()
1591
+ start_frame = end_frame
1592
+ end_frame += self.num_latent_frames_batch_size
1593
+
1594
+ while start_frame < num_frames:
1595
+ tile = z[
1596
+ :,
1597
+ :,
1598
+ start_frame:end_frame,
1599
+ i : i + self.tile_latent_min_height,
1600
+ j : j + self.tile_latent_min_width,
1601
+ ]
1602
+ if self.post_quant_conv is not None:
1603
+ tile = self.post_quant_conv(tile)
1604
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1605
+ dec.append(tile)
1606
+ start_frame = end_frame
1607
+ end_frame += self.num_latent_frames_batch_size
1608
+
1609
+ row.append(torch.cat(dec, dim=2))
1610
+ rows.append(row)
1611
+
1612
+ result_rows = []
1613
+ for i, row in enumerate(rows):
1614
+ result_row = []
1615
+ for j, tile in enumerate(row):
1616
+ # blend the above tile and the left tile
1617
+ # to the current tile and add the current tile to the result row
1618
+ if i > 0:
1619
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1620
+ if j > 0:
1621
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1622
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1623
+ result_rows.append(torch.cat(result_row, dim=4))
1624
+
1625
+ dec = torch.cat(result_rows, dim=3)
1626
+
1627
+ if not return_dict:
1628
+ return (dec,)
1629
+
1630
+ return DecoderOutput(sample=dec)
1631
+
1632
+ def forward(
1633
+ self,
1634
+ sample: torch.Tensor,
1635
+ sample_posterior: bool = False,
1636
+ return_dict: bool = True,
1637
+ generator: Optional[torch.Generator] = None,
1638
+ ) -> Union[torch.Tensor, torch.Tensor]:
1639
+ x = sample
1640
+ posterior = self.encode(x).latent_dist
1641
+ if sample_posterior:
1642
+ z = posterior.sample(generator=generator)
1643
+ else:
1644
+ z = posterior.mode()
1645
+ dec = self.decode(z)
1646
+ if not return_dict:
1647
+ return (dec,)
1648
+ return dec
1649
+
1650
+ @classmethod
1651
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
1652
+ if subfolder is not None:
1653
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1654
+
1655
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1656
+ if not os.path.isfile(config_file):
1657
+ raise RuntimeError(f"{config_file} does not exist")
1658
+ with open(config_file, "r") as f:
1659
+ config = json.load(f)
1660
+
1661
+ model = cls.from_config(config, **vae_additional_kwargs)
1662
+ from diffusers.utils import WEIGHTS_NAME
1663
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1664
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1665
+ if os.path.exists(model_file_safetensors):
1666
+ from safetensors.torch import load_file, safe_open
1667
+ state_dict = load_file(model_file_safetensors)
1668
+ else:
1669
+ if not os.path.isfile(model_file):
1670
+ raise RuntimeError(f"{model_file} does not exist")
1671
+ state_dict = torch.load(model_file, map_location="cpu")
1672
+ m, u = model.load_state_dict(state_dict, strict=False)
1673
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1674
+ print(m, u)
1675
+ return model
videox_fun/models/flux_transformer2d.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py
2
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import AttentionProcessor
28
+ from diffusers.models.embeddings import (
29
+ CombinedTimestepGuidanceTextProjEmbeddings,
30
+ CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed)
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import (AdaLayerNormContinuous,
34
+ AdaLayerNormZero,
35
+ AdaLayerNormZeroSingle)
36
+ from diffusers.utils import (USE_PEFT_BACKEND, logging, scale_lora_layers,
37
+ unscale_lora_layers)
38
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
39
+
40
+ from ..dist import (FluxMultiGPUsAttnProcessor2_0, get_sequence_parallel_rank,
41
+ get_sequence_parallel_world_size, get_sp_group)
42
+ from .attention_utils import attention
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
47
+ query = attn.to_q(hidden_states)
48
+ key = attn.to_k(hidden_states)
49
+ value = attn.to_v(hidden_states)
50
+
51
+ encoder_query = encoder_key = encoder_value = None
52
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
53
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
54
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
55
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
56
+
57
+ return query, key, value, encoder_query, encoder_key, encoder_value
58
+
59
+ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
60
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
61
+
62
+ def apply_rotary_emb(
63
+ x: torch.Tensor,
64
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
65
+ use_real: bool = True,
66
+ use_real_unbind_dim: int = -1,
67
+ sequence_dim: int = 2,
68
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
69
+ """
70
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
71
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
72
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
73
+ tensors contain rotary embeddings and are returned as real tensors.
74
+
75
+ Args:
76
+ x (`torch.Tensor`):
77
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
78
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
79
+
80
+ Returns:
81
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
82
+ """
83
+ if use_real:
84
+ cos, sin = freqs_cis # [S, D]
85
+ if sequence_dim == 2:
86
+ cos = cos[None, None, :, :]
87
+ sin = sin[None, None, :, :]
88
+ elif sequence_dim == 1:
89
+ cos = cos[None, :, None, :]
90
+ sin = sin[None, :, None, :]
91
+ else:
92
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
93
+
94
+ cos, sin = cos.to(x.device), sin.to(x.device)
95
+
96
+ if use_real_unbind_dim == -1:
97
+ # Used for flux, cogvideox, hunyuan-dit
98
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
99
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
100
+ elif use_real_unbind_dim == -2:
101
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
102
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
103
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
104
+ else:
105
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
106
+
107
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
108
+
109
+ return out
110
+ else:
111
+ # used for lumina
112
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
113
+ freqs_cis = freqs_cis.unsqueeze(2)
114
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
115
+
116
+ return x_out.type_as(x)
117
+
118
+
119
+ class FluxAttnProcessor:
120
+ _attention_backend = None
121
+
122
+ def __init__(self):
123
+ if not hasattr(F, "scaled_dot_product_attention"):
124
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
125
+
126
+ def __call__(
127
+ self,
128
+ attn: "FluxAttention",
129
+ hidden_states: torch.Tensor,
130
+ encoder_hidden_states: torch.Tensor = None,
131
+ attention_mask: Optional[torch.Tensor] = None,
132
+ image_rotary_emb: Optional[torch.Tensor] = None,
133
+ text_seq_len: int = None,
134
+ ) -> torch.Tensor:
135
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
136
+ attn, hidden_states, encoder_hidden_states
137
+ )
138
+
139
+ query = query.unflatten(-1, (attn.heads, -1))
140
+ key = key.unflatten(-1, (attn.heads, -1))
141
+ value = value.unflatten(-1, (attn.heads, -1))
142
+
143
+ query = attn.norm_q(query)
144
+ key = attn.norm_k(key)
145
+
146
+ if attn.added_kv_proj_dim is not None:
147
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
148
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
149
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
150
+
151
+ encoder_query = attn.norm_added_q(encoder_query)
152
+ encoder_key = attn.norm_added_k(encoder_key)
153
+
154
+ query = torch.cat([encoder_query, query], dim=1)
155
+ key = torch.cat([encoder_key, key], dim=1)
156
+ value = torch.cat([encoder_value, value], dim=1)
157
+
158
+ if image_rotary_emb is not None:
159
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
160
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
161
+
162
+ hidden_states = attention(
163
+ query, key, value, attn_mask=attention_mask,
164
+ )
165
+ hidden_states = hidden_states.flatten(2, 3)
166
+ hidden_states = hidden_states.to(query.dtype)
167
+
168
+ if encoder_hidden_states is not None:
169
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
170
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
171
+ )
172
+ hidden_states = attn.to_out[0](hidden_states)
173
+ hidden_states = attn.to_out[1](hidden_states)
174
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
175
+
176
+ return hidden_states, encoder_hidden_states
177
+ else:
178
+ return hidden_states
179
+
180
+
181
+ class FluxIPAdapterAttnProcessor(torch.nn.Module):
182
+ """Flux Attention processor for IP-Adapter."""
183
+
184
+ _attention_backend = None
185
+
186
+ def __init__(
187
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
188
+ ):
189
+ super().__init__()
190
+
191
+ if not hasattr(F, "scaled_dot_product_attention"):
192
+ raise ImportError(
193
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
194
+ )
195
+
196
+ self.hidden_size = hidden_size
197
+ self.cross_attention_dim = cross_attention_dim
198
+
199
+ if not isinstance(num_tokens, (tuple, list)):
200
+ num_tokens = [num_tokens]
201
+
202
+ if not isinstance(scale, list):
203
+ scale = [scale] * len(num_tokens)
204
+ if len(scale) != len(num_tokens):
205
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
206
+ self.scale = scale
207
+
208
+ self.to_k_ip = nn.ModuleList(
209
+ [
210
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
211
+ for _ in range(len(num_tokens))
212
+ ]
213
+ )
214
+ self.to_v_ip = nn.ModuleList(
215
+ [
216
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
217
+ for _ in range(len(num_tokens))
218
+ ]
219
+ )
220
+
221
+ def __call__(
222
+ self,
223
+ attn: "FluxAttention",
224
+ hidden_states: torch.Tensor,
225
+ encoder_hidden_states: torch.Tensor = None,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ image_rotary_emb: Optional[torch.Tensor] = None,
228
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
229
+ ip_adapter_masks: Optional[torch.Tensor] = None,
230
+ ) -> torch.Tensor:
231
+ batch_size = hidden_states.shape[0]
232
+
233
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
234
+ attn, hidden_states, encoder_hidden_states
235
+ )
236
+
237
+ query = query.unflatten(-1, (attn.heads, -1))
238
+ key = key.unflatten(-1, (attn.heads, -1))
239
+ value = value.unflatten(-1, (attn.heads, -1))
240
+
241
+ query = attn.norm_q(query)
242
+ key = attn.norm_k(key)
243
+ ip_query = query
244
+
245
+ if encoder_hidden_states is not None:
246
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
247
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
248
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
249
+
250
+ encoder_query = attn.norm_added_q(encoder_query)
251
+ encoder_key = attn.norm_added_k(encoder_key)
252
+
253
+ query = torch.cat([encoder_query, query], dim=1)
254
+ key = torch.cat([encoder_key, key], dim=1)
255
+ value = torch.cat([encoder_value, value], dim=1)
256
+
257
+ if image_rotary_emb is not None:
258
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
259
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
260
+
261
+ hidden_states = attention(
262
+ query,
263
+ key,
264
+ value,
265
+ attn_mask=attention_mask,
266
+ dropout_p=0.0,
267
+ is_causal=False,
268
+ )
269
+ hidden_states = hidden_states.flatten(2, 3)
270
+ hidden_states = hidden_states.to(query.dtype)
271
+
272
+ if encoder_hidden_states is not None:
273
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
274
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
275
+ )
276
+ hidden_states = attn.to_out[0](hidden_states)
277
+ hidden_states = attn.to_out[1](hidden_states)
278
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
279
+
280
+ # IP-adapter
281
+ ip_attn_output = torch.zeros_like(hidden_states)
282
+
283
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
284
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
285
+ ):
286
+ ip_key = to_k_ip(current_ip_hidden_states)
287
+ ip_value = to_v_ip(current_ip_hidden_states)
288
+
289
+ ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
290
+ ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
291
+
292
+ current_ip_hidden_states = dispatch_attention_fn(
293
+ ip_query,
294
+ ip_key,
295
+ ip_value,
296
+ attn_mask=None,
297
+ dropout_p=0.0,
298
+ is_causal=False,
299
+ backend=self._attention_backend,
300
+ )
301
+ current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
302
+ current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
303
+ ip_attn_output += scale * current_ip_hidden_states
304
+
305
+ return hidden_states, encoder_hidden_states, ip_attn_output
306
+ else:
307
+ return hidden_states
308
+
309
+
310
+ class FluxAttention(torch.nn.Module):
311
+ _default_processor_cls = FluxAttnProcessor
312
+ _available_processors = [
313
+ FluxAttnProcessor,
314
+ FluxIPAdapterAttnProcessor,
315
+ ]
316
+
317
+ def __init__(
318
+ self,
319
+ query_dim: int,
320
+ heads: int = 8,
321
+ dim_head: int = 64,
322
+ dropout: float = 0.0,
323
+ bias: bool = False,
324
+ added_kv_proj_dim: Optional[int] = None,
325
+ added_proj_bias: Optional[bool] = True,
326
+ out_bias: bool = True,
327
+ eps: float = 1e-5,
328
+ out_dim: int = None,
329
+ context_pre_only: Optional[bool] = None,
330
+ pre_only: bool = False,
331
+ elementwise_affine: bool = True,
332
+ processor=None,
333
+ ):
334
+ super().__init__()
335
+
336
+ self.head_dim = dim_head
337
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
338
+ self.query_dim = query_dim
339
+ self.use_bias = bias
340
+ self.dropout = dropout
341
+ self.out_dim = out_dim if out_dim is not None else query_dim
342
+ self.context_pre_only = context_pre_only
343
+ self.pre_only = pre_only
344
+ self.heads = out_dim // dim_head if out_dim is not None else heads
345
+ self.added_kv_proj_dim = added_kv_proj_dim
346
+ self.added_proj_bias = added_proj_bias
347
+
348
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
349
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
350
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
351
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
352
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
353
+
354
+ if not self.pre_only:
355
+ self.to_out = torch.nn.ModuleList([])
356
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
357
+ self.to_out.append(torch.nn.Dropout(dropout))
358
+
359
+ if added_kv_proj_dim is not None:
360
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
361
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
362
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
363
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
364
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
365
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
366
+
367
+ if processor is None:
368
+ self.processor = self._default_processor_cls()
369
+ else:
370
+ self.processor = processor
371
+
372
+ def set_processor(self, processor: "AttnProcessor") -> None:
373
+ r"""
374
+ Set the attention processor to use.
375
+
376
+ Args:
377
+ processor (`AttnProcessor`):
378
+ The attention processor to use.
379
+ """
380
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
381
+ # pop `processor` from `self._modules`
382
+ if (
383
+ hasattr(self, "processor")
384
+ and isinstance(self.processor, torch.nn.Module)
385
+ and not isinstance(processor, torch.nn.Module)
386
+ ):
387
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
388
+ self._modules.pop("processor")
389
+
390
+ self.processor = processor
391
+
392
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
393
+ r"""
394
+ Get the attention processor in use.
395
+
396
+ Args:
397
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
398
+ Set to `True` to return the deprecated LoRA attention processor.
399
+
400
+ Returns:
401
+ "AttentionProcessor": The attention processor in use.
402
+ """
403
+ if not return_deprecated_lora:
404
+ return self.processor
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states: torch.Tensor,
409
+ encoder_hidden_states: Optional[torch.Tensor] = None,
410
+ attention_mask: Optional[torch.Tensor] = None,
411
+ image_rotary_emb: Optional[torch.Tensor] = None,
412
+ **kwargs,
413
+ ) -> torch.Tensor:
414
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
415
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
416
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
417
+ if len(unused_kwargs) > 0:
418
+ logger.warning(
419
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
420
+ )
421
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
422
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
423
+
424
+
425
+ @maybe_allow_in_graph
426
+ class FluxSingleTransformerBlock(nn.Module):
427
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
428
+ super().__init__()
429
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
430
+
431
+ self.norm = AdaLayerNormZeroSingle(dim)
432
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
433
+ self.act_mlp = nn.GELU(approximate="tanh")
434
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
435
+
436
+ self.attn = FluxAttention(
437
+ query_dim=dim,
438
+ dim_head=attention_head_dim,
439
+ heads=num_attention_heads,
440
+ out_dim=dim,
441
+ bias=True,
442
+ processor=FluxAttnProcessor(),
443
+ eps=1e-6,
444
+ pre_only=True,
445
+ )
446
+
447
+ def forward(
448
+ self,
449
+ hidden_states: torch.Tensor,
450
+ encoder_hidden_states: torch.Tensor,
451
+ temb: torch.Tensor,
452
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
453
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
454
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
455
+ text_seq_len = encoder_hidden_states.shape[1]
456
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
457
+
458
+ residual = hidden_states
459
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
460
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
461
+ joint_attention_kwargs = joint_attention_kwargs or {}
462
+ attn_output = self.attn(
463
+ hidden_states=norm_hidden_states,
464
+ image_rotary_emb=image_rotary_emb,
465
+ text_seq_len=text_seq_len,
466
+ **joint_attention_kwargs,
467
+ )
468
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
469
+ gate = gate.unsqueeze(1)
470
+ hidden_states = gate * self.proj_out(hidden_states)
471
+ hidden_states = residual + hidden_states
472
+ if hidden_states.dtype == torch.float16:
473
+ hidden_states = hidden_states.clip(-65504, 65504)
474
+
475
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
476
+ return encoder_hidden_states, hidden_states
477
+
478
+
479
+ @maybe_allow_in_graph
480
+ class FluxTransformerBlock(nn.Module):
481
+ def __init__(
482
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
483
+ ):
484
+ super().__init__()
485
+
486
+ self.norm1 = AdaLayerNormZero(dim)
487
+ self.norm1_context = AdaLayerNormZero(dim)
488
+
489
+ self.attn = FluxAttention(
490
+ query_dim=dim,
491
+ added_kv_proj_dim=dim,
492
+ dim_head=attention_head_dim,
493
+ heads=num_attention_heads,
494
+ out_dim=dim,
495
+ context_pre_only=False,
496
+ bias=True,
497
+ processor=FluxAttnProcessor(),
498
+ eps=eps,
499
+ )
500
+
501
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
502
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
503
+
504
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
505
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
506
+
507
+ def forward(
508
+ self,
509
+ hidden_states: torch.Tensor,
510
+ encoder_hidden_states: torch.Tensor,
511
+ temb: torch.Tensor,
512
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
513
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
514
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
515
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
516
+
517
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
518
+ encoder_hidden_states, emb=temb
519
+ )
520
+ joint_attention_kwargs = joint_attention_kwargs or {}
521
+
522
+ # Attention.
523
+ attention_outputs = self.attn(
524
+ hidden_states=norm_hidden_states,
525
+ encoder_hidden_states=norm_encoder_hidden_states,
526
+ image_rotary_emb=image_rotary_emb,
527
+ **joint_attention_kwargs,
528
+ )
529
+
530
+ if len(attention_outputs) == 2:
531
+ attn_output, context_attn_output = attention_outputs
532
+ elif len(attention_outputs) == 3:
533
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
534
+
535
+ # Process attention outputs for the `hidden_states`.
536
+ attn_output = gate_msa.unsqueeze(1) * attn_output
537
+ hidden_states = hidden_states + attn_output
538
+
539
+ norm_hidden_states = self.norm2(hidden_states)
540
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
541
+
542
+ ff_output = self.ff(norm_hidden_states)
543
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
544
+
545
+ hidden_states = hidden_states + ff_output
546
+ if len(attention_outputs) == 3:
547
+ hidden_states = hidden_states + ip_attn_output
548
+
549
+ # Process attention outputs for the `encoder_hidden_states`.
550
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
551
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
552
+
553
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
554
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
555
+
556
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
557
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
558
+ if encoder_hidden_states.dtype == torch.float16:
559
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
560
+
561
+ return encoder_hidden_states, hidden_states
562
+
563
+
564
+ class FluxPosEmbed(nn.Module):
565
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
566
+ def __init__(self, theta: int, axes_dim: List[int]):
567
+ super().__init__()
568
+ self.theta = theta
569
+ self.axes_dim = axes_dim
570
+
571
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
572
+ n_axes = ids.shape[-1]
573
+ cos_out = []
574
+ sin_out = []
575
+ pos = ids.float()
576
+ is_mps = ids.device.type == "mps"
577
+ is_npu = ids.device.type == "npu"
578
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
579
+ for i in range(n_axes):
580
+ cos, sin = get_1d_rotary_pos_embed(
581
+ self.axes_dim[i],
582
+ pos[:, i],
583
+ theta=self.theta,
584
+ repeat_interleave_real=True,
585
+ use_real=True,
586
+ freqs_dtype=freqs_dtype,
587
+ )
588
+ cos_out.append(cos)
589
+ sin_out.append(sin)
590
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
591
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
592
+ return freqs_cos, freqs_sin
593
+
594
+
595
+ class FluxTransformer2DModel(
596
+ ModelMixin,
597
+ ConfigMixin,
598
+ PeftAdapterMixin,
599
+ FromOriginalModelMixin,
600
+ ):
601
+ """
602
+ The Transformer model introduced in Flux.
603
+
604
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
605
+
606
+ Args:
607
+ patch_size (`int`, defaults to `1`):
608
+ Patch size to turn the input data into small patches.
609
+ in_channels (`int`, defaults to `64`):
610
+ The number of channels in the input.
611
+ out_channels (`int`, *optional*, defaults to `None`):
612
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
613
+ num_layers (`int`, defaults to `19`):
614
+ The number of layers of dual stream DiT blocks to use.
615
+ num_single_layers (`int`, defaults to `38`):
616
+ The number of layers of single stream DiT blocks to use.
617
+ attention_head_dim (`int`, defaults to `128`):
618
+ The number of dimensions to use for each attention head.
619
+ num_attention_heads (`int`, defaults to `24`):
620
+ The number of attention heads to use.
621
+ joint_attention_dim (`int`, defaults to `4096`):
622
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
623
+ `encoder_hidden_states`).
624
+ pooled_projection_dim (`int`, defaults to `768`):
625
+ The number of dimensions to use for the pooled projection.
626
+ guidance_embeds (`bool`, defaults to `False`):
627
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
628
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
629
+ The dimensions to use for the rotary positional embeddings.
630
+ """
631
+
632
+ _supports_gradient_checkpointing = True
633
+ # _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
634
+ # _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
635
+ # _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
636
+
637
+ @register_to_config
638
+ def __init__(
639
+ self,
640
+ patch_size: int = 1,
641
+ in_channels: int = 64,
642
+ out_channels: Optional[int] = None,
643
+ num_layers: int = 19,
644
+ num_single_layers: int = 38,
645
+ attention_head_dim: int = 128,
646
+ num_attention_heads: int = 24,
647
+ joint_attention_dim: int = 4096,
648
+ pooled_projection_dim: int = 768,
649
+ guidance_embeds: bool = False,
650
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
651
+ ):
652
+ super().__init__()
653
+ self.out_channels = out_channels or in_channels
654
+ self.inner_dim = num_attention_heads * attention_head_dim
655
+
656
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
657
+
658
+ text_time_guidance_cls = (
659
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
660
+ )
661
+ self.time_text_embed = text_time_guidance_cls(
662
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
663
+ )
664
+
665
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
666
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
667
+
668
+ self.transformer_blocks = nn.ModuleList(
669
+ [
670
+ FluxTransformerBlock(
671
+ dim=self.inner_dim,
672
+ num_attention_heads=num_attention_heads,
673
+ attention_head_dim=attention_head_dim,
674
+ )
675
+ for _ in range(num_layers)
676
+ ]
677
+ )
678
+
679
+ self.single_transformer_blocks = nn.ModuleList(
680
+ [
681
+ FluxSingleTransformerBlock(
682
+ dim=self.inner_dim,
683
+ num_attention_heads=num_attention_heads,
684
+ attention_head_dim=attention_head_dim,
685
+ )
686
+ for _ in range(num_single_layers)
687
+ ]
688
+ )
689
+
690
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
691
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
692
+
693
+ self.gradient_checkpointing = False
694
+
695
+ self.sp_world_size = 1
696
+ self.sp_world_rank = 0
697
+
698
+ def enable_multi_gpus_inference(self,):
699
+ self.sp_world_size = get_sequence_parallel_world_size()
700
+ self.sp_world_rank = get_sequence_parallel_rank()
701
+ self.all_gather = get_sp_group().all_gather
702
+ self.set_attn_processor(FluxMultiGPUsAttnProcessor2_0())
703
+
704
+ @property
705
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
706
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
707
+ r"""
708
+ Returns:
709
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
710
+ indexed by its weight name.
711
+ """
712
+ # set recursively
713
+ processors = {}
714
+
715
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
716
+ if hasattr(module, "get_processor"):
717
+ processors[f"{name}.processor"] = module.get_processor()
718
+
719
+ for sub_name, child in module.named_children():
720
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
721
+
722
+ return processors
723
+
724
+ for name, module in self.named_children():
725
+ fn_recursive_add_processors(name, module, processors)
726
+
727
+ return processors
728
+
729
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
730
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
731
+ r"""
732
+ Sets the attention processor to use to compute attention.
733
+
734
+ Parameters:
735
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
736
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
737
+ for **all** `Attention` layers.
738
+
739
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
740
+ processor. This is strongly recommended when setting trainable attention processors.
741
+
742
+ """
743
+ count = len(self.attn_processors.keys())
744
+
745
+ if isinstance(processor, dict) and len(processor) != count:
746
+ raise ValueError(
747
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
748
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
749
+ )
750
+
751
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
752
+ if hasattr(module, "set_processor"):
753
+ if not isinstance(processor, dict):
754
+ module.set_processor(processor)
755
+ else:
756
+ module.set_processor(processor.pop(f"{name}.processor"))
757
+
758
+ for sub_name, child in module.named_children():
759
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
760
+
761
+ for name, module in self.named_children():
762
+ fn_recursive_attn_processor(name, module, processor)
763
+
764
+ def forward(
765
+ self,
766
+ hidden_states: torch.Tensor,
767
+ encoder_hidden_states: torch.Tensor = None,
768
+ pooled_projections: torch.Tensor = None,
769
+ timestep: torch.LongTensor = None,
770
+ img_ids: torch.Tensor = None,
771
+ txt_ids: torch.Tensor = None,
772
+ guidance: torch.Tensor = None,
773
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
774
+ controlnet_block_samples=None,
775
+ controlnet_single_block_samples=None,
776
+ return_dict: bool = True,
777
+ controlnet_blocks_repeat: bool = False,
778
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
779
+ """
780
+ The [`FluxTransformer2DModel`] forward method.
781
+
782
+ Args:
783
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
784
+ Input `hidden_states`.
785
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
786
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
787
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
788
+ from the embeddings of input conditions.
789
+ timestep ( `torch.LongTensor`):
790
+ Used to indicate denoising step.
791
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
792
+ A list of tensors that if specified are added to the residuals of transformer blocks.
793
+ joint_attention_kwargs (`dict`, *optional*):
794
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
795
+ `self.processor` in
796
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
797
+ return_dict (`bool`, *optional*, defaults to `True`):
798
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
799
+ tuple.
800
+
801
+ Returns:
802
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
803
+ `tuple` where the first element is the sample tensor.
804
+ """
805
+ if joint_attention_kwargs is not None:
806
+ joint_attention_kwargs = joint_attention_kwargs.copy()
807
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
808
+ else:
809
+ lora_scale = 1.0
810
+
811
+ if USE_PEFT_BACKEND:
812
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
813
+ scale_lora_layers(self, lora_scale)
814
+ else:
815
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
816
+ logger.warning(
817
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
818
+ )
819
+
820
+ hidden_states = self.x_embedder(hidden_states)
821
+
822
+ timestep = timestep.to(hidden_states.dtype) * 1000
823
+ if guidance is not None:
824
+ guidance = guidance.to(hidden_states.dtype) * 1000
825
+
826
+ temb = (
827
+ self.time_text_embed(timestep, pooled_projections)
828
+ if guidance is None
829
+ else self.time_text_embed(timestep, guidance, pooled_projections)
830
+ )
831
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
832
+
833
+ if txt_ids.ndim == 3:
834
+ logger.warning(
835
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
836
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
837
+ )
838
+ txt_ids = txt_ids[0]
839
+ if img_ids.ndim == 3:
840
+ logger.warning(
841
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
842
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
843
+ )
844
+ img_ids = img_ids[0]
845
+
846
+ ids = torch.cat((txt_ids, img_ids), dim=0)
847
+ image_rotary_emb = self.pos_embed(ids)
848
+
849
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
850
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
851
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
852
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
853
+
854
+ # Context Parallel
855
+ if self.sp_world_size > 1:
856
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
857
+ if image_rotary_emb is not None:
858
+ txt_rotary_emb = (
859
+ image_rotary_emb[0][:encoder_hidden_states.shape[1]],
860
+ image_rotary_emb[1][:encoder_hidden_states.shape[1]]
861
+ )
862
+ image_rotary_emb = (
863
+ torch.chunk(image_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
864
+ torch.chunk(image_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
865
+ )
866
+ image_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \
867
+ for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, image_rotary_emb)]
868
+
869
+ for index_block, block in enumerate(self.transformer_blocks):
870
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
871
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
872
+ block,
873
+ hidden_states,
874
+ encoder_hidden_states,
875
+ temb,
876
+ image_rotary_emb,
877
+ joint_attention_kwargs,
878
+ )
879
+
880
+ else:
881
+ encoder_hidden_states, hidden_states = block(
882
+ hidden_states=hidden_states,
883
+ encoder_hidden_states=encoder_hidden_states,
884
+ temb=temb,
885
+ image_rotary_emb=image_rotary_emb,
886
+ joint_attention_kwargs=joint_attention_kwargs,
887
+ )
888
+
889
+ # controlnet residual
890
+ if controlnet_block_samples is not None:
891
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
892
+ interval_control = int(np.ceil(interval_control))
893
+ # For Xlabs ControlNet.
894
+ if controlnet_blocks_repeat:
895
+ hidden_states = (
896
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
897
+ )
898
+ else:
899
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
900
+
901
+ for index_block, block in enumerate(self.single_transformer_blocks):
902
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
903
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
904
+ block,
905
+ hidden_states,
906
+ encoder_hidden_states,
907
+ temb,
908
+ image_rotary_emb,
909
+ joint_attention_kwargs,
910
+ )
911
+
912
+ else:
913
+ encoder_hidden_states, hidden_states = block(
914
+ hidden_states=hidden_states,
915
+ encoder_hidden_states=encoder_hidden_states,
916
+ temb=temb,
917
+ image_rotary_emb=image_rotary_emb,
918
+ joint_attention_kwargs=joint_attention_kwargs,
919
+ )
920
+
921
+ # controlnet residual
922
+ if controlnet_single_block_samples is not None:
923
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
924
+ interval_control = int(np.ceil(interval_control))
925
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
926
+
927
+ hidden_states = self.norm_out(hidden_states, temb)
928
+ output = self.proj_out(hidden_states)
929
+
930
+ if self.sp_world_size > 1:
931
+ output = self.all_gather(output, dim=1)
932
+
933
+ if USE_PEFT_BACKEND:
934
+ # remove `lora_scale` from each PEFT layer
935
+ unscale_lora_layers(self, lora_scale)
936
+
937
+ if not return_dict:
938
+ return (output,)
939
+
940
+ return Transformer2DModelOutput(sample=output)
videox_fun/models/qwenimage_transformer2d.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_qwenimage.py
2
+ # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import functools
18
+ import glob
19
+ import json
20
+ import math
21
+ import os
22
+ import types
23
+ import warnings
24
+ from typing import Any, Dict, List, Optional, Tuple, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.cuda.amp as amp
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
32
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
33
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
34
+ from diffusers.models.attention import Attention, FeedForward
35
+ from diffusers.models.attention_processor import (
36
+ Attention, AttentionProcessor, CogVideoXAttnProcessor2_0,
37
+ FusedCogVideoXAttnProcessor2_0)
38
+ from diffusers.models.embeddings import (CogVideoXPatchEmbed,
39
+ TimestepEmbedding, Timesteps,
40
+ get_3d_sincos_pos_embed)
41
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
42
+ from diffusers.models.modeling_utils import ModelMixin
43
+ from diffusers.models.normalization import (AdaLayerNorm,
44
+ AdaLayerNormContinuous,
45
+ CogVideoXLayerNormZero, RMSNorm)
46
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
47
+ scale_lora_layers, unscale_lora_layers)
48
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
49
+ from torch import nn
50
+
51
+ from ..dist import (QwenImageMultiGPUsAttnProcessor2_0,
52
+ get_sequence_parallel_rank,
53
+ get_sequence_parallel_world_size, get_sp_group)
54
+ from .attention_utils import attention
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+
59
+ def get_timestep_embedding(
60
+ timesteps: torch.Tensor,
61
+ embedding_dim: int,
62
+ flip_sin_to_cos: bool = False,
63
+ downscale_freq_shift: float = 1,
64
+ scale: float = 1,
65
+ max_period: int = 10000,
66
+ ) -> torch.Tensor:
67
+ """
68
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
69
+
70
+ Args
71
+ timesteps (torch.Tensor):
72
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
73
+ embedding_dim (int):
74
+ the dimension of the output.
75
+ flip_sin_to_cos (bool):
76
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
77
+ downscale_freq_shift (float):
78
+ Controls the delta between frequencies between dimensions
79
+ scale (float):
80
+ Scaling factor applied to the embeddings.
81
+ max_period (int):
82
+ Controls the maximum frequency of the embeddings
83
+ Returns
84
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
85
+ """
86
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
87
+
88
+ half_dim = embedding_dim // 2
89
+ exponent = -math.log(max_period) * torch.arange(
90
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
91
+ )
92
+ exponent = exponent / (half_dim - downscale_freq_shift)
93
+
94
+ emb = torch.exp(exponent).to(timesteps.dtype)
95
+ emb = timesteps[:, None].float() * emb[None, :]
96
+
97
+ # scale embeddings
98
+ emb = scale * emb
99
+
100
+ # concat sine and cosine embeddings
101
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
102
+
103
+ # flip sine and cosine embeddings
104
+ if flip_sin_to_cos:
105
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
106
+
107
+ # zero pad
108
+ if embedding_dim % 2 == 1:
109
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
110
+ return emb
111
+
112
+
113
+ def apply_rotary_emb_qwen(
114
+ x: torch.Tensor,
115
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
116
+ use_real: bool = True,
117
+ use_real_unbind_dim: int = -1,
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ """
120
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
121
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
122
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
123
+ tensors contain rotary embeddings and are returned as real tensors.
124
+
125
+ Args:
126
+ x (`torch.Tensor`):
127
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
128
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
129
+
130
+ Returns:
131
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
132
+ """
133
+ if use_real:
134
+ cos, sin = freqs_cis # [S, D]
135
+ cos = cos[None, None]
136
+ sin = sin[None, None]
137
+ cos, sin = cos.to(x.device), sin.to(x.device)
138
+
139
+ if use_real_unbind_dim == -1:
140
+ # Used for flux, cogvideox, hunyuan-dit
141
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
142
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
143
+ elif use_real_unbind_dim == -2:
144
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
145
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
146
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
147
+ else:
148
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
149
+
150
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
151
+
152
+ return out
153
+ else:
154
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
155
+ freqs_cis = freqs_cis.unsqueeze(1)
156
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
157
+
158
+ return x_out.type_as(x)
159
+
160
+
161
+ class QwenTimestepProjEmbeddings(nn.Module):
162
+ def __init__(self, embedding_dim):
163
+ super().__init__()
164
+
165
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
166
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
167
+
168
+ def forward(self, timestep, hidden_states):
169
+ timesteps_proj = self.time_proj(timestep)
170
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
171
+
172
+ conditioning = timesteps_emb
173
+
174
+ return conditioning
175
+
176
+
177
+ class QwenEmbedRope(nn.Module):
178
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
179
+ super().__init__()
180
+ self.theta = theta
181
+ self.axes_dim = axes_dim
182
+ pos_index = torch.arange(4096)
183
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
184
+ self.pos_freqs = torch.cat(
185
+ [
186
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
187
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
188
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
189
+ ],
190
+ dim=1,
191
+ )
192
+ self.neg_freqs = torch.cat(
193
+ [
194
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
195
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
196
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
197
+ ],
198
+ dim=1,
199
+ )
200
+ self.rope_cache = {}
201
+
202
+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
203
+ self.scale_rope = scale_rope
204
+
205
+ def rope_params(self, index, dim, theta=10000):
206
+ """
207
+ Args:
208
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
209
+ """
210
+ assert dim % 2 == 0
211
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
212
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
213
+ return freqs
214
+
215
+ def forward(self, video_fhw, txt_seq_lens, device):
216
+ """
217
+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
218
+ txt_length: [bs] a list of 1 integers representing the length of the text
219
+ """
220
+ if self.pos_freqs.device != device:
221
+ self.pos_freqs = self.pos_freqs.to(device)
222
+ self.neg_freqs = self.neg_freqs.to(device)
223
+
224
+ if isinstance(video_fhw, list):
225
+ video_fhw = video_fhw[0]
226
+ if not isinstance(video_fhw, list):
227
+ video_fhw = [video_fhw]
228
+
229
+ vid_freqs = []
230
+ max_vid_index = 0
231
+ for idx, fhw in enumerate(video_fhw):
232
+ frame, height, width = fhw
233
+ rope_key = f"{idx}_{height}_{width}"
234
+
235
+ if not torch.compiler.is_compiling():
236
+ if rope_key not in self.rope_cache:
237
+ self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
238
+ video_freq = self.rope_cache[rope_key]
239
+ else:
240
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
241
+ video_freq = video_freq.to(device)
242
+ vid_freqs.append(video_freq)
243
+
244
+ if self.scale_rope:
245
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
246
+ else:
247
+ max_vid_index = max(height, width, max_vid_index)
248
+
249
+ max_len = max(txt_seq_lens)
250
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
251
+ vid_freqs = torch.cat(vid_freqs, dim=0)
252
+
253
+ return vid_freqs, txt_freqs
254
+
255
+ @functools.lru_cache(maxsize=None)
256
+ def _compute_video_freqs(self, frame, height, width, idx=0):
257
+ seq_lens = frame * height * width
258
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
259
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
260
+
261
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
262
+ if self.scale_rope:
263
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
264
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
265
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
266
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
267
+ else:
268
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
269
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
270
+
271
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
272
+ return freqs.clone().contiguous()
273
+
274
+
275
+ class QwenDoubleStreamAttnProcessor2_0:
276
+ """
277
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
278
+ implements joint attention computation where text and image streams are processed together.
279
+ """
280
+
281
+ _attention_backend = None
282
+
283
+ def __init__(self):
284
+ if not hasattr(F, "scaled_dot_product_attention"):
285
+ raise ImportError(
286
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
287
+ )
288
+
289
+ def __call__(
290
+ self,
291
+ attn: Attention,
292
+ hidden_states: torch.FloatTensor, # Image stream
293
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
294
+ encoder_hidden_states_mask: torch.FloatTensor = None,
295
+ attention_mask: Optional[torch.FloatTensor] = None,
296
+ image_rotary_emb: Optional[torch.Tensor] = None,
297
+ ) -> torch.FloatTensor:
298
+ if encoder_hidden_states is None:
299
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
300
+
301
+ seq_txt = encoder_hidden_states.shape[1]
302
+
303
+ # Compute QKV for image stream (sample projections)
304
+ img_query = attn.to_q(hidden_states)
305
+ img_key = attn.to_k(hidden_states)
306
+ img_value = attn.to_v(hidden_states)
307
+
308
+ # Compute QKV for text stream (context projections)
309
+ txt_query = attn.add_q_proj(encoder_hidden_states)
310
+ txt_key = attn.add_k_proj(encoder_hidden_states)
311
+ txt_value = attn.add_v_proj(encoder_hidden_states)
312
+
313
+ # Reshape for multi-head attention
314
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
315
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
316
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
317
+
318
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
319
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
320
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
321
+
322
+ # Apply QK normalization
323
+ if attn.norm_q is not None:
324
+ img_query = attn.norm_q(img_query)
325
+ if attn.norm_k is not None:
326
+ img_key = attn.norm_k(img_key)
327
+ if attn.norm_added_q is not None:
328
+ txt_query = attn.norm_added_q(txt_query)
329
+ if attn.norm_added_k is not None:
330
+ txt_key = attn.norm_added_k(txt_key)
331
+
332
+ # Apply RoPE
333
+ if image_rotary_emb is not None:
334
+ img_freqs, txt_freqs = image_rotary_emb
335
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
336
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
337
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
338
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
339
+
340
+ # Concatenate for joint attention
341
+ # Order: [text, image]
342
+ joint_query = torch.cat([txt_query, img_query], dim=1)
343
+ joint_key = torch.cat([txt_key, img_key], dim=1)
344
+ joint_value = torch.cat([txt_value, img_value], dim=1)
345
+
346
+ joint_hidden_states = attention(
347
+ joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, causal=False
348
+ )
349
+
350
+ # Reshape back
351
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
352
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
353
+
354
+ # Split attention outputs back
355
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
356
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
357
+
358
+ # Apply output projections
359
+ img_attn_output = attn.to_out[0](img_attn_output)
360
+ if len(attn.to_out) > 1:
361
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
362
+
363
+ txt_attn_output = attn.to_add_out(txt_attn_output)
364
+
365
+ return img_attn_output, txt_attn_output
366
+
367
+
368
+ @maybe_allow_in_graph
369
+ class QwenImageTransformerBlock(nn.Module):
370
+ def __init__(
371
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
372
+ ):
373
+ super().__init__()
374
+
375
+ self.dim = dim
376
+ self.num_attention_heads = num_attention_heads
377
+ self.attention_head_dim = attention_head_dim
378
+
379
+ # Image processing modules
380
+ self.img_mod = nn.Sequential(
381
+ nn.SiLU(),
382
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
383
+ )
384
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
385
+ self.attn = Attention(
386
+ query_dim=dim,
387
+ cross_attention_dim=None, # Enable cross attention for joint computation
388
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
389
+ dim_head=attention_head_dim,
390
+ heads=num_attention_heads,
391
+ out_dim=dim,
392
+ context_pre_only=False,
393
+ bias=True,
394
+ processor=QwenDoubleStreamAttnProcessor2_0(),
395
+ qk_norm=qk_norm,
396
+ eps=eps,
397
+ )
398
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
399
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
400
+
401
+ # Text processing modules
402
+ self.txt_mod = nn.Sequential(
403
+ nn.SiLU(),
404
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
405
+ )
406
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
407
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
408
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
409
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
410
+
411
+ def _modulate(self, x, mod_params):
412
+ """Apply modulation to input tensor"""
413
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
414
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states: torch.Tensor,
419
+ encoder_hidden_states: torch.Tensor,
420
+ encoder_hidden_states_mask: torch.Tensor,
421
+ temb: torch.Tensor,
422
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
423
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
424
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
425
+ # Get modulation parameters for both streams
426
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
427
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
428
+
429
+ # Split modulation parameters for norm1 and norm2
430
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
431
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
432
+
433
+ # Process image stream - norm1 + modulation
434
+ img_normed = self.img_norm1(hidden_states)
435
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
436
+
437
+ # Process text stream - norm1 + modulation
438
+ txt_normed = self.txt_norm1(encoder_hidden_states)
439
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
440
+
441
+ # Use QwenAttnProcessor2_0 for joint attention computation
442
+ # This directly implements the DoubleStreamLayerMegatron logic:
443
+ # 1. Computes QKV for both streams
444
+ # 2. Applies QK normalization and RoPE
445
+ # 3. Concatenates and runs joint attention
446
+ # 4. Splits results back to separate streams
447
+ joint_attention_kwargs = joint_attention_kwargs or {}
448
+ attn_output = self.attn(
449
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
450
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
451
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
452
+ image_rotary_emb=image_rotary_emb,
453
+ **joint_attention_kwargs,
454
+ )
455
+
456
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
457
+ img_attn_output, txt_attn_output = attn_output
458
+
459
+ # Apply attention gates and add residual (like in Megatron)
460
+ hidden_states = hidden_states + img_gate1 * img_attn_output
461
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
462
+
463
+ # Process image stream - norm2 + MLP
464
+ img_normed2 = self.img_norm2(hidden_states)
465
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
466
+ img_mlp_output = self.img_mlp(img_modulated2)
467
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
468
+
469
+ # Process text stream - norm2 + MLP
470
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
471
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
472
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
473
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
474
+
475
+ # Clip to prevent overflow for fp16
476
+ if encoder_hidden_states.dtype == torch.float16:
477
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
478
+ if hidden_states.dtype == torch.float16:
479
+ hidden_states = hidden_states.clip(-65504, 65504)
480
+
481
+ return encoder_hidden_states, hidden_states
482
+
483
+
484
+ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
485
+ """
486
+ The Transformer model introduced in Qwen.
487
+
488
+ Args:
489
+ patch_size (`int`, defaults to `2`):
490
+ Patch size to turn the input data into small patches.
491
+ in_channels (`int`, defaults to `64`):
492
+ The number of channels in the input.
493
+ out_channels (`int`, *optional*, defaults to `None`):
494
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
495
+ num_layers (`int`, defaults to `60`):
496
+ The number of layers of dual stream DiT blocks to use.
497
+ attention_head_dim (`int`, defaults to `128`):
498
+ The number of dimensions to use for each attention head.
499
+ num_attention_heads (`int`, defaults to `24`):
500
+ The number of attention heads to use.
501
+ joint_attention_dim (`int`, defaults to `3584`):
502
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
503
+ `encoder_hidden_states`).
504
+ guidance_embeds (`bool`, defaults to `False`):
505
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
506
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
507
+ The dimensions to use for the rotary positional embeddings.
508
+ """
509
+
510
+ # _supports_gradient_checkpointing = True
511
+ # _no_split_modules = ["QwenImageTransformerBlock"]
512
+ # _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
513
+ # _repeated_blocks = ["QwenImageTransformerBlock"]
514
+ _supports_gradient_checkpointing = True
515
+
516
+ @register_to_config
517
+ def __init__(
518
+ self,
519
+ patch_size: int = 2,
520
+ in_channels: int = 64,
521
+ out_channels: Optional[int] = 16,
522
+ num_layers: int = 60,
523
+ attention_head_dim: int = 128,
524
+ num_attention_heads: int = 24,
525
+ joint_attention_dim: int = 3584,
526
+ guidance_embeds: bool = False, # TODO: this should probably be removed
527
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
528
+ ):
529
+ super().__init__()
530
+ self.out_channels = out_channels or in_channels
531
+ self.inner_dim = num_attention_heads * attention_head_dim
532
+
533
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
534
+
535
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
536
+
537
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
538
+
539
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
540
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
541
+
542
+ self.transformer_blocks = nn.ModuleList(
543
+ [
544
+ QwenImageTransformerBlock(
545
+ dim=self.inner_dim,
546
+ num_attention_heads=num_attention_heads,
547
+ attention_head_dim=attention_head_dim,
548
+ )
549
+ for _ in range(num_layers)
550
+ ]
551
+ )
552
+
553
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
554
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
555
+
556
+ self.gradient_checkpointing = False
557
+
558
+ self.sp_world_size = 1
559
+ self.sp_world_rank = 0
560
+
561
+ def _set_gradient_checkpointing(self, *args, **kwargs):
562
+ if "value" in kwargs:
563
+ self.gradient_checkpointing = kwargs["value"]
564
+ elif "enable" in kwargs:
565
+ self.gradient_checkpointing = kwargs["enable"]
566
+ else:
567
+ raise ValueError("Invalid set gradient checkpointing")
568
+
569
+ def enable_multi_gpus_inference(self,):
570
+ self.sp_world_size = get_sequence_parallel_world_size()
571
+ self.sp_world_rank = get_sequence_parallel_rank()
572
+ self.all_gather = get_sp_group().all_gather
573
+ self.set_attn_processor(QwenImageMultiGPUsAttnProcessor2_0())
574
+
575
+ @property
576
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
577
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
578
+ r"""
579
+ Returns:
580
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
581
+ indexed by its weight name.
582
+ """
583
+ # set recursively
584
+ processors = {}
585
+
586
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
587
+ if hasattr(module, "get_processor"):
588
+ processors[f"{name}.processor"] = module.get_processor()
589
+
590
+ for sub_name, child in module.named_children():
591
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
592
+
593
+ return processors
594
+
595
+ for name, module in self.named_children():
596
+ fn_recursive_add_processors(name, module, processors)
597
+
598
+ return processors
599
+
600
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
601
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
602
+ r"""
603
+ Sets the attention processor to use to compute attention.
604
+
605
+ Parameters:
606
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
607
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
608
+ for **all** `Attention` layers.
609
+
610
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
611
+ processor. This is strongly recommended when setting trainable attention processors.
612
+
613
+ """
614
+ count = len(self.attn_processors.keys())
615
+
616
+ if isinstance(processor, dict) and len(processor) != count:
617
+ raise ValueError(
618
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
619
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
620
+ )
621
+
622
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
623
+ if hasattr(module, "set_processor"):
624
+ if not isinstance(processor, dict):
625
+ module.set_processor(processor)
626
+ else:
627
+ module.set_processor(processor.pop(f"{name}.processor"))
628
+
629
+ for sub_name, child in module.named_children():
630
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
631
+
632
+ for name, module in self.named_children():
633
+ fn_recursive_attn_processor(name, module, processor)
634
+
635
+ def forward(
636
+ self,
637
+ hidden_states: torch.Tensor,
638
+ encoder_hidden_states: torch.Tensor = None,
639
+ encoder_hidden_states_mask: torch.Tensor = None,
640
+ timestep: torch.LongTensor = None,
641
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
642
+ txt_seq_lens: Optional[List[int]] = None,
643
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
644
+ attention_kwargs: Optional[Dict[str, Any]] = None,
645
+ return_dict: bool = True,
646
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
647
+ """
648
+ The [`QwenTransformer2DModel`] forward method.
649
+
650
+ Args:
651
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
652
+ Input `hidden_states`.
653
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
654
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
655
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
656
+ Mask of the input conditions.
657
+ timestep ( `torch.LongTensor`):
658
+ Used to indicate denoising step.
659
+ attention_kwargs (`dict`, *optional*):
660
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
661
+ `self.processor` in
662
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
663
+ return_dict (`bool`, *optional*, defaults to `True`):
664
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
665
+ tuple.
666
+
667
+ Returns:
668
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
669
+ `tuple` where the first element is the sample tensor.
670
+ """
671
+ if attention_kwargs is not None:
672
+ attention_kwargs = attention_kwargs.copy()
673
+ lora_scale = attention_kwargs.pop("scale", 1.0)
674
+ else:
675
+ lora_scale = 1.0
676
+
677
+ if USE_PEFT_BACKEND:
678
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
679
+ scale_lora_layers(self, lora_scale)
680
+ else:
681
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
682
+ logger.warning(
683
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
684
+ )
685
+
686
+ hidden_states = self.img_in(hidden_states)
687
+
688
+ timestep = timestep.to(hidden_states.dtype)
689
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
690
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
691
+
692
+ if guidance is not None:
693
+ guidance = guidance.to(hidden_states.dtype) * 1000
694
+
695
+ temb = (
696
+ self.time_text_embed(timestep, hidden_states)
697
+ if guidance is None
698
+ else self.time_text_embed(timestep, guidance, hidden_states)
699
+ )
700
+
701
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
702
+
703
+ # Context Parallel
704
+ if self.sp_world_size > 1:
705
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
706
+ if image_rotary_emb is not None:
707
+ image_rotary_emb = (
708
+ torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
709
+ image_rotary_emb[1]
710
+ )
711
+
712
+ for index_block, block in enumerate(self.transformer_blocks):
713
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
714
+ def create_custom_forward(module):
715
+ def custom_forward(*inputs):
716
+ return module(*inputs)
717
+
718
+ return custom_forward
719
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
720
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
721
+ create_custom_forward(block),
722
+ hidden_states,
723
+ encoder_hidden_states,
724
+ encoder_hidden_states_mask,
725
+ temb,
726
+ image_rotary_emb,
727
+ **ckpt_kwargs,
728
+ )
729
+
730
+ else:
731
+ encoder_hidden_states, hidden_states = block(
732
+ hidden_states=hidden_states,
733
+ encoder_hidden_states=encoder_hidden_states,
734
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
735
+ temb=temb,
736
+ image_rotary_emb=image_rotary_emb,
737
+ joint_attention_kwargs=attention_kwargs,
738
+ )
739
+
740
+ # Use only the image part (hidden_states) from the dual-stream blocks
741
+ hidden_states = self.norm_out(hidden_states, temb)
742
+ output = self.proj_out(hidden_states)
743
+
744
+ if self.sp_world_size > 1:
745
+ output = self.all_gather(output, dim=1)
746
+
747
+ if USE_PEFT_BACKEND:
748
+ # remove `lora_scale` from each PEFT layer
749
+ unscale_lora_layers(self, lora_scale)
750
+
751
+ if not return_dict:
752
+ return (output,)
753
+
754
+ return Transformer2DModelOutput(sample=output)
755
+
756
+ @classmethod
757
+ def from_pretrained(
758
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
759
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
760
+ ):
761
+ if subfolder is not None:
762
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
763
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
764
+
765
+ config_file = os.path.join(pretrained_model_path, 'config.json')
766
+ if not os.path.isfile(config_file):
767
+ raise RuntimeError(f"{config_file} does not exist")
768
+ with open(config_file, "r") as f:
769
+ config = json.load(f)
770
+
771
+ from diffusers.utils import WEIGHTS_NAME
772
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
773
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
774
+
775
+ if "dict_mapping" in transformer_additional_kwargs.keys():
776
+ for key in transformer_additional_kwargs["dict_mapping"]:
777
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
778
+
779
+ if low_cpu_mem_usage:
780
+ try:
781
+ import re
782
+
783
+ from diffusers import __version__ as diffusers_version
784
+ if diffusers_version >= "0.33.0":
785
+ from diffusers.models.model_loading_utils import \
786
+ load_model_dict_into_meta
787
+ else:
788
+ from diffusers.models.modeling_utils import \
789
+ load_model_dict_into_meta
790
+ from diffusers.utils import is_accelerate_available
791
+ if is_accelerate_available():
792
+ import accelerate
793
+
794
+ # Instantiate model with empty weights
795
+ with accelerate.init_empty_weights():
796
+ model = cls.from_config(config, **transformer_additional_kwargs)
797
+
798
+ param_device = "cpu"
799
+ if os.path.exists(model_file):
800
+ state_dict = torch.load(model_file, map_location="cpu")
801
+ elif os.path.exists(model_file_safetensors):
802
+ from safetensors.torch import load_file, safe_open
803
+ state_dict = load_file(model_file_safetensors)
804
+ else:
805
+ from safetensors.torch import load_file, safe_open
806
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
807
+ state_dict = {}
808
+ print(model_files_safetensors)
809
+ for _model_file_safetensors in model_files_safetensors:
810
+ _state_dict = load_file(_model_file_safetensors)
811
+ for key in _state_dict:
812
+ state_dict[key] = _state_dict[key]
813
+
814
+ if diffusers_version >= "0.33.0":
815
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
816
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
817
+ load_model_dict_into_meta(
818
+ model,
819
+ state_dict,
820
+ dtype=torch_dtype,
821
+ model_name_or_path=pretrained_model_path,
822
+ )
823
+ else:
824
+ model._convert_deprecated_attention_blocks(state_dict)
825
+ # move the params from meta device to cpu
826
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
827
+ if len(missing_keys) > 0:
828
+ raise ValueError(
829
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
830
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
831
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
832
+ " those weights or else make sure your checkpoint file is correct."
833
+ )
834
+
835
+ unexpected_keys = load_model_dict_into_meta(
836
+ model,
837
+ state_dict,
838
+ device=param_device,
839
+ dtype=torch_dtype,
840
+ model_name_or_path=pretrained_model_path,
841
+ )
842
+
843
+ if cls._keys_to_ignore_on_load_unexpected is not None:
844
+ for pat in cls._keys_to_ignore_on_load_unexpected:
845
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
846
+
847
+ if len(unexpected_keys) > 0:
848
+ print(
849
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
850
+ )
851
+
852
+ return model
853
+ except Exception as e:
854
+ print(
855
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
856
+ )
857
+
858
+ model = cls.from_config(config, **transformer_additional_kwargs)
859
+ if os.path.exists(model_file):
860
+ state_dict = torch.load(model_file, map_location="cpu")
861
+ elif os.path.exists(model_file_safetensors):
862
+ from safetensors.torch import load_file, safe_open
863
+ state_dict = load_file(model_file_safetensors)
864
+ else:
865
+ from safetensors.torch import load_file, safe_open
866
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
867
+ state_dict = {}
868
+ for _model_file_safetensors in model_files_safetensors:
869
+ _state_dict = load_file(_model_file_safetensors)
870
+ for key in _state_dict:
871
+ state_dict[key] = _state_dict[key]
872
+
873
+ tmp_state_dict = {}
874
+ for key in state_dict:
875
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
876
+ tmp_state_dict[key] = state_dict[key]
877
+ else:
878
+ print(key, "Size don't match, skip")
879
+
880
+ state_dict = tmp_state_dict
881
+
882
+ m, u = model.load_state_dict(state_dict, strict=False)
883
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
884
+ print(m)
885
+
886
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
887
+ print(f"### All Parameters: {sum(params) / 1e6} M")
888
+
889
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
890
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
891
+
892
+ model = model.to(torch_dtype)
893
+ return model
videox_fun/models/qwenimage_vae.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
2
+ # Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # We gratefully acknowledge the Wan Team for their outstanding contributions.
17
+ # QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
18
+ # For more information about the Wan VAE, please refer to:
19
+ # - GitHub: https://github.com/Wan-Video/Wan2.1
20
+ # - arXiv: https://arxiv.org/abs/2503.20314
21
+
22
+ import functools
23
+ import glob
24
+ import json
25
+ import math
26
+ import os
27
+ import types
28
+ import warnings
29
+ from typing import Any, Dict, List, Optional, Tuple, Union
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.cuda.amp as amp
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ import torch.utils.checkpoint
37
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
38
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
39
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
40
+ from diffusers.models.activations import get_activation
41
+ from diffusers.models.attention import FeedForward
42
+ from diffusers.models.attention_processor import Attention
43
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
44
+ DiagonalGaussianDistribution)
45
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
46
+ from diffusers.models.modeling_outputs import (AutoencoderKLOutput,
47
+ Transformer2DModelOutput)
48
+ from diffusers.models.modeling_utils import ModelMixin
49
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
50
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
51
+ scale_lora_layers, unscale_lora_layers)
52
+ from diffusers.utils.accelerate_utils import apply_forward_hook
53
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
54
+ from torch import nn
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+ CACHE_T = 2
59
+
60
+ class QwenImageCausalConv3d(nn.Conv3d):
61
+ r"""
62
+ A custom 3D causal convolution layer with feature caching support.
63
+
64
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
65
+ caching for efficient inference.
66
+
67
+ Args:
68
+ in_channels (int): Number of channels in the input image
69
+ out_channels (int): Number of channels produced by the convolution
70
+ kernel_size (int or tuple): Size of the convolving kernel
71
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
72
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ in_channels: int,
78
+ out_channels: int,
79
+ kernel_size: Union[int, Tuple[int, int, int]],
80
+ stride: Union[int, Tuple[int, int, int]] = 1,
81
+ padding: Union[int, Tuple[int, int, int]] = 0,
82
+ ) -> None:
83
+ super().__init__(
84
+ in_channels=in_channels,
85
+ out_channels=out_channels,
86
+ kernel_size=kernel_size,
87
+ stride=stride,
88
+ padding=padding,
89
+ )
90
+
91
+ # Set up causal padding
92
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
93
+ self.padding = (0, 0, 0)
94
+
95
+ def forward(self, x, cache_x=None):
96
+ padding = list(self._padding)
97
+ if cache_x is not None and self._padding[4] > 0:
98
+ cache_x = cache_x.to(x.device)
99
+ x = torch.cat([cache_x, x], dim=2)
100
+ padding[4] -= cache_x.shape[2]
101
+ x = F.pad(x, padding)
102
+ return super().forward(x)
103
+
104
+
105
+ class QwenImageRMS_norm(nn.Module):
106
+ r"""
107
+ A custom RMS normalization layer.
108
+
109
+ Args:
110
+ dim (int): The number of dimensions to normalize over.
111
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
112
+ Default is True.
113
+ images (bool, optional): Whether the input represents image data. Default is True.
114
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
115
+ """
116
+
117
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
118
+ super().__init__()
119
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
120
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
121
+
122
+ self.channel_first = channel_first
123
+ self.scale = dim**0.5
124
+ self.gamma = nn.Parameter(torch.ones(shape))
125
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
126
+
127
+ def forward(self, x):
128
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
129
+
130
+
131
+ class QwenImageUpsample(nn.Upsample):
132
+ r"""
133
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
134
+
135
+ Args:
136
+ x (torch.Tensor): Input tensor to be upsampled.
137
+
138
+ Returns:
139
+ torch.Tensor: Upsampled tensor with the same data type as the input.
140
+ """
141
+
142
+ def forward(self, x):
143
+ return super().forward(x.float()).type_as(x)
144
+
145
+
146
+ class QwenImageResample(nn.Module):
147
+ r"""
148
+ A custom resampling module for 2D and 3D data.
149
+
150
+ Args:
151
+ dim (int): The number of input/output channels.
152
+ mode (str): The resampling mode. Must be one of:
153
+ - 'none': No resampling (identity operation).
154
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
155
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
156
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
157
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
158
+ """
159
+
160
+ def __init__(self, dim: int, mode: str) -> None:
161
+ super().__init__()
162
+ self.dim = dim
163
+ self.mode = mode
164
+
165
+ # layers
166
+ if mode == "upsample2d":
167
+ self.resample = nn.Sequential(
168
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
169
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
170
+ )
171
+ elif mode == "upsample3d":
172
+ self.resample = nn.Sequential(
173
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
174
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
175
+ )
176
+ self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
177
+
178
+ elif mode == "downsample2d":
179
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
180
+ elif mode == "downsample3d":
181
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
182
+ self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
183
+
184
+ else:
185
+ self.resample = nn.Identity()
186
+
187
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
188
+ b, c, t, h, w = x.size()
189
+ if self.mode == "upsample3d":
190
+ if feat_cache is not None:
191
+ idx = feat_idx[0]
192
+ if feat_cache[idx] is None:
193
+ feat_cache[idx] = "Rep"
194
+ feat_idx[0] += 1
195
+ else:
196
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
197
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
198
+ # cache last frame of last two chunk
199
+ cache_x = torch.cat(
200
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
201
+ )
202
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
203
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
204
+ if feat_cache[idx] == "Rep":
205
+ x = self.time_conv(x)
206
+ else:
207
+ x = self.time_conv(x, feat_cache[idx])
208
+ feat_cache[idx] = cache_x
209
+ feat_idx[0] += 1
210
+
211
+ x = x.reshape(b, 2, c, t, h, w)
212
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
213
+ x = x.reshape(b, c, t * 2, h, w)
214
+ t = x.shape[2]
215
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
216
+ x = self.resample(x)
217
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
218
+
219
+ if self.mode == "downsample3d":
220
+ if feat_cache is not None:
221
+ idx = feat_idx[0]
222
+ if feat_cache[idx] is None:
223
+ feat_cache[idx] = x.clone()
224
+ feat_idx[0] += 1
225
+ else:
226
+ cache_x = x[:, :, -1:, :, :].clone()
227
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
228
+ feat_cache[idx] = cache_x
229
+ feat_idx[0] += 1
230
+ return x
231
+
232
+
233
+ class QwenImageResidualBlock(nn.Module):
234
+ r"""
235
+ A custom residual block module.
236
+
237
+ Args:
238
+ in_dim (int): Number of input channels.
239
+ out_dim (int): Number of output channels.
240
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
241
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ in_dim: int,
247
+ out_dim: int,
248
+ dropout: float = 0.0,
249
+ non_linearity: str = "silu",
250
+ ) -> None:
251
+ super().__init__()
252
+ self.in_dim = in_dim
253
+ self.out_dim = out_dim
254
+ self.nonlinearity = get_activation(non_linearity)
255
+
256
+ # layers
257
+ self.norm1 = QwenImageRMS_norm(in_dim, images=False)
258
+ self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
259
+ self.norm2 = QwenImageRMS_norm(out_dim, images=False)
260
+ self.dropout = nn.Dropout(dropout)
261
+ self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
262
+ self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
263
+
264
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
265
+ # Apply shortcut connection
266
+ h = self.conv_shortcut(x)
267
+
268
+ # First normalization and activation
269
+ x = self.norm1(x)
270
+ x = self.nonlinearity(x)
271
+
272
+ if feat_cache is not None:
273
+ idx = feat_idx[0]
274
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
275
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
276
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
277
+
278
+ x = self.conv1(x, feat_cache[idx])
279
+ feat_cache[idx] = cache_x
280
+ feat_idx[0] += 1
281
+ else:
282
+ x = self.conv1(x)
283
+
284
+ # Second normalization and activation
285
+ x = self.norm2(x)
286
+ x = self.nonlinearity(x)
287
+
288
+ # Dropout
289
+ x = self.dropout(x)
290
+
291
+ if feat_cache is not None:
292
+ idx = feat_idx[0]
293
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
294
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
295
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
296
+
297
+ x = self.conv2(x, feat_cache[idx])
298
+ feat_cache[idx] = cache_x
299
+ feat_idx[0] += 1
300
+ else:
301
+ x = self.conv2(x)
302
+
303
+ # Add residual connection
304
+ return x + h
305
+
306
+
307
+ class QwenImageAttentionBlock(nn.Module):
308
+ r"""
309
+ Causal self-attention with a single head.
310
+
311
+ Args:
312
+ dim (int): The number of channels in the input tensor.
313
+ """
314
+
315
+ def __init__(self, dim):
316
+ super().__init__()
317
+ self.dim = dim
318
+
319
+ # layers
320
+ self.norm = QwenImageRMS_norm(dim)
321
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
322
+ self.proj = nn.Conv2d(dim, dim, 1)
323
+
324
+ def forward(self, x):
325
+ identity = x
326
+ batch_size, channels, time, height, width = x.size()
327
+
328
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
329
+ x = self.norm(x)
330
+
331
+ # compute query, key, value
332
+ qkv = self.to_qkv(x)
333
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
334
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
335
+ q, k, v = qkv.chunk(3, dim=-1)
336
+
337
+ # apply attention
338
+ x = F.scaled_dot_product_attention(q, k, v)
339
+
340
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
341
+
342
+ # output projection
343
+ x = self.proj(x)
344
+
345
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
346
+ x = x.view(batch_size, time, channels, height, width)
347
+ x = x.permute(0, 2, 1, 3, 4)
348
+
349
+ return x + identity
350
+
351
+
352
+ class QwenImageMidBlock(nn.Module):
353
+ """
354
+ Middle block for QwenImageVAE encoder and decoder.
355
+
356
+ Args:
357
+ dim (int): Number of input/output channels.
358
+ dropout (float): Dropout rate.
359
+ non_linearity (str): Type of non-linearity to use.
360
+ """
361
+
362
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
363
+ super().__init__()
364
+ self.dim = dim
365
+
366
+ # Create the components
367
+ resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
368
+ attentions = []
369
+ for _ in range(num_layers):
370
+ attentions.append(QwenImageAttentionBlock(dim))
371
+ resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
372
+ self.attentions = nn.ModuleList(attentions)
373
+ self.resnets = nn.ModuleList(resnets)
374
+
375
+ self.gradient_checkpointing = False
376
+
377
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
378
+ # First residual block
379
+ x = self.resnets[0](x, feat_cache, feat_idx)
380
+
381
+ # Process through attention and residual blocks
382
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
383
+ if attn is not None:
384
+ x = attn(x)
385
+
386
+ x = resnet(x, feat_cache, feat_idx)
387
+
388
+ return x
389
+
390
+
391
+ class QwenImageEncoder3d(nn.Module):
392
+ r"""
393
+ A 3D encoder module.
394
+
395
+ Args:
396
+ dim (int): The base number of channels in the first layer.
397
+ z_dim (int): The dimensionality of the latent space.
398
+ dim_mult (list of int): Multipliers for the number of channels in each block.
399
+ num_res_blocks (int): Number of residual blocks in each block.
400
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
401
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
402
+ dropout (float): Dropout rate for the dropout layers.
403
+ non_linearity (str): Type of non-linearity to use.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ dim=128,
409
+ z_dim=4,
410
+ dim_mult=[1, 2, 4, 4],
411
+ num_res_blocks=2,
412
+ attn_scales=[],
413
+ temperal_downsample=[True, True, False],
414
+ dropout=0.0,
415
+ non_linearity: str = "silu",
416
+ ):
417
+ super().__init__()
418
+ self.dim = dim
419
+ self.z_dim = z_dim
420
+ self.dim_mult = dim_mult
421
+ self.num_res_blocks = num_res_blocks
422
+ self.attn_scales = attn_scales
423
+ self.temperal_downsample = temperal_downsample
424
+ self.nonlinearity = get_activation(non_linearity)
425
+
426
+ # dimensions
427
+ dims = [dim * u for u in [1] + dim_mult]
428
+ scale = 1.0
429
+
430
+ # init block
431
+ self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
432
+
433
+ # downsample blocks
434
+ self.down_blocks = nn.ModuleList([])
435
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
436
+ # residual (+attention) blocks
437
+ for _ in range(num_res_blocks):
438
+ self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
439
+ if scale in attn_scales:
440
+ self.down_blocks.append(QwenImageAttentionBlock(out_dim))
441
+ in_dim = out_dim
442
+
443
+ # downsample block
444
+ if i != len(dim_mult) - 1:
445
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
446
+ self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
447
+ scale /= 2.0
448
+
449
+ # middle blocks
450
+ self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
451
+
452
+ # output blocks
453
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
454
+ self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
455
+
456
+ self.gradient_checkpointing = False
457
+
458
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
459
+ if feat_cache is not None:
460
+ idx = feat_idx[0]
461
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
462
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
463
+ # cache last frame of last two chunk
464
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
465
+ x = self.conv_in(x, feat_cache[idx])
466
+ feat_cache[idx] = cache_x
467
+ feat_idx[0] += 1
468
+ else:
469
+ x = self.conv_in(x)
470
+
471
+ ## downsamples
472
+ for layer in self.down_blocks:
473
+ if feat_cache is not None:
474
+ x = layer(x, feat_cache, feat_idx)
475
+ else:
476
+ x = layer(x)
477
+
478
+ ## middle
479
+ x = self.mid_block(x, feat_cache, feat_idx)
480
+
481
+ ## head
482
+ x = self.norm_out(x)
483
+ x = self.nonlinearity(x)
484
+ if feat_cache is not None:
485
+ idx = feat_idx[0]
486
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
487
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
488
+ # cache last frame of last two chunk
489
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
490
+ x = self.conv_out(x, feat_cache[idx])
491
+ feat_cache[idx] = cache_x
492
+ feat_idx[0] += 1
493
+ else:
494
+ x = self.conv_out(x)
495
+ return x
496
+
497
+
498
+ class QwenImageUpBlock(nn.Module):
499
+ """
500
+ A block that handles upsampling for the QwenImageVAE decoder.
501
+
502
+ Args:
503
+ in_dim (int): Input dimension
504
+ out_dim (int): Output dimension
505
+ num_res_blocks (int): Number of residual blocks
506
+ dropout (float): Dropout rate
507
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
508
+ non_linearity (str): Type of non-linearity to use
509
+ """
510
+
511
+ def __init__(
512
+ self,
513
+ in_dim: int,
514
+ out_dim: int,
515
+ num_res_blocks: int,
516
+ dropout: float = 0.0,
517
+ upsample_mode: Optional[str] = None,
518
+ non_linearity: str = "silu",
519
+ ):
520
+ super().__init__()
521
+ self.in_dim = in_dim
522
+ self.out_dim = out_dim
523
+
524
+ # Create layers list
525
+ resnets = []
526
+ # Add residual blocks and attention if needed
527
+ current_dim = in_dim
528
+ for _ in range(num_res_blocks + 1):
529
+ resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
530
+ current_dim = out_dim
531
+
532
+ self.resnets = nn.ModuleList(resnets)
533
+
534
+ # Add upsampling layer if needed
535
+ self.upsamplers = None
536
+ if upsample_mode is not None:
537
+ self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
538
+
539
+ self.gradient_checkpointing = False
540
+
541
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
542
+ """
543
+ Forward pass through the upsampling block.
544
+
545
+ Args:
546
+ x (torch.Tensor): Input tensor
547
+ feat_cache (list, optional): Feature cache for causal convolutions
548
+ feat_idx (list, optional): Feature index for cache management
549
+
550
+ Returns:
551
+ torch.Tensor: Output tensor
552
+ """
553
+ for resnet in self.resnets:
554
+ if feat_cache is not None:
555
+ x = resnet(x, feat_cache, feat_idx)
556
+ else:
557
+ x = resnet(x)
558
+
559
+ if self.upsamplers is not None:
560
+ if feat_cache is not None:
561
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
562
+ else:
563
+ x = self.upsamplers[0](x)
564
+ return x
565
+
566
+
567
+ class QwenImageDecoder3d(nn.Module):
568
+ r"""
569
+ A 3D decoder module.
570
+
571
+ Args:
572
+ dim (int): The base number of channels in the first layer.
573
+ z_dim (int): The dimensionality of the latent space.
574
+ dim_mult (list of int): Multipliers for the number of channels in each block.
575
+ num_res_blocks (int): Number of residual blocks in each block.
576
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
577
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
578
+ dropout (float): Dropout rate for the dropout layers.
579
+ non_linearity (str): Type of non-linearity to use.
580
+ """
581
+
582
+ def __init__(
583
+ self,
584
+ dim=128,
585
+ z_dim=4,
586
+ dim_mult=[1, 2, 4, 4],
587
+ num_res_blocks=2,
588
+ attn_scales=[],
589
+ temperal_upsample=[False, True, True],
590
+ dropout=0.0,
591
+ non_linearity: str = "silu",
592
+ ):
593
+ super().__init__()
594
+ self.dim = dim
595
+ self.z_dim = z_dim
596
+ self.dim_mult = dim_mult
597
+ self.num_res_blocks = num_res_blocks
598
+ self.attn_scales = attn_scales
599
+ self.temperal_upsample = temperal_upsample
600
+
601
+ self.nonlinearity = get_activation(non_linearity)
602
+
603
+ # dimensions
604
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
605
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
606
+
607
+ # init block
608
+ self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
609
+
610
+ # middle blocks
611
+ self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
612
+
613
+ # upsample blocks
614
+ self.up_blocks = nn.ModuleList([])
615
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
616
+ # residual (+attention) blocks
617
+ if i > 0:
618
+ in_dim = in_dim // 2
619
+
620
+ # Determine if we need upsampling
621
+ upsample_mode = None
622
+ if i != len(dim_mult) - 1:
623
+ upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
624
+
625
+ # Create and add the upsampling block
626
+ up_block = QwenImageUpBlock(
627
+ in_dim=in_dim,
628
+ out_dim=out_dim,
629
+ num_res_blocks=num_res_blocks,
630
+ dropout=dropout,
631
+ upsample_mode=upsample_mode,
632
+ non_linearity=non_linearity,
633
+ )
634
+ self.up_blocks.append(up_block)
635
+
636
+ # Update scale for next iteration
637
+ if upsample_mode is not None:
638
+ scale *= 2.0
639
+
640
+ # output blocks
641
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
642
+ self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
643
+
644
+ self.gradient_checkpointing = False
645
+
646
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
647
+ ## conv1
648
+ if feat_cache is not None:
649
+ idx = feat_idx[0]
650
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
651
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
652
+ # cache last frame of last two chunk
653
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
654
+ x = self.conv_in(x, feat_cache[idx])
655
+ feat_cache[idx] = cache_x
656
+ feat_idx[0] += 1
657
+ else:
658
+ x = self.conv_in(x)
659
+
660
+ ## middle
661
+ x = self.mid_block(x, feat_cache, feat_idx)
662
+
663
+ ## upsamples
664
+ for up_block in self.up_blocks:
665
+ x = up_block(x, feat_cache, feat_idx)
666
+
667
+ ## head
668
+ x = self.norm_out(x)
669
+ x = self.nonlinearity(x)
670
+ if feat_cache is not None:
671
+ idx = feat_idx[0]
672
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
673
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
674
+ # cache last frame of last two chunk
675
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
676
+ x = self.conv_out(x, feat_cache[idx])
677
+ feat_cache[idx] = cache_x
678
+ feat_idx[0] += 1
679
+ else:
680
+ x = self.conv_out(x)
681
+ return x
682
+
683
+
684
+ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
685
+ r"""
686
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
687
+
688
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
689
+ for all models (such as downloading or saving).
690
+ """
691
+
692
+ _supports_gradient_checkpointing = False
693
+
694
+ # fmt: off
695
+ @register_to_config
696
+ def __init__(
697
+ self,
698
+ base_dim: int = 96,
699
+ z_dim: int = 16,
700
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
701
+ num_res_blocks: int = 2,
702
+ attn_scales: List[float] = [],
703
+ temperal_downsample: List[bool] = [False, True, True],
704
+ dropout: float = 0.0,
705
+ latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
706
+ latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
707
+ ) -> None:
708
+ # fmt: on
709
+ super().__init__()
710
+
711
+ self.z_dim = z_dim
712
+ self.temperal_downsample = temperal_downsample
713
+ self.temperal_upsample = temperal_downsample[::-1]
714
+
715
+ self.encoder = QwenImageEncoder3d(
716
+ base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
717
+ )
718
+ self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
719
+ self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
720
+
721
+ self.decoder = QwenImageDecoder3d(
722
+ base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
723
+ )
724
+
725
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
726
+
727
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
728
+ # to perform decoding of a single video latent at a time.
729
+ self.use_slicing = False
730
+
731
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
732
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
733
+ # intermediate tiles together, the memory requirement can be lowered.
734
+ self.use_tiling = False
735
+
736
+ # The minimal tile height and width for spatial tiling to be used
737
+ self.tile_sample_min_height = 256
738
+ self.tile_sample_min_width = 256
739
+
740
+ # The minimal distance between two spatial tiles
741
+ self.tile_sample_stride_height = 192
742
+ self.tile_sample_stride_width = 192
743
+
744
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
745
+ self._cached_conv_counts = {
746
+ "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
747
+ if self.decoder is not None
748
+ else 0,
749
+ "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
750
+ if self.encoder is not None
751
+ else 0,
752
+ }
753
+
754
+ def enable_tiling(
755
+ self,
756
+ tile_sample_min_height: Optional[int] = None,
757
+ tile_sample_min_width: Optional[int] = None,
758
+ tile_sample_stride_height: Optional[float] = None,
759
+ tile_sample_stride_width: Optional[float] = None,
760
+ ) -> None:
761
+ r"""
762
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
763
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
764
+ processing larger images.
765
+
766
+ Args:
767
+ tile_sample_min_height (`int`, *optional*):
768
+ The minimum height required for a sample to be separated into tiles across the height dimension.
769
+ tile_sample_min_width (`int`, *optional*):
770
+ The minimum width required for a sample to be separated into tiles across the width dimension.
771
+ tile_sample_stride_height (`int`, *optional*):
772
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
773
+ no tiling artifacts produced across the height dimension.
774
+ tile_sample_stride_width (`int`, *optional*):
775
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
776
+ artifacts produced across the width dimension.
777
+ """
778
+ self.use_tiling = True
779
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
780
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
781
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
782
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
783
+
784
+ def disable_tiling(self) -> None:
785
+ r"""
786
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
787
+ decoding in one step.
788
+ """
789
+ self.use_tiling = False
790
+
791
+ def enable_slicing(self) -> None:
792
+ r"""
793
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
794
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
795
+ """
796
+ self.use_slicing = True
797
+
798
+ def disable_slicing(self) -> None:
799
+ r"""
800
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
801
+ decoding in one step.
802
+ """
803
+ self.use_slicing = False
804
+
805
+ def clear_cache(self):
806
+ def _count_conv3d(model):
807
+ count = 0
808
+ for m in model.modules():
809
+ if isinstance(m, QwenImageCausalConv3d):
810
+ count += 1
811
+ return count
812
+
813
+ self._conv_num = _count_conv3d(self.decoder)
814
+ self._conv_idx = [0]
815
+ self._feat_map = [None] * self._conv_num
816
+ # cache encode
817
+ self._enc_conv_num = _count_conv3d(self.encoder)
818
+ self._enc_conv_idx = [0]
819
+ self._enc_feat_map = [None] * self._enc_conv_num
820
+
821
+ def _encode(self, x: torch.Tensor):
822
+ _, _, num_frame, height, width = x.shape
823
+
824
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
825
+ return self.tiled_encode(x)
826
+
827
+ self.clear_cache()
828
+ iter_ = 1 + (num_frame - 1) // 4
829
+ for i in range(iter_):
830
+ self._enc_conv_idx = [0]
831
+ if i == 0:
832
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
833
+ else:
834
+ out_ = self.encoder(
835
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
836
+ feat_cache=self._enc_feat_map,
837
+ feat_idx=self._enc_conv_idx,
838
+ )
839
+ out = torch.cat([out, out_], 2)
840
+
841
+ enc = self.quant_conv(out)
842
+ self.clear_cache()
843
+ return enc
844
+
845
+ @apply_forward_hook
846
+ def encode(
847
+ self, x: torch.Tensor, return_dict: bool = True
848
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
849
+ r"""
850
+ Encode a batch of images into latents.
851
+
852
+ Args:
853
+ x (`torch.Tensor`): Input batch of images.
854
+ return_dict (`bool`, *optional*, defaults to `True`):
855
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
856
+
857
+ Returns:
858
+ The latent representations of the encoded videos. If `return_dict` is True, a
859
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
860
+ """
861
+ if self.use_slicing and x.shape[0] > 1:
862
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
863
+ h = torch.cat(encoded_slices)
864
+ else:
865
+ h = self._encode(x)
866
+ posterior = DiagonalGaussianDistribution(h)
867
+
868
+ if not return_dict:
869
+ return (posterior,)
870
+ return AutoencoderKLOutput(latent_dist=posterior)
871
+
872
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
873
+ _, _, num_frame, height, width = z.shape
874
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
875
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
876
+
877
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
878
+ return self.tiled_decode(z, return_dict=return_dict)
879
+
880
+ self.clear_cache()
881
+ x = self.post_quant_conv(z)
882
+ for i in range(num_frame):
883
+ self._conv_idx = [0]
884
+ if i == 0:
885
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
886
+ else:
887
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
888
+ out = torch.cat([out, out_], 2)
889
+
890
+ out = torch.clamp(out, min=-1.0, max=1.0)
891
+ self.clear_cache()
892
+ if not return_dict:
893
+ return (out,)
894
+
895
+ return DecoderOutput(sample=out)
896
+
897
+ @apply_forward_hook
898
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
899
+ r"""
900
+ Decode a batch of images.
901
+
902
+ Args:
903
+ z (`torch.Tensor`): Input batch of latent vectors.
904
+ return_dict (`bool`, *optional*, defaults to `True`):
905
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
906
+
907
+ Returns:
908
+ [`~models.vae.DecoderOutput`] or `tuple`:
909
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
910
+ returned.
911
+ """
912
+ if self.use_slicing and z.shape[0] > 1:
913
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
914
+ decoded = torch.cat(decoded_slices)
915
+ else:
916
+ decoded = self._decode(z).sample
917
+
918
+ if not return_dict:
919
+ return (decoded,)
920
+ return DecoderOutput(sample=decoded)
921
+
922
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
923
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
924
+ for y in range(blend_extent):
925
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
926
+ y / blend_extent
927
+ )
928
+ return b
929
+
930
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
931
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
932
+ for x in range(blend_extent):
933
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
934
+ x / blend_extent
935
+ )
936
+ return b
937
+
938
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
939
+ r"""Encode a batch of images using a tiled encoder.
940
+
941
+ Args:
942
+ x (`torch.Tensor`): Input batch of videos.
943
+
944
+ Returns:
945
+ `torch.Tensor`:
946
+ The latent representation of the encoded videos.
947
+ """
948
+ _, _, num_frames, height, width = x.shape
949
+ latent_height = height // self.spatial_compression_ratio
950
+ latent_width = width // self.spatial_compression_ratio
951
+
952
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
953
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
954
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
955
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
956
+
957
+ blend_height = tile_latent_min_height - tile_latent_stride_height
958
+ blend_width = tile_latent_min_width - tile_latent_stride_width
959
+
960
+ # Split x into overlapping tiles and encode them separately.
961
+ # The tiles have an overlap to avoid seams between tiles.
962
+ rows = []
963
+ for i in range(0, height, self.tile_sample_stride_height):
964
+ row = []
965
+ for j in range(0, width, self.tile_sample_stride_width):
966
+ self.clear_cache()
967
+ time = []
968
+ frame_range = 1 + (num_frames - 1) // 4
969
+ for k in range(frame_range):
970
+ self._enc_conv_idx = [0]
971
+ if k == 0:
972
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
973
+ else:
974
+ tile = x[
975
+ :,
976
+ :,
977
+ 1 + 4 * (k - 1) : 1 + 4 * k,
978
+ i : i + self.tile_sample_min_height,
979
+ j : j + self.tile_sample_min_width,
980
+ ]
981
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
982
+ tile = self.quant_conv(tile)
983
+ time.append(tile)
984
+ row.append(torch.cat(time, dim=2))
985
+ rows.append(row)
986
+ self.clear_cache()
987
+
988
+ result_rows = []
989
+ for i, row in enumerate(rows):
990
+ result_row = []
991
+ for j, tile in enumerate(row):
992
+ # blend the above tile and the left tile
993
+ # to the current tile and add the current tile to the result row
994
+ if i > 0:
995
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
996
+ if j > 0:
997
+ tile = self.blend_h(row[j - 1], tile, blend_width)
998
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
999
+ result_rows.append(torch.cat(result_row, dim=-1))
1000
+
1001
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1002
+ return enc
1003
+
1004
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1005
+ r"""
1006
+ Decode a batch of images using a tiled decoder.
1007
+
1008
+ Args:
1009
+ z (`torch.Tensor`): Input batch of latent vectors.
1010
+ return_dict (`bool`, *optional*, defaults to `True`):
1011
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1012
+
1013
+ Returns:
1014
+ [`~models.vae.DecoderOutput`] or `tuple`:
1015
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1016
+ returned.
1017
+ """
1018
+ _, _, num_frames, height, width = z.shape
1019
+ sample_height = height * self.spatial_compression_ratio
1020
+ sample_width = width * self.spatial_compression_ratio
1021
+
1022
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1023
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1024
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1025
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1026
+
1027
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1028
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1029
+
1030
+ # Split z into overlapping tiles and decode them separately.
1031
+ # The tiles have an overlap to avoid seams between tiles.
1032
+ rows = []
1033
+ for i in range(0, height, tile_latent_stride_height):
1034
+ row = []
1035
+ for j in range(0, width, tile_latent_stride_width):
1036
+ self.clear_cache()
1037
+ time = []
1038
+ for k in range(num_frames):
1039
+ self._conv_idx = [0]
1040
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1041
+ tile = self.post_quant_conv(tile)
1042
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1043
+ time.append(decoded)
1044
+ row.append(torch.cat(time, dim=2))
1045
+ rows.append(row)
1046
+ self.clear_cache()
1047
+
1048
+ result_rows = []
1049
+ for i, row in enumerate(rows):
1050
+ result_row = []
1051
+ for j, tile in enumerate(row):
1052
+ # blend the above tile and the left tile
1053
+ # to the current tile and add the current tile to the result row
1054
+ if i > 0:
1055
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1056
+ if j > 0:
1057
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1058
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1059
+ result_rows.append(torch.cat(result_row, dim=-1))
1060
+
1061
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1062
+
1063
+ if not return_dict:
1064
+ return (dec,)
1065
+ return DecoderOutput(sample=dec)
1066
+
1067
+ def forward(
1068
+ self,
1069
+ sample: torch.Tensor,
1070
+ sample_posterior: bool = False,
1071
+ return_dict: bool = True,
1072
+ generator: Optional[torch.Generator] = None,
1073
+ ) -> Union[DecoderOutput, torch.Tensor]:
1074
+ """
1075
+ Args:
1076
+ sample (`torch.Tensor`): Input sample.
1077
+ return_dict (`bool`, *optional*, defaults to `True`):
1078
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1079
+ """
1080
+ x = sample
1081
+ posterior = self.encode(x).latent_dist
1082
+ if sample_posterior:
1083
+ z = posterior.sample(generator=generator)
1084
+ else:
1085
+ z = posterior.mode()
1086
+ dec = self.decode(z, return_dict=return_dict)
1087
+ return dec
videox_fun/models/wan_camera_adapter.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SimpleAdapter(nn.Module):
6
+ def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1):
7
+ super(SimpleAdapter, self).__init__()
8
+
9
+ # Pixel Unshuffle: reduce spatial dimensions by a factor of 8
10
+ self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor)
11
+
12
+ # Convolution: reduce spatial dimensions by a factor
13
+ # of 2 (without overlap)
14
+ self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
15
+
16
+ # Residual blocks for feature extraction
17
+ self.residual_blocks = nn.Sequential(
18
+ *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
19
+ )
20
+
21
+ def forward(self, x):
22
+ # Reshape to merge the frame dimension into batch
23
+ bs, c, f, h, w = x.size()
24
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
25
+
26
+ # Pixel Unshuffle operation
27
+ x_unshuffled = self.pixel_unshuffle(x)
28
+
29
+ # Convolution operation
30
+ x_conv = self.conv(x_unshuffled)
31
+
32
+ # Feature extraction with residual blocks
33
+ out = self.residual_blocks(x_conv)
34
+
35
+ # Reshape to restore original bf dimension
36
+ out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
37
+
38
+ # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
39
+ out = out.permute(0, 2, 1, 3, 4)
40
+
41
+ return out
42
+
43
+
44
+ class ResidualBlock(nn.Module):
45
+ def __init__(self, dim):
46
+ super(ResidualBlock, self).__init__()
47
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
48
+ self.relu = nn.ReLU(inplace=True)
49
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
50
+
51
+ def forward(self, x):
52
+ residual = x
53
+ out = self.relu(self.conv1(x))
54
+ out = self.conv2(out)
55
+ out += residual
56
+ return out
57
+
58
+ # Example usage
59
+ # in_dim = 3
60
+ # out_dim = 64
61
+ # adapter = SimpleAdapterWithReshape(in_dim, out_dim)
62
+ # x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4
63
+ # output = adapter(x)
64
+ # print(output.shape) # Should reflect transformed dimensions
videox_fun/models/wan_image_encoder.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as T
9
+
10
+ from .attention_utils import attention, flash_attention
11
+ from .wan_xlm_roberta import XLMRoberta
12
+ from diffusers.configuration_utils import ConfigMixin
13
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+
16
+
17
+ __all__ = [
18
+ 'XLMRobertaCLIP',
19
+ 'clip_xlm_roberta_vit_h_14',
20
+ 'CLIPModel',
21
+ ]
22
+
23
+
24
+ def pos_interpolate(pos, seq_len):
25
+ if pos.size(1) == seq_len:
26
+ return pos
27
+ else:
28
+ src_grid = int(math.sqrt(pos.size(1)))
29
+ tar_grid = int(math.sqrt(seq_len))
30
+ n = pos.size(1) - src_grid * src_grid
31
+ return torch.cat([
32
+ pos[:, :n],
33
+ F.interpolate(
34
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
35
+ 0, 3, 1, 2),
36
+ size=(tar_grid, tar_grid),
37
+ mode='bicubic',
38
+ align_corners=False).flatten(2).transpose(1, 2)
39
+ ],
40
+ dim=1)
41
+
42
+
43
+ class QuickGELU(nn.Module):
44
+
45
+ def forward(self, x):
46
+ return x * torch.sigmoid(1.702 * x)
47
+
48
+
49
+ class LayerNorm(nn.LayerNorm):
50
+
51
+ def forward(self, x):
52
+ return super().forward(x.float()).type_as(x)
53
+
54
+
55
+ class SelfAttention(nn.Module):
56
+
57
+ def __init__(self,
58
+ dim,
59
+ num_heads,
60
+ causal=False,
61
+ attn_dropout=0.0,
62
+ proj_dropout=0.0):
63
+ assert dim % num_heads == 0
64
+ super().__init__()
65
+ self.dim = dim
66
+ self.num_heads = num_heads
67
+ self.head_dim = dim // num_heads
68
+ self.causal = causal
69
+ self.attn_dropout = attn_dropout
70
+ self.proj_dropout = proj_dropout
71
+
72
+ # layers
73
+ self.to_qkv = nn.Linear(dim, dim * 3)
74
+ self.proj = nn.Linear(dim, dim)
75
+
76
+ def forward(self, x):
77
+ """
78
+ x: [B, L, C].
79
+ """
80
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
81
+
82
+ # compute query, key, value
83
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
84
+
85
+ # compute attention
86
+ p = self.attn_dropout if self.training else 0.0
87
+ x = attention(q, k, v, dropout_p=p, causal=self.causal, attention_type="none")
88
+ x = x.reshape(b, s, c)
89
+
90
+ # output
91
+ x = self.proj(x)
92
+ x = F.dropout(x, self.proj_dropout, self.training)
93
+ return x
94
+
95
+
96
+ class SwiGLU(nn.Module):
97
+
98
+ def __init__(self, dim, mid_dim):
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.mid_dim = mid_dim
102
+
103
+ # layers
104
+ self.fc1 = nn.Linear(dim, mid_dim)
105
+ self.fc2 = nn.Linear(dim, mid_dim)
106
+ self.fc3 = nn.Linear(mid_dim, dim)
107
+
108
+ def forward(self, x):
109
+ x = F.silu(self.fc1(x)) * self.fc2(x)
110
+ x = self.fc3(x)
111
+ return x
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+
116
+ def __init__(self,
117
+ dim,
118
+ mlp_ratio,
119
+ num_heads,
120
+ post_norm=False,
121
+ causal=False,
122
+ activation='quick_gelu',
123
+ attn_dropout=0.0,
124
+ proj_dropout=0.0,
125
+ norm_eps=1e-5):
126
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
127
+ super().__init__()
128
+ self.dim = dim
129
+ self.mlp_ratio = mlp_ratio
130
+ self.num_heads = num_heads
131
+ self.post_norm = post_norm
132
+ self.causal = causal
133
+ self.norm_eps = norm_eps
134
+
135
+ # layers
136
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
137
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
138
+ proj_dropout)
139
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
140
+ if activation == 'swi_glu':
141
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
142
+ else:
143
+ self.mlp = nn.Sequential(
144
+ nn.Linear(dim, int(dim * mlp_ratio)),
145
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
146
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
147
+
148
+ def forward(self, x):
149
+ if self.post_norm:
150
+ x = x + self.norm1(self.attn(x))
151
+ x = x + self.norm2(self.mlp(x))
152
+ else:
153
+ x = x + self.attn(self.norm1(x))
154
+ x = x + self.mlp(self.norm2(x))
155
+ return x
156
+
157
+
158
+ class AttentionPool(nn.Module):
159
+
160
+ def __init__(self,
161
+ dim,
162
+ mlp_ratio,
163
+ num_heads,
164
+ activation='gelu',
165
+ proj_dropout=0.0,
166
+ norm_eps=1e-5):
167
+ assert dim % num_heads == 0
168
+ super().__init__()
169
+ self.dim = dim
170
+ self.mlp_ratio = mlp_ratio
171
+ self.num_heads = num_heads
172
+ self.head_dim = dim // num_heads
173
+ self.proj_dropout = proj_dropout
174
+ self.norm_eps = norm_eps
175
+
176
+ # layers
177
+ gain = 1.0 / math.sqrt(dim)
178
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
179
+ self.to_q = nn.Linear(dim, dim)
180
+ self.to_kv = nn.Linear(dim, dim * 2)
181
+ self.proj = nn.Linear(dim, dim)
182
+ self.norm = LayerNorm(dim, eps=norm_eps)
183
+ self.mlp = nn.Sequential(
184
+ nn.Linear(dim, int(dim * mlp_ratio)),
185
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
186
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
187
+
188
+ def forward(self, x):
189
+ """
190
+ x: [B, L, C].
191
+ """
192
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
193
+
194
+ # compute query, key, value
195
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
196
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
197
+
198
+ # compute attention
199
+ x = flash_attention(q, k, v, version=2)
200
+ x = x.reshape(b, 1, c)
201
+
202
+ # output
203
+ x = self.proj(x)
204
+ x = F.dropout(x, self.proj_dropout, self.training)
205
+
206
+ # mlp
207
+ x = x + self.mlp(self.norm(x))
208
+ return x[:, 0]
209
+
210
+
211
+ class VisionTransformer(nn.Module):
212
+
213
+ def __init__(self,
214
+ image_size=224,
215
+ patch_size=16,
216
+ dim=768,
217
+ mlp_ratio=4,
218
+ out_dim=512,
219
+ num_heads=12,
220
+ num_layers=12,
221
+ pool_type='token',
222
+ pre_norm=True,
223
+ post_norm=False,
224
+ activation='quick_gelu',
225
+ attn_dropout=0.0,
226
+ proj_dropout=0.0,
227
+ embedding_dropout=0.0,
228
+ norm_eps=1e-5):
229
+ if image_size % patch_size != 0:
230
+ print(
231
+ '[WARNING] image_size is not divisible by patch_size',
232
+ flush=True)
233
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
234
+ out_dim = out_dim or dim
235
+ super().__init__()
236
+ self.image_size = image_size
237
+ self.patch_size = patch_size
238
+ self.num_patches = (image_size // patch_size)**2
239
+ self.dim = dim
240
+ self.mlp_ratio = mlp_ratio
241
+ self.out_dim = out_dim
242
+ self.num_heads = num_heads
243
+ self.num_layers = num_layers
244
+ self.pool_type = pool_type
245
+ self.post_norm = post_norm
246
+ self.norm_eps = norm_eps
247
+
248
+ # embeddings
249
+ gain = 1.0 / math.sqrt(dim)
250
+ self.patch_embedding = nn.Conv2d(
251
+ 3,
252
+ dim,
253
+ kernel_size=patch_size,
254
+ stride=patch_size,
255
+ bias=not pre_norm)
256
+ if pool_type in ('token', 'token_fc'):
257
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
258
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
259
+ 1, self.num_patches +
260
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
261
+ self.dropout = nn.Dropout(embedding_dropout)
262
+
263
+ # transformer
264
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
265
+ self.transformer = nn.Sequential(*[
266
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
267
+ activation, attn_dropout, proj_dropout, norm_eps)
268
+ for _ in range(num_layers)
269
+ ])
270
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
271
+
272
+ # head
273
+ if pool_type == 'token':
274
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
275
+ elif pool_type == 'token_fc':
276
+ self.head = nn.Linear(dim, out_dim)
277
+ elif pool_type == 'attn_pool':
278
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
279
+ proj_dropout, norm_eps)
280
+
281
+ def forward(self, x, interpolation=False, use_31_block=False):
282
+ b = x.size(0)
283
+
284
+ # embeddings
285
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
286
+ if self.pool_type in ('token', 'token_fc'):
287
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
288
+ if interpolation:
289
+ e = pos_interpolate(self.pos_embedding, x.size(1))
290
+ else:
291
+ e = self.pos_embedding
292
+ x = self.dropout(x + e)
293
+ if self.pre_norm is not None:
294
+ x = self.pre_norm(x)
295
+
296
+ # transformer
297
+ if use_31_block:
298
+ x = self.transformer[:-1](x)
299
+ return x
300
+ else:
301
+ x = self.transformer(x)
302
+ return x
303
+
304
+
305
+ class XLMRobertaWithHead(XLMRoberta):
306
+
307
+ def __init__(self, **kwargs):
308
+ self.out_dim = kwargs.pop('out_dim')
309
+ super().__init__(**kwargs)
310
+
311
+ # head
312
+ mid_dim = (self.dim + self.out_dim) // 2
313
+ self.head = nn.Sequential(
314
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
315
+ nn.Linear(mid_dim, self.out_dim, bias=False))
316
+
317
+ def forward(self, ids):
318
+ # xlm-roberta
319
+ x = super().forward(ids)
320
+
321
+ # average pooling
322
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
323
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
324
+
325
+ # head
326
+ x = self.head(x)
327
+ return x
328
+
329
+
330
+ class XLMRobertaCLIP(nn.Module):
331
+
332
+ def __init__(self,
333
+ embed_dim=1024,
334
+ image_size=224,
335
+ patch_size=14,
336
+ vision_dim=1280,
337
+ vision_mlp_ratio=4,
338
+ vision_heads=16,
339
+ vision_layers=32,
340
+ vision_pool='token',
341
+ vision_pre_norm=True,
342
+ vision_post_norm=False,
343
+ activation='gelu',
344
+ vocab_size=250002,
345
+ max_text_len=514,
346
+ type_size=1,
347
+ pad_id=1,
348
+ text_dim=1024,
349
+ text_heads=16,
350
+ text_layers=24,
351
+ text_post_norm=True,
352
+ text_dropout=0.1,
353
+ attn_dropout=0.0,
354
+ proj_dropout=0.0,
355
+ embedding_dropout=0.0,
356
+ norm_eps=1e-5):
357
+ super().__init__()
358
+ self.embed_dim = embed_dim
359
+ self.image_size = image_size
360
+ self.patch_size = patch_size
361
+ self.vision_dim = vision_dim
362
+ self.vision_mlp_ratio = vision_mlp_ratio
363
+ self.vision_heads = vision_heads
364
+ self.vision_layers = vision_layers
365
+ self.vision_pre_norm = vision_pre_norm
366
+ self.vision_post_norm = vision_post_norm
367
+ self.activation = activation
368
+ self.vocab_size = vocab_size
369
+ self.max_text_len = max_text_len
370
+ self.type_size = type_size
371
+ self.pad_id = pad_id
372
+ self.text_dim = text_dim
373
+ self.text_heads = text_heads
374
+ self.text_layers = text_layers
375
+ self.text_post_norm = text_post_norm
376
+ self.norm_eps = norm_eps
377
+
378
+ # models
379
+ self.visual = VisionTransformer(
380
+ image_size=image_size,
381
+ patch_size=patch_size,
382
+ dim=vision_dim,
383
+ mlp_ratio=vision_mlp_ratio,
384
+ out_dim=embed_dim,
385
+ num_heads=vision_heads,
386
+ num_layers=vision_layers,
387
+ pool_type=vision_pool,
388
+ pre_norm=vision_pre_norm,
389
+ post_norm=vision_post_norm,
390
+ activation=activation,
391
+ attn_dropout=attn_dropout,
392
+ proj_dropout=proj_dropout,
393
+ embedding_dropout=embedding_dropout,
394
+ norm_eps=norm_eps)
395
+ self.textual = XLMRobertaWithHead(
396
+ vocab_size=vocab_size,
397
+ max_seq_len=max_text_len,
398
+ type_size=type_size,
399
+ pad_id=pad_id,
400
+ dim=text_dim,
401
+ out_dim=embed_dim,
402
+ num_heads=text_heads,
403
+ num_layers=text_layers,
404
+ post_norm=text_post_norm,
405
+ dropout=text_dropout)
406
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
407
+
408
+ def forward(self, imgs, txt_ids):
409
+ """
410
+ imgs: [B, 3, H, W] of torch.float32.
411
+ - mean: [0.48145466, 0.4578275, 0.40821073]
412
+ - std: [0.26862954, 0.26130258, 0.27577711]
413
+ txt_ids: [B, L] of torch.long.
414
+ Encoded by data.CLIPTokenizer.
415
+ """
416
+ xi = self.visual(imgs)
417
+ xt = self.textual(txt_ids)
418
+ return xi, xt
419
+
420
+ def param_groups(self):
421
+ groups = [{
422
+ 'params': [
423
+ p for n, p in self.named_parameters()
424
+ if 'norm' in n or n.endswith('bias')
425
+ ],
426
+ 'weight_decay': 0.0
427
+ }, {
428
+ 'params': [
429
+ p for n, p in self.named_parameters()
430
+ if not ('norm' in n or n.endswith('bias'))
431
+ ]
432
+ }]
433
+ return groups
434
+
435
+
436
+ def _clip(pretrained=False,
437
+ pretrained_name=None,
438
+ model_cls=XLMRobertaCLIP,
439
+ return_transforms=False,
440
+ return_tokenizer=False,
441
+ tokenizer_padding='eos',
442
+ dtype=torch.float32,
443
+ device='cpu',
444
+ **kwargs):
445
+ # init a model on device
446
+ with torch.device(device):
447
+ model = model_cls(**kwargs)
448
+
449
+ # set device
450
+ model = model.to(dtype=dtype, device=device)
451
+ output = (model,)
452
+
453
+ # init transforms
454
+ if return_transforms:
455
+ # mean and std
456
+ if 'siglip' in pretrained_name.lower():
457
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
458
+ else:
459
+ mean = [0.48145466, 0.4578275, 0.40821073]
460
+ std = [0.26862954, 0.26130258, 0.27577711]
461
+
462
+ # transforms
463
+ transforms = T.Compose([
464
+ T.Resize((model.image_size, model.image_size),
465
+ interpolation=T.InterpolationMode.BICUBIC),
466
+ T.ToTensor(),
467
+ T.Normalize(mean=mean, std=std)
468
+ ])
469
+ output += (transforms,)
470
+ return output[0] if len(output) == 1 else output
471
+
472
+
473
+ def clip_xlm_roberta_vit_h_14(
474
+ pretrained=False,
475
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
476
+ **kwargs):
477
+ cfg = dict(
478
+ embed_dim=1024,
479
+ image_size=224,
480
+ patch_size=14,
481
+ vision_dim=1280,
482
+ vision_mlp_ratio=4,
483
+ vision_heads=16,
484
+ vision_layers=32,
485
+ vision_pool='token',
486
+ activation='gelu',
487
+ vocab_size=250002,
488
+ max_text_len=514,
489
+ type_size=1,
490
+ pad_id=1,
491
+ text_dim=1024,
492
+ text_heads=16,
493
+ text_layers=24,
494
+ text_post_norm=True,
495
+ text_dropout=0.1,
496
+ attn_dropout=0.0,
497
+ proj_dropout=0.0,
498
+ embedding_dropout=0.0)
499
+ cfg.update(**kwargs)
500
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
501
+
502
+
503
+ class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
504
+
505
+ def __init__(self):
506
+ super(CLIPModel, self).__init__()
507
+ # init model
508
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
509
+ pretrained=False,
510
+ return_transforms=True,
511
+ return_tokenizer=False)
512
+
513
+ def forward(self, videos):
514
+ # preprocess
515
+ size = (self.model.image_size,) * 2
516
+ videos = torch.cat([
517
+ F.interpolate(
518
+ u.transpose(0, 1),
519
+ size=size,
520
+ mode='bicubic',
521
+ align_corners=False) for u in videos
522
+ ])
523
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
524
+
525
+ # forward
526
+ with torch.cuda.amp.autocast(dtype=self.dtype):
527
+ out = self.model.visual(videos, use_31_block=True)
528
+ return out
529
+
530
+ @classmethod
531
+ def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}):
532
+ def filter_kwargs(cls, kwargs):
533
+ import inspect
534
+ sig = inspect.signature(cls.__init__)
535
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
536
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
537
+ return filtered_kwargs
538
+
539
+ model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
540
+ if pretrained_model_path.endswith(".safetensors"):
541
+ from safetensors.torch import load_file, safe_open
542
+ state_dict = load_file(pretrained_model_path)
543
+ else:
544
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
545
+ tmp_state_dict = {}
546
+ for key in state_dict:
547
+ tmp_state_dict["model." + key] = state_dict[key]
548
+ state_dict = tmp_state_dict
549
+ m, u = model.load_state_dict(state_dict)
550
+
551
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
552
+ print(m, u)
553
+ return model
videox_fun/models/wan_text_encoder.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from diffusers.configuration_utils import ConfigMixin
10
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+
13
+
14
+ def fp16_clamp(x):
15
+ if x.dtype == torch.float16 and torch.isinf(x).any():
16
+ clamp = torch.finfo(x.dtype).max - 1000
17
+ x = torch.clamp(x, min=-clamp, max=clamp)
18
+ return x
19
+
20
+
21
+ def init_weights(m):
22
+ if isinstance(m, T5LayerNorm):
23
+ nn.init.ones_(m.weight)
24
+ elif isinstance(m, T5FeedForward):
25
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
26
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
27
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
28
+ elif isinstance(m, T5Attention):
29
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
30
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
31
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
32
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
33
+ elif isinstance(m, T5RelativeEmbedding):
34
+ nn.init.normal_(
35
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
36
+
37
+
38
+ class GELU(nn.Module):
39
+ def forward(self, x):
40
+ return 0.5 * x * (1.0 + torch.tanh(
41
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
42
+
43
+
44
+ class T5LayerNorm(nn.Module):
45
+ def __init__(self, dim, eps=1e-6):
46
+ super(T5LayerNorm, self).__init__()
47
+ self.dim = dim
48
+ self.eps = eps
49
+ self.weight = nn.Parameter(torch.ones(dim))
50
+
51
+ def forward(self, x):
52
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
53
+ self.eps)
54
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
55
+ x = x.type_as(self.weight)
56
+ return self.weight * x
57
+
58
+
59
+ class T5Attention(nn.Module):
60
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
61
+ assert dim_attn % num_heads == 0
62
+ super(T5Attention, self).__init__()
63
+ self.dim = dim
64
+ self.dim_attn = dim_attn
65
+ self.num_heads = num_heads
66
+ self.head_dim = dim_attn // num_heads
67
+
68
+ # layers
69
+ self.q = nn.Linear(dim, dim_attn, bias=False)
70
+ self.k = nn.Linear(dim, dim_attn, bias=False)
71
+ self.v = nn.Linear(dim, dim_attn, bias=False)
72
+ self.o = nn.Linear(dim_attn, dim, bias=False)
73
+ self.dropout = nn.Dropout(dropout)
74
+
75
+ def forward(self, x, context=None, mask=None, pos_bias=None):
76
+ """
77
+ x: [B, L1, C].
78
+ context: [B, L2, C] or None.
79
+ mask: [B, L2] or [B, L1, L2] or None.
80
+ """
81
+ # check inputs
82
+ context = x if context is None else context
83
+ b, n, c = x.size(0), self.num_heads, self.head_dim
84
+
85
+ # compute query, key, value
86
+ q = self.q(x).view(b, -1, n, c)
87
+ k = self.k(context).view(b, -1, n, c)
88
+ v = self.v(context).view(b, -1, n, c)
89
+
90
+ # attention bias
91
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
92
+ if pos_bias is not None:
93
+ attn_bias += pos_bias
94
+ if mask is not None:
95
+ assert mask.ndim in [2, 3]
96
+ mask = mask.view(b, 1, 1,
97
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
98
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
99
+
100
+ # compute attention (T5 does not use scaling)
101
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
102
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
103
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
104
+
105
+ # output
106
+ x = x.reshape(b, -1, n * c)
107
+ x = self.o(x)
108
+ x = self.dropout(x)
109
+ return x
110
+
111
+
112
+ class T5FeedForward(nn.Module):
113
+
114
+ def __init__(self, dim, dim_ffn, dropout=0.1):
115
+ super(T5FeedForward, self).__init__()
116
+ self.dim = dim
117
+ self.dim_ffn = dim_ffn
118
+
119
+ # layers
120
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
121
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
122
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
123
+ self.dropout = nn.Dropout(dropout)
124
+
125
+ def forward(self, x):
126
+ x = self.fc1(x) * self.gate(x)
127
+ x = self.dropout(x)
128
+ x = self.fc2(x)
129
+ x = self.dropout(x)
130
+ return x
131
+
132
+
133
+ class T5SelfAttention(nn.Module):
134
+ def __init__(self,
135
+ dim,
136
+ dim_attn,
137
+ dim_ffn,
138
+ num_heads,
139
+ num_buckets,
140
+ shared_pos=True,
141
+ dropout=0.1):
142
+ super(T5SelfAttention, self).__init__()
143
+ self.dim = dim
144
+ self.dim_attn = dim_attn
145
+ self.dim_ffn = dim_ffn
146
+ self.num_heads = num_heads
147
+ self.num_buckets = num_buckets
148
+ self.shared_pos = shared_pos
149
+
150
+ # layers
151
+ self.norm1 = T5LayerNorm(dim)
152
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
153
+ self.norm2 = T5LayerNorm(dim)
154
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
155
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
156
+ num_buckets, num_heads, bidirectional=True)
157
+
158
+ def forward(self, x, mask=None, pos_bias=None):
159
+ e = pos_bias if self.shared_pos else self.pos_embedding(
160
+ x.size(1), x.size(1))
161
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
162
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
163
+ return x
164
+
165
+
166
+ class T5CrossAttention(nn.Module):
167
+ def __init__(self,
168
+ dim,
169
+ dim_attn,
170
+ dim_ffn,
171
+ num_heads,
172
+ num_buckets,
173
+ shared_pos=True,
174
+ dropout=0.1):
175
+ super(T5CrossAttention, self).__init__()
176
+ self.dim = dim
177
+ self.dim_attn = dim_attn
178
+ self.dim_ffn = dim_ffn
179
+ self.num_heads = num_heads
180
+ self.num_buckets = num_buckets
181
+ self.shared_pos = shared_pos
182
+
183
+ # layers
184
+ self.norm1 = T5LayerNorm(dim)
185
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
186
+ self.norm2 = T5LayerNorm(dim)
187
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
188
+ self.norm3 = T5LayerNorm(dim)
189
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
190
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
191
+ num_buckets, num_heads, bidirectional=False)
192
+
193
+ def forward(self,
194
+ x,
195
+ mask=None,
196
+ encoder_states=None,
197
+ encoder_mask=None,
198
+ pos_bias=None):
199
+ e = pos_bias if self.shared_pos else self.pos_embedding(
200
+ x.size(1), x.size(1))
201
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
202
+ x = fp16_clamp(x + self.cross_attn(
203
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
204
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
205
+ return x
206
+
207
+
208
+ class T5RelativeEmbedding(nn.Module):
209
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
210
+ super(T5RelativeEmbedding, self).__init__()
211
+ self.num_buckets = num_buckets
212
+ self.num_heads = num_heads
213
+ self.bidirectional = bidirectional
214
+ self.max_dist = max_dist
215
+
216
+ # layers
217
+ self.embedding = nn.Embedding(num_buckets, num_heads)
218
+
219
+ def forward(self, lq, lk):
220
+ device = self.embedding.weight.device
221
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
222
+ # torch.arange(lq).unsqueeze(1).to(device)
223
+ if torch.device(type="meta") != device:
224
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
225
+ torch.arange(lq, device=device).unsqueeze(1)
226
+ else:
227
+ rel_pos = torch.arange(lk).unsqueeze(0) - \
228
+ torch.arange(lq).unsqueeze(1)
229
+ rel_pos = self._relative_position_bucket(rel_pos)
230
+ rel_pos_embeds = self.embedding(rel_pos)
231
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
232
+ 0) # [1, N, Lq, Lk]
233
+ return rel_pos_embeds.contiguous()
234
+
235
+ def _relative_position_bucket(self, rel_pos):
236
+ # preprocess
237
+ if self.bidirectional:
238
+ num_buckets = self.num_buckets // 2
239
+ rel_buckets = (rel_pos > 0).long() * num_buckets
240
+ rel_pos = torch.abs(rel_pos)
241
+ else:
242
+ num_buckets = self.num_buckets
243
+ rel_buckets = 0
244
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
245
+
246
+ # embeddings for small and large positions
247
+ max_exact = num_buckets // 2
248
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
249
+ math.log(self.max_dist / max_exact) *
250
+ (num_buckets - max_exact)).long()
251
+ rel_pos_large = torch.min(
252
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
253
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
254
+ return rel_buckets
255
+
256
+ class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
257
+ def __init__(self,
258
+ vocab,
259
+ dim,
260
+ dim_attn,
261
+ dim_ffn,
262
+ num_heads,
263
+ num_layers,
264
+ num_buckets,
265
+ shared_pos=True,
266
+ dropout=0.1):
267
+ super(WanT5EncoderModel, self).__init__()
268
+ self.dim = dim
269
+ self.dim_attn = dim_attn
270
+ self.dim_ffn = dim_ffn
271
+ self.num_heads = num_heads
272
+ self.num_layers = num_layers
273
+ self.num_buckets = num_buckets
274
+ self.shared_pos = shared_pos
275
+
276
+ # layers
277
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
278
+ else nn.Embedding(vocab, dim)
279
+ self.pos_embedding = T5RelativeEmbedding(
280
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
281
+ self.dropout = nn.Dropout(dropout)
282
+ self.blocks = nn.ModuleList([
283
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
284
+ shared_pos, dropout) for _ in range(num_layers)
285
+ ])
286
+ self.norm = T5LayerNorm(dim)
287
+
288
+ # initialize weights
289
+ self.apply(init_weights)
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: Optional[torch.LongTensor] = None,
294
+ attention_mask: Optional[torch.FloatTensor] = None,
295
+ ):
296
+ x = self.token_embedding(input_ids)
297
+ x = self.dropout(x)
298
+ e = self.pos_embedding(x.size(1),
299
+ x.size(1)) if self.shared_pos else None
300
+ for block in self.blocks:
301
+ x = block(x, attention_mask, pos_bias=e)
302
+ x = self.norm(x)
303
+ x = self.dropout(x)
304
+ return (x, )
305
+
306
+ @classmethod
307
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16):
308
+ def filter_kwargs(cls, kwargs):
309
+ import inspect
310
+ sig = inspect.signature(cls.__init__)
311
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
312
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
313
+ return filtered_kwargs
314
+
315
+ if low_cpu_mem_usage:
316
+ try:
317
+ import re
318
+
319
+ from diffusers import __version__ as diffusers_version
320
+ if diffusers_version >= "0.33.0":
321
+ from diffusers.models.model_loading_utils import \
322
+ load_model_dict_into_meta
323
+ else:
324
+ from diffusers.models.modeling_utils import \
325
+ load_model_dict_into_meta
326
+ from diffusers.utils import is_accelerate_available
327
+ if is_accelerate_available():
328
+ import accelerate
329
+
330
+ # Instantiate model with empty weights
331
+ with accelerate.init_empty_weights():
332
+ model = cls(**filter_kwargs(cls, additional_kwargs))
333
+
334
+ param_device = "cpu"
335
+ if pretrained_model_path.endswith(".safetensors"):
336
+ from safetensors.torch import load_file
337
+ state_dict = load_file(pretrained_model_path)
338
+ else:
339
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
340
+
341
+ if diffusers_version >= "0.33.0":
342
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
343
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
344
+ load_model_dict_into_meta(
345
+ model,
346
+ state_dict,
347
+ dtype=torch_dtype,
348
+ model_name_or_path=pretrained_model_path,
349
+ )
350
+ else:
351
+ # move the params from meta device to cpu
352
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
353
+ if len(missing_keys) > 0:
354
+ raise ValueError(
355
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
356
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
357
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
358
+ " those weights or else make sure your checkpoint file is correct."
359
+ )
360
+
361
+ unexpected_keys = load_model_dict_into_meta(
362
+ model,
363
+ state_dict,
364
+ device=param_device,
365
+ dtype=torch_dtype,
366
+ model_name_or_path=pretrained_model_path,
367
+ )
368
+
369
+ if cls._keys_to_ignore_on_load_unexpected is not None:
370
+ for pat in cls._keys_to_ignore_on_load_unexpected:
371
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
372
+
373
+ if len(unexpected_keys) > 0:
374
+ print(
375
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
376
+ )
377
+
378
+ return model
379
+ except Exception as e:
380
+ print(
381
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
382
+ )
383
+
384
+ model = cls(**filter_kwargs(cls, additional_kwargs))
385
+ if pretrained_model_path.endswith(".safetensors"):
386
+ from safetensors.torch import load_file, safe_open
387
+ state_dict = load_file(pretrained_model_path)
388
+ else:
389
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
390
+ m, u = model.load_state_dict(state_dict, strict=False)
391
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
392
+ print(m, u)
393
+
394
+ model = model.to(torch_dtype)
395
+ return model
videox_fun/models/wan_transformer3d.py ADDED
@@ -0,0 +1,1399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+
4
+ import glob
5
+ import json
6
+ import math
7
+ import os
8
+ import types
9
+ import warnings
10
+ from typing import Any, Dict, Optional, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.nn as nn
16
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.utils import is_torch_version, logging
20
+ from torch import nn
21
+
22
+ from ..dist import (get_sequence_parallel_rank,
23
+ get_sequence_parallel_world_size, get_sp_group,
24
+ usp_attn_forward, xFuserLongContextAttention)
25
+ from ..utils import cfg_skip
26
+ from .attention_utils import attention
27
+ from .cache_utils import TeaCache
28
+ from .wan_camera_adapter import SimpleAdapter
29
+
30
+
31
+ def sinusoidal_embedding_1d(dim, position):
32
+ # preprocess
33
+ assert dim % 2 == 0
34
+ half = dim // 2
35
+ position = position.type(torch.float64)
36
+
37
+ # calculation
38
+ sinusoid = torch.outer(
39
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
40
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
41
+ return x
42
+
43
+
44
+ @amp.autocast(enabled=False)
45
+ def rope_params(max_seq_len, dim, theta=10000):
46
+ assert dim % 2 == 0
47
+ freqs = torch.outer(
48
+ torch.arange(max_seq_len),
49
+ 1.0 / torch.pow(theta,
50
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
51
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
52
+ return freqs
53
+
54
+
55
+ # modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py
56
+ @amp.autocast(enabled=False)
57
+ def get_1d_rotary_pos_embed_riflex(
58
+ pos: Union[np.ndarray, int],
59
+ dim: int,
60
+ theta: float = 10000.0,
61
+ use_real=False,
62
+ k: Optional[int] = None,
63
+ L_test: Optional[int] = None,
64
+ L_test_scale: Optional[int] = None,
65
+ ):
66
+ """
67
+ RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
68
+
69
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
70
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
71
+ data type.
72
+
73
+ Args:
74
+ dim (`int`): Dimension of the frequency tensor.
75
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
76
+ theta (`float`, *optional*, defaults to 10000.0):
77
+ Scaling factor for frequency computation. Defaults to 10000.0.
78
+ use_real (`bool`, *optional*):
79
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
80
+ k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
81
+ L_test (`int`, *optional*, defaults to None): the number of frames for inference
82
+ Returns:
83
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
84
+ """
85
+ assert dim % 2 == 0
86
+
87
+ if isinstance(pos, int):
88
+ pos = torch.arange(pos)
89
+ if isinstance(pos, np.ndarray):
90
+ pos = torch.from_numpy(pos) # type: ignore # [S]
91
+
92
+ freqs = 1.0 / torch.pow(theta,
93
+ torch.arange(0, dim, 2).to(torch.float64).div(dim))
94
+
95
+ # === Riflex modification start ===
96
+ # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
97
+ # Empirical observations show that a few videos may exhibit repetition in the tail frames.
98
+ # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
99
+ if k is not None:
100
+ freqs[k-1] = 0.9 * 2 * torch.pi / L_test
101
+ # === Riflex modification end ===
102
+ if L_test_scale is not None:
103
+ freqs[k-1] = freqs[k-1] / L_test_scale
104
+
105
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
106
+ if use_real:
107
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
108
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
109
+ return freqs_cos, freqs_sin
110
+ else:
111
+ # lumina
112
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
113
+ return freqs_cis
114
+
115
+
116
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
117
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
118
+ tw = tgt_width
119
+ th = tgt_height
120
+ h, w = src
121
+ r = h / w
122
+ if r > (th / tw):
123
+ resize_height = th
124
+ resize_width = int(round(th / h * w))
125
+ else:
126
+ resize_width = tw
127
+ resize_height = int(round(tw / w * h))
128
+
129
+ crop_top = int(round((th - resize_height) / 2.0))
130
+ crop_left = int(round((tw - resize_width) / 2.0))
131
+
132
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
133
+
134
+
135
+ @amp.autocast(enabled=False)
136
+ @torch.compiler.disable()
137
+ def rope_apply(x, grid_sizes, freqs, frame_split_indices=None, ground_frame_indices=None):
138
+ n, c = x.size(2), x.size(3) // 2
139
+
140
+ # split freqs
141
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
142
+
143
+ # loop over samples
144
+ output = []
145
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
146
+ seq_len = f * h * w
147
+
148
+ # precompute multipliers
149
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
150
+ seq_len, n, -1, 2))
151
+
152
+ # Handle temporal freqs with split for paired data and ground frames
153
+ if frame_split_indices is not None and i < len(frame_split_indices):
154
+ # print("applying repeat rope")
155
+ # print(f"[ROPE] frame_split_indices: {frame_split_indices}")
156
+ # Split temporal positions: src [0, f_src-1], ground [0], tgt [0, f_tgt-1]
157
+ f_src = frame_split_indices[i]
158
+
159
+ # Check if we have ground frames
160
+ if ground_frame_indices is not None and i < len(ground_frame_indices):
161
+ ground_start, ground_end = ground_frame_indices[i]
162
+ f_ground = ground_end - ground_start
163
+ f_tgt = f - f_src - f_ground
164
+
165
+ # print(f"[ROPE] CoT data: f={f}, f_src={f_src}, f_ground={f_ground}, f_tgt={f_tgt}")
166
+ # print(f"[ROPE] ground_frame_indices: {ground_frame_indices}")
167
+ # exit()
168
+ # Generate independent temporal freqs
169
+ # Src: positions [1..f_src]
170
+
171
+ freqs_src_t = freqs[0][1:f_src + 1].view(f_src, 1, 1, -1).expand(f_src, h, w, -1)
172
+
173
+ # Ground: force all frames to use position 0
174
+ freqs_ground_t = freqs[0][:1].view(1, 1, 1, -1).repeat(f_ground, h, w, 1)
175
+
176
+ # Tgt: positions [1..f_tgt]
177
+ freqs_tgt_t = freqs[0][1:f_tgt + 1].view(f_tgt, 1, 1, -1).expand(f_tgt, h, w, -1)
178
+
179
+ freqs_temporal = torch.cat([freqs_src_t, freqs_ground_t, freqs_tgt_t], dim=0)
180
+ else:
181
+ # No ground frames, regular paired data
182
+ # print(f"[ROPE] Paired data: f={f}, f_src={f_src}, f_tgt={f - f_src}")
183
+ f_tgt = f - f_src
184
+
185
+ # Generate independent temporal freqs for src and tgt
186
+ freqs_src_t = freqs[0][:f_src].view(f_src, 1, 1, -1).expand(f_src, h, w, -1)
187
+ freqs_tgt_t = freqs[0][:f_tgt].view(f_tgt, 1, 1, -1).expand(f_tgt, h, w, -1)
188
+ freqs_temporal = torch.cat([freqs_src_t, freqs_tgt_t], dim=0)
189
+ else:
190
+ # Default: continuous temporal positions [0, f-1]
191
+ freqs_temporal = freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1)
192
+
193
+ # Combine temporal + spatial freqs
194
+ freqs_i = torch.cat([
195
+ freqs_temporal,
196
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
197
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
198
+ ], dim=-1).reshape(seq_len, 1, -1)
199
+
200
+ # apply rotary embedding
201
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
202
+ x_i = torch.cat([x_i, x[i, seq_len:]])
203
+ # append to collection
204
+ output.append(x_i)
205
+ return torch.stack(output).to(x.dtype)
206
+
207
+
208
+ def rope_apply_qk(q, k, grid_sizes, freqs, frame_split_indices=None, ground_frame_indices=None):
209
+ q = rope_apply(q, grid_sizes, freqs, frame_split_indices, ground_frame_indices)
210
+ k = rope_apply(k, grid_sizes, freqs, frame_split_indices, ground_frame_indices)
211
+ return q, k
212
+
213
+
214
+ class WanRMSNorm(nn.Module):
215
+
216
+ def __init__(self, dim, eps=1e-5):
217
+ super().__init__()
218
+ self.dim = dim
219
+ self.eps = eps
220
+ self.weight = nn.Parameter(torch.ones(dim))
221
+
222
+ def forward(self, x):
223
+ r"""
224
+ Args:
225
+ x(Tensor): Shape [B, L, C]
226
+ """
227
+ return self._norm(x) * self.weight
228
+
229
+ def _norm(self, x):
230
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(x.dtype)
231
+
232
+
233
+ class WanLayerNorm(nn.LayerNorm):
234
+
235
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
236
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
237
+
238
+ def forward(self, x):
239
+ r"""
240
+ Args:
241
+ x(Tensor): Shape [B, L, C]
242
+ """
243
+ return super().forward(x)
244
+
245
+
246
+ class WanSelfAttention(nn.Module):
247
+
248
+ def __init__(self,
249
+ dim,
250
+ num_heads,
251
+ window_size=(-1, -1),
252
+ qk_norm=True,
253
+ eps=1e-6):
254
+ assert dim % num_heads == 0
255
+ super().__init__()
256
+ self.dim = dim
257
+ self.num_heads = num_heads
258
+ self.head_dim = dim // num_heads
259
+ self.window_size = window_size
260
+ self.qk_norm = qk_norm
261
+ self.eps = eps
262
+
263
+ # layers
264
+ self.q = nn.Linear(dim, dim)
265
+ self.k = nn.Linear(dim, dim)
266
+ self.v = nn.Linear(dim, dim)
267
+ self.o = nn.Linear(dim, dim)
268
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
269
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
270
+
271
+ def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0, frame_split_indices=None, ground_frame_indices=None):
272
+ r"""
273
+ Args:
274
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
275
+ seq_lens(Tensor): Shape [B]
276
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
277
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
278
+ frame_split_indices(List[int], optional): Split indices for paired data temporal RoPE
279
+ ground_frame_indices(List[Tuple[int, int]], optional): Ground frame positions for special temporal RoPE
280
+ """
281
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
282
+
283
+ # query, key, value function
284
+ def qkv_fn(x):
285
+ q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d)
286
+ k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d)
287
+ v = self.v(x.to(dtype)).view(b, s, n, d)
288
+ return q, k, v
289
+
290
+ q, k, v = qkv_fn(x)
291
+
292
+ q, k = rope_apply_qk(q, k, grid_sizes, freqs, frame_split_indices, ground_frame_indices)
293
+
294
+ x = attention(
295
+ q.to(dtype),
296
+ k.to(dtype),
297
+ v=v.to(dtype),
298
+ k_lens=seq_lens,
299
+ window_size=self.window_size)
300
+ x = x.to(dtype)
301
+
302
+ # output
303
+ x = x.flatten(2)
304
+ x = self.o(x)
305
+ return x
306
+
307
+
308
+ class WanT2VCrossAttention(WanSelfAttention):
309
+
310
+ def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
311
+ r"""
312
+ Args:
313
+ x(Tensor): Shape [B, L1, C]
314
+ context(Tensor): Shape [B, L2, C]
315
+ context_lens(Tensor): Shape [B]
316
+ """
317
+ b, n, d = x.size(0), self.num_heads, self.head_dim
318
+
319
+ # compute query, key, value
320
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
321
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
322
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
323
+
324
+ # compute attention
325
+ x = attention(
326
+ q.to(dtype),
327
+ k.to(dtype),
328
+ v.to(dtype),
329
+ k_lens=context_lens
330
+ )
331
+ x = x.to(dtype)
332
+
333
+ # output
334
+ x = x.flatten(2)
335
+ x = self.o(x)
336
+ return x
337
+
338
+
339
+ class WanI2VCrossAttention(WanSelfAttention):
340
+
341
+ def __init__(self,
342
+ dim,
343
+ num_heads,
344
+ window_size=(-1, -1),
345
+ qk_norm=True,
346
+ eps=1e-6):
347
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
348
+
349
+ self.k_img = nn.Linear(dim, dim)
350
+ self.v_img = nn.Linear(dim, dim)
351
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
352
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
353
+
354
+ def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
355
+ r"""
356
+ Args:
357
+ x(Tensor): Shape [B, L1, C]
358
+ context(Tensor): Shape [B, L2, C]
359
+ context_lens(Tensor): Shape [B]
360
+ """
361
+ context_img = context[:, :257]
362
+ context = context[:, 257:]
363
+ b, n, d = x.size(0), self.num_heads, self.head_dim
364
+
365
+ # compute query, key, value
366
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
367
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
368
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
369
+ k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d)
370
+ v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d)
371
+
372
+ img_x = attention(
373
+ q.to(dtype),
374
+ k_img.to(dtype),
375
+ v_img.to(dtype),
376
+ k_lens=None
377
+ )
378
+ img_x = img_x.to(dtype)
379
+ # compute attention
380
+ x = attention(
381
+ q.to(dtype),
382
+ k.to(dtype),
383
+ v.to(dtype),
384
+ k_lens=context_lens
385
+ )
386
+ x = x.to(dtype)
387
+
388
+ # output
389
+ x = x.flatten(2)
390
+ img_x = img_x.flatten(2)
391
+ x = x + img_x
392
+ x = self.o(x)
393
+ return x
394
+
395
+
396
+ class WanCrossAttention(WanSelfAttention):
397
+ def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
398
+ r"""
399
+ Args:
400
+ x(Tensor): Shape [B, L1, C]
401
+ context(Tensor): Shape [B, L2, C]
402
+ context_lens(Tensor): Shape [B]
403
+ """
404
+ b, n, d = x.size(0), self.num_heads, self.head_dim
405
+ # compute query, key, value
406
+ q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
407
+ k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
408
+ v = self.v(context.to(dtype)).view(b, -1, n, d)
409
+ # compute attention
410
+ x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens)
411
+ # output
412
+ x = x.flatten(2)
413
+ x = self.o(x.to(dtype))
414
+ return x
415
+
416
+
417
+ WAN_CROSSATTENTION_CLASSES = {
418
+ 't2v_cross_attn': WanT2VCrossAttention,
419
+ 'i2v_cross_attn': WanI2VCrossAttention,
420
+ 'cross_attn': WanCrossAttention,
421
+ }
422
+
423
+
424
+ class WanAttentionBlock(nn.Module):
425
+
426
+ def __init__(self,
427
+ cross_attn_type,
428
+ dim,
429
+ ffn_dim,
430
+ num_heads,
431
+ window_size=(-1, -1),
432
+ qk_norm=True,
433
+ cross_attn_norm=False,
434
+ eps=1e-6):
435
+ super().__init__()
436
+ self.dim = dim
437
+ self.ffn_dim = ffn_dim
438
+ self.num_heads = num_heads
439
+ self.window_size = window_size
440
+ self.qk_norm = qk_norm
441
+ self.cross_attn_norm = cross_attn_norm
442
+ self.eps = eps
443
+
444
+ # layers
445
+ self.norm1 = WanLayerNorm(dim, eps)
446
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
447
+ eps)
448
+ self.norm3 = WanLayerNorm(
449
+ dim, eps,
450
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
451
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
452
+ num_heads,
453
+ (-1, -1),
454
+ qk_norm,
455
+ eps)
456
+ self.norm2 = WanLayerNorm(dim, eps)
457
+ self.ffn = nn.Sequential(
458
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
459
+ nn.Linear(ffn_dim, dim))
460
+
461
+ # modulation
462
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
463
+
464
+ def forward(
465
+ self,
466
+ x,
467
+ e,
468
+ seq_lens,
469
+ grid_sizes,
470
+ freqs,
471
+ context,
472
+ context_lens,
473
+ dtype=torch.bfloat16,
474
+ t=0,
475
+ frame_split_indices=None,
476
+ ground_frame_indices=None,
477
+ ):
478
+ r"""
479
+ Args:
480
+ x(Tensor): Shape [B, L, C]
481
+ e(Tensor): Shape [B, 6, C]
482
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
483
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
484
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
485
+ frame_split_indices(List[int], optional): Split indices for paired data temporal RoPE
486
+ ground_frame_indices(List[Tuple[int, int]], optional): Ground frame positions for special temporal RoPE
487
+ """
488
+ if e.dim() > 3:
489
+ e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
490
+ e = [e.squeeze(2) for e in e]
491
+ else:
492
+ e = (self.modulation + e).chunk(6, dim=1)
493
+
494
+ # self-attention
495
+ temp_x = self.norm1(x) * (1 + e[1]) + e[0]
496
+ temp_x = temp_x.to(dtype)
497
+
498
+ y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype, t=t, frame_split_indices=frame_split_indices, ground_frame_indices=ground_frame_indices)
499
+ x = x + y * e[2]
500
+
501
+ # cross-attention & ffn function
502
+ def cross_attn_ffn(x, context, context_lens, e):
503
+ # cross-attention
504
+ x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype, t=t)
505
+
506
+ # ffn function
507
+ temp_x = self.norm2(x) * (1 + e[4]) + e[3]
508
+ temp_x = temp_x.to(dtype)
509
+
510
+ y = self.ffn(temp_x)
511
+ x = x + y * e[5]
512
+ return x
513
+
514
+ x = cross_attn_ffn(x, context, context_lens, e)
515
+ return x
516
+
517
+
518
+ class Head(nn.Module):
519
+
520
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
521
+ super().__init__()
522
+ self.dim = dim
523
+ self.out_dim = out_dim
524
+ self.patch_size = patch_size
525
+ self.eps = eps
526
+
527
+ # layers
528
+ out_dim = math.prod(patch_size) * out_dim
529
+ self.norm = WanLayerNorm(dim, eps)
530
+ self.head = nn.Linear(dim, out_dim)
531
+
532
+ # modulation
533
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
534
+
535
+ def forward(self, x, e):
536
+ r"""
537
+ Args:
538
+ x(Tensor): Shape [B, L1, C]
539
+ e(Tensor): Shape [B, C]
540
+ """
541
+ if e.dim() > 2:
542
+ e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
543
+ e = [e.squeeze(2) for e in e]
544
+ else:
545
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
546
+
547
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
548
+ return x
549
+
550
+
551
+ class MLPProj(torch.nn.Module):
552
+
553
+ def __init__(self, in_dim, out_dim):
554
+ super().__init__()
555
+
556
+ self.proj = torch.nn.Sequential(
557
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
558
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
559
+ torch.nn.LayerNorm(out_dim))
560
+
561
+ def forward(self, image_embeds):
562
+ clip_extra_context_tokens = self.proj(image_embeds)
563
+ return clip_extra_context_tokens
564
+
565
+
566
+
567
+ class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
568
+ r"""
569
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
570
+ """
571
+
572
+ # ignore_for_config = [
573
+ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
574
+ # ]
575
+ # _no_split_modules = ['WanAttentionBlock']
576
+ _supports_gradient_checkpointing = True
577
+
578
+ @register_to_config
579
+ def __init__(
580
+ self,
581
+ model_type='t2v',
582
+ patch_size=(1, 2, 2),
583
+ text_len=512,
584
+ in_dim=16,
585
+ dim=2048,
586
+ ffn_dim=8192,
587
+ freq_dim=256,
588
+ text_dim=4096,
589
+ out_dim=16,
590
+ num_heads=16,
591
+ num_layers=32,
592
+ window_size=(-1, -1),
593
+ qk_norm=True,
594
+ cross_attn_norm=True,
595
+ eps=1e-6,
596
+ in_channels=16,
597
+ hidden_size=2048,
598
+ add_control_adapter=False,
599
+ in_dim_control_adapter=24,
600
+ downscale_factor_control_adapter=8,
601
+ add_ref_conv=False,
602
+ in_dim_ref_conv=16,
603
+ cross_attn_type=None,
604
+ ):
605
+ r"""
606
+ Initialize the diffusion model backbone.
607
+
608
+ Args:
609
+ model_type (`str`, *optional*, defaults to 't2v'):
610
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
611
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
612
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
613
+ text_len (`int`, *optional*, defaults to 512):
614
+ Fixed length for text embeddings
615
+ in_dim (`int`, *optional*, defaults to 16):
616
+ Input video channels (C_in)
617
+ dim (`int`, *optional*, defaults to 2048):
618
+ Hidden dimension of the transformer
619
+ ffn_dim (`int`, *optional*, defaults to 8192):
620
+ Intermediate dimension in feed-forward network
621
+ freq_dim (`int`, *optional*, defaults to 256):
622
+ Dimension for sinusoidal time embeddings
623
+ text_dim (`int`, *optional*, defaults to 4096):
624
+ Input dimension for text embeddings
625
+ out_dim (`int`, *optional*, defaults to 16):
626
+ Output video channels (C_out)
627
+ num_heads (`int`, *optional*, defaults to 16):
628
+ Number of attention heads
629
+ num_layers (`int`, *optional*, defaults to 32):
630
+ Number of transformer blocks
631
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
632
+ Window size for local attention (-1 indicates global attention)
633
+ qk_norm (`bool`, *optional*, defaults to True):
634
+ Enable query/key normalization
635
+ cross_attn_norm (`bool`, *optional*, defaults to False):
636
+ Enable cross-attention normalization
637
+ eps (`float`, *optional*, defaults to 1e-6):
638
+ Epsilon value for normalization layers
639
+ """
640
+
641
+ super().__init__()
642
+
643
+ # assert model_type in ['t2v', 'i2v', 'ti2v']
644
+ self.model_type = model_type
645
+
646
+ self.patch_size = patch_size
647
+ self.text_len = text_len
648
+ self.in_dim = in_dim
649
+ self.dim = dim
650
+ self.ffn_dim = ffn_dim
651
+ self.freq_dim = freq_dim
652
+ self.text_dim = text_dim
653
+ self.out_dim = out_dim
654
+ self.num_heads = num_heads
655
+ self.num_layers = num_layers
656
+ self.window_size = window_size
657
+ self.qk_norm = qk_norm
658
+ self.cross_attn_norm = cross_attn_norm
659
+ self.eps = eps
660
+
661
+ # embeddings
662
+ self.patch_embedding = nn.Conv3d(
663
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
664
+ self.text_embedding = nn.Sequential(
665
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
666
+ nn.Linear(dim, dim))
667
+
668
+ self.time_embedding = nn.Sequential(
669
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
670
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
671
+
672
+ # blocks
673
+ if cross_attn_type is None:
674
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
675
+ self.blocks = nn.ModuleList([
676
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
677
+ window_size, qk_norm, cross_attn_norm, eps)
678
+ for _ in range(num_layers)
679
+ ])
680
+ for layer_idx, block in enumerate(self.blocks):
681
+ block.self_attn.layer_idx = layer_idx
682
+ block.self_attn.num_layers = self.num_layers
683
+
684
+ # head
685
+ self.head = Head(dim, out_dim, patch_size, eps)
686
+
687
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
688
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
689
+ d = dim // num_heads
690
+ self.d = d
691
+ self.dim = dim
692
+ self.freqs = torch.cat(
693
+ [
694
+ rope_params(1024, d - 4 * (d // 6)),
695
+ rope_params(1024, 2 * (d // 6)),
696
+ rope_params(1024, 2 * (d // 6))
697
+ ],
698
+ dim=1
699
+ )
700
+
701
+ if model_type == 'i2v':
702
+ self.img_emb = MLPProj(1280, dim)
703
+
704
+ if add_control_adapter:
705
+ self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], downscale_factor=downscale_factor_control_adapter)
706
+ else:
707
+ self.control_adapter = None
708
+
709
+ if add_ref_conv:
710
+ self.ref_conv = nn.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
711
+ else:
712
+ self.ref_conv = None
713
+
714
+ self.teacache = None
715
+ self.cfg_skip_ratio = None
716
+ self.current_steps = 0
717
+ self.num_inference_steps = None
718
+ self.gradient_checkpointing = False
719
+ self.sp_world_size = 1
720
+ self.sp_world_rank = 0
721
+ self.init_weights()
722
+
723
+ def _set_gradient_checkpointing(self, *args, **kwargs):
724
+ if "value" in kwargs:
725
+ self.gradient_checkpointing = kwargs["value"]
726
+ elif "enable" in kwargs:
727
+ self.gradient_checkpointing = kwargs["enable"]
728
+ else:
729
+ raise ValueError("Invalid set gradient checkpointing")
730
+
731
+ def enable_teacache(
732
+ self,
733
+ coefficients,
734
+ num_steps: int,
735
+ rel_l1_thresh: float,
736
+ num_skip_start_steps: int = 0,
737
+ offload: bool = True,
738
+ ):
739
+ self.teacache = TeaCache(
740
+ coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
741
+ )
742
+
743
+ def share_teacache(
744
+ self,
745
+ transformer = None,
746
+ ):
747
+ self.teacache = transformer.teacache
748
+
749
+ def disable_teacache(self):
750
+ self.teacache = None
751
+
752
+ def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
753
+ if cfg_skip_ratio != 0:
754
+ self.cfg_skip_ratio = cfg_skip_ratio
755
+ self.current_steps = 0
756
+ self.num_inference_steps = num_steps
757
+ else:
758
+ self.cfg_skip_ratio = None
759
+ self.current_steps = 0
760
+ self.num_inference_steps = None
761
+
762
+ def share_cfg_skip(
763
+ self,
764
+ transformer = None,
765
+ ):
766
+ self.cfg_skip_ratio = transformer.cfg_skip_ratio
767
+ self.current_steps = transformer.current_steps
768
+ self.num_inference_steps = transformer.num_inference_steps
769
+
770
+ def disable_cfg_skip(self):
771
+ self.cfg_skip_ratio = None
772
+ self.current_steps = 0
773
+ self.num_inference_steps = None
774
+
775
+ def enable_riflex(
776
+ self,
777
+ k = 6,
778
+ L_test = 66,
779
+ L_test_scale = 4.886,
780
+ ):
781
+ device = self.freqs.device
782
+ self.freqs = torch.cat(
783
+ [
784
+ get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale),
785
+ rope_params(1024, 2 * (self.d // 6)),
786
+ rope_params(1024, 2 * (self.d // 6))
787
+ ],
788
+ dim=1
789
+ ).to(device)
790
+
791
+ def disable_riflex(self):
792
+ device = self.freqs.device
793
+ self.freqs = torch.cat(
794
+ [
795
+ rope_params(1024, self.d - 4 * (self.d // 6)),
796
+ rope_params(1024, 2 * (self.d // 6)),
797
+ rope_params(1024, 2 * (self.d // 6))
798
+ ],
799
+ dim=1
800
+ ).to(device)
801
+
802
+ def enable_multi_gpus_inference(self,):
803
+ self.sp_world_size = get_sequence_parallel_world_size()
804
+ self.sp_world_rank = get_sequence_parallel_rank()
805
+ self.all_gather = get_sp_group().all_gather
806
+
807
+ # For normal model.
808
+ for block in self.blocks:
809
+ block.self_attn.forward = types.MethodType(
810
+ usp_attn_forward, block.self_attn)
811
+
812
+ # For vace model.
813
+ if hasattr(self, 'vace_blocks'):
814
+ for block in self.vace_blocks:
815
+ block.self_attn.forward = types.MethodType(
816
+ usp_attn_forward, block.self_attn)
817
+
818
+ @cfg_skip()
819
+ def forward(
820
+ self,
821
+ x,
822
+ t,
823
+ context,
824
+ seq_len,
825
+ clip_fea=None,
826
+ y=None,
827
+ y_camera=None,
828
+ full_ref=None,
829
+ subject_ref=None,
830
+ cond_flag=True,
831
+ frame_split_indices=None,
832
+ ground_frame_indices=None,
833
+ ):
834
+ r"""
835
+ Forward pass through the diffusion model
836
+
837
+ Args:
838
+ x (List[Tensor]):
839
+ List of input video tensors, each with shape [C_in, F, H, W]
840
+ t (Tensor):
841
+ Diffusion timesteps tensor of shape [B]
842
+ context (List[Tensor]):
843
+ List of text embeddings each with shape [L, C]
844
+ seq_len (`int`):
845
+ Maximum sequence length for positional encoding
846
+ clip_fea (Tensor, *optional*):
847
+ CLIP image features for image-to-video mode
848
+ y (List[Tensor], *optional*):
849
+ Conditional video inputs for image-to-video mode, same shape as x
850
+ cond_flag (`bool`, *optional*, defaults to True):
851
+ Flag to indicate whether to forward the condition input
852
+
853
+ Returns:
854
+ List[Tensor]:
855
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
856
+ """
857
+ # Wan2.2 don't need a clip.
858
+ # if self.model_type == 'i2v':
859
+ # assert clip_fea is not None and y is not None
860
+ # params
861
+ device = self.patch_embedding.weight.device
862
+ dtype = x.dtype
863
+ if self.freqs.device != device and torch.device(type="meta") != device:
864
+ self.freqs = self.freqs.to(device)
865
+
866
+ if y is not None:
867
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
868
+
869
+ # embeddings
870
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
871
+ # add control adapter
872
+ if self.control_adapter is not None and y_camera is not None:
873
+ y_camera = self.control_adapter(y_camera)
874
+ x = [u + v for u, v in zip(x, y_camera)]
875
+
876
+ grid_sizes = torch.stack(
877
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
878
+
879
+ x = [u.flatten(2).transpose(1, 2) for u in x]
880
+ if self.ref_conv is not None and full_ref is not None:
881
+ full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
882
+ grid_sizes = torch.stack([torch.tensor([u[0] + 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
883
+ seq_len += full_ref.size(1)
884
+ x = [torch.concat([_full_ref.unsqueeze(0), u], dim=1) for _full_ref, u in zip(full_ref, x)]
885
+ if t.dim() != 1 and t.size(1) < seq_len:
886
+ pad_size = seq_len - t.size(1)
887
+ last_elements = t[:, -1].unsqueeze(1)
888
+ padding = last_elements.repeat(1, pad_size)
889
+ t = torch.cat([padding, t], dim=1)
890
+
891
+ if subject_ref is not None:
892
+ subject_ref_frames = subject_ref.size(2)
893
+ subject_ref = self.patch_embedding(subject_ref).flatten(2).transpose(1, 2)
894
+ grid_sizes = torch.stack([torch.tensor([u[0] + subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
895
+ seq_len += subject_ref.size(1)
896
+ x = [torch.concat([u, _subject_ref.unsqueeze(0)], dim=1) for _subject_ref, u in zip(subject_ref, x)]
897
+ if t.dim() != 1 and t.size(1) < seq_len:
898
+ pad_size = seq_len - t.size(1)
899
+ last_elements = t[:, -1].unsqueeze(1)
900
+ padding = last_elements.repeat(1, pad_size)
901
+ t = torch.cat([t, padding], dim=1)
902
+
903
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
904
+ if self.sp_world_size > 1:
905
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
906
+ assert seq_lens.max() <= seq_len
907
+ x = torch.cat([
908
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
909
+ dim=1) for u in x
910
+ ])
911
+
912
+ # time embeddings
913
+ with amp.autocast(dtype=torch.float32):
914
+ if t.dim() != 1:
915
+ if t.size(1) < seq_len:
916
+ pad_size = seq_len - t.size(1)
917
+ last_elements = t[:, -1].unsqueeze(1)
918
+ padding = last_elements.repeat(1, pad_size)
919
+ t = torch.cat([t, padding], dim=1)
920
+ bt = t.size(0)
921
+ ft = t.flatten()
922
+ e = self.time_embedding(
923
+ sinusoidal_embedding_1d(self.freq_dim,
924
+ ft).unflatten(0, (bt, seq_len)).float())
925
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
926
+ else:
927
+ e = self.time_embedding(
928
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
929
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
930
+
931
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
932
+ # e0 = e0.to(dtype)
933
+ # e = e.to(dtype)
934
+
935
+ # context
936
+ context_lens = None
937
+ context = self.text_embedding(
938
+ torch.stack([
939
+ torch.cat(
940
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
941
+ for u in context
942
+ ]))
943
+
944
+ if clip_fea is not None:
945
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
946
+ context = torch.concat([context_clip, context], dim=1)
947
+
948
+ # Context Parallel
949
+ if self.sp_world_size > 1:
950
+ x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
951
+ if t.dim() != 1:
952
+ e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
953
+ e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
954
+
955
+ # TeaCache
956
+ if self.teacache is not None:
957
+ if cond_flag:
958
+ if t.dim() != 1:
959
+ modulated_inp = e0[:, -1, :]
960
+ else:
961
+ modulated_inp = e0
962
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
963
+ if skip_flag:
964
+ self.should_calc = True
965
+ self.teacache.accumulated_rel_l1_distance = 0
966
+ else:
967
+ if cond_flag:
968
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
969
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
970
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
971
+ self.should_calc = False
972
+ else:
973
+ self.should_calc = True
974
+ self.teacache.accumulated_rel_l1_distance = 0
975
+ self.teacache.previous_modulated_input = modulated_inp
976
+ self.teacache.should_calc = self.should_calc
977
+ else:
978
+ self.should_calc = self.teacache.should_calc
979
+
980
+ # TeaCache
981
+ if self.teacache is not None:
982
+ if not self.should_calc:
983
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
984
+ x = x + previous_residual.to(x.device)[-x.size()[0]:,]
985
+ else:
986
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
987
+
988
+ for block in self.blocks:
989
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
990
+
991
+ def create_custom_forward(module):
992
+ def custom_forward(*inputs):
993
+ return module(*inputs)
994
+
995
+ return custom_forward
996
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
997
+ x = torch.utils.checkpoint.checkpoint(
998
+ create_custom_forward(block),
999
+ x,
1000
+ e0,
1001
+ seq_lens,
1002
+ grid_sizes,
1003
+ self.freqs,
1004
+ context,
1005
+ context_lens,
1006
+ dtype,
1007
+ t,
1008
+ frame_split_indices,
1009
+ ground_frame_indices,
1010
+ **ckpt_kwargs,
1011
+ )
1012
+ else:
1013
+ # arguments
1014
+ kwargs = dict(
1015
+ e=e0,
1016
+ seq_lens=seq_lens,
1017
+ grid_sizes=grid_sizes,
1018
+ freqs=self.freqs,
1019
+ context=context,
1020
+ context_lens=context_lens,
1021
+ dtype=dtype,
1022
+ t=t,
1023
+ frame_split_indices=frame_split_indices,
1024
+ ground_frame_indices=ground_frame_indices,
1025
+ )
1026
+ x = block(x, **kwargs)
1027
+
1028
+ if cond_flag:
1029
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
1030
+ else:
1031
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
1032
+ else:
1033
+ for block in self.blocks:
1034
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1035
+
1036
+ def create_custom_forward(module):
1037
+ def custom_forward(*inputs):
1038
+ return module(*inputs)
1039
+
1040
+ return custom_forward
1041
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1042
+ x = torch.utils.checkpoint.checkpoint(
1043
+ create_custom_forward(block),
1044
+ x,
1045
+ e0,
1046
+ seq_lens,
1047
+ grid_sizes,
1048
+ self.freqs,
1049
+ context,
1050
+ context_lens,
1051
+ dtype,
1052
+ t,
1053
+ frame_split_indices,
1054
+ ground_frame_indices,
1055
+ **ckpt_kwargs,
1056
+ )
1057
+ else:
1058
+ # arguments
1059
+ kwargs = dict(
1060
+ e=e0,
1061
+ seq_lens=seq_lens,
1062
+ grid_sizes=grid_sizes,
1063
+ freqs=self.freqs,
1064
+ context=context,
1065
+ context_lens=context_lens,
1066
+ dtype=dtype,
1067
+ t=t,
1068
+ frame_split_indices=frame_split_indices,
1069
+ ground_frame_indices=ground_frame_indices,
1070
+ )
1071
+ x = block(x, **kwargs)
1072
+
1073
+ # head
1074
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1075
+ def create_custom_forward(module):
1076
+ def custom_forward(*inputs):
1077
+ return module(*inputs)
1078
+
1079
+ return custom_forward
1080
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1081
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
1082
+ else:
1083
+ x = self.head(x, e)
1084
+
1085
+ if self.sp_world_size > 1:
1086
+ x = self.all_gather(x, dim=1)
1087
+
1088
+ if self.ref_conv is not None and full_ref is not None:
1089
+ full_ref_length = full_ref.size(1)
1090
+ x = x[:, full_ref_length:]
1091
+ grid_sizes = torch.stack([torch.tensor([u[0] - 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
1092
+
1093
+ if subject_ref is not None:
1094
+ subject_ref_length = subject_ref.size(1)
1095
+ x = x[:, :-subject_ref_length]
1096
+ grid_sizes = torch.stack([torch.tensor([u[0] - subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
1097
+
1098
+ # unpatchify
1099
+ x = self.unpatchify(x, grid_sizes)
1100
+ x = torch.stack(x)
1101
+ if self.teacache is not None and cond_flag:
1102
+ self.teacache.cnt += 1
1103
+ if self.teacache.cnt == self.teacache.num_steps:
1104
+ self.teacache.reset()
1105
+ return x
1106
+
1107
+
1108
+ def unpatchify(self, x, grid_sizes):
1109
+ r"""
1110
+ Reconstruct video tensors from patch embeddings.
1111
+
1112
+ Args:
1113
+ x (List[Tensor]):
1114
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
1115
+ grid_sizes (Tensor):
1116
+ Original spatial-temporal grid dimensions before patching,
1117
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
1118
+
1119
+ Returns:
1120
+ List[Tensor]:
1121
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
1122
+ """
1123
+
1124
+ c = self.out_dim
1125
+ out = []
1126
+ for u, v in zip(x, grid_sizes.tolist()):
1127
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
1128
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
1129
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
1130
+ out.append(u)
1131
+ return out
1132
+
1133
+ def init_weights(self):
1134
+ r"""
1135
+ Initialize model parameters using Xavier initialization.
1136
+ """
1137
+
1138
+ # basic init
1139
+ for m in self.modules():
1140
+ if isinstance(m, nn.Linear):
1141
+ nn.init.xavier_uniform_(m.weight)
1142
+ if m.bias is not None:
1143
+ nn.init.zeros_(m.bias)
1144
+
1145
+ # init embeddings
1146
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
1147
+ for m in self.text_embedding.modules():
1148
+ if isinstance(m, nn.Linear):
1149
+ nn.init.normal_(m.weight, std=.02)
1150
+ for m in self.time_embedding.modules():
1151
+ if isinstance(m, nn.Linear):
1152
+ nn.init.normal_(m.weight, std=.02)
1153
+
1154
+ # init output layer
1155
+ nn.init.zeros_(self.head.head.weight)
1156
+
1157
+ @classmethod
1158
+ def from_pretrained(
1159
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1160
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1161
+ ):
1162
+ if subfolder is not None:
1163
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1164
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1165
+
1166
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1167
+ if not os.path.isfile(config_file):
1168
+ raise RuntimeError(f"{config_file} does not exist")
1169
+ with open(config_file, "r") as f:
1170
+ config = json.load(f)
1171
+
1172
+ from diffusers.utils import WEIGHTS_NAME
1173
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1174
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1175
+
1176
+ if "dict_mapping" in transformer_additional_kwargs.keys():
1177
+ for key in transformer_additional_kwargs["dict_mapping"]:
1178
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
1179
+
1180
+ if low_cpu_mem_usage:
1181
+ try:
1182
+ import re
1183
+
1184
+ from diffusers import __version__ as diffusers_version
1185
+ if diffusers_version >= "0.33.0":
1186
+ from diffusers.models.model_loading_utils import \
1187
+ load_model_dict_into_meta
1188
+ else:
1189
+ from diffusers.models.modeling_utils import \
1190
+ load_model_dict_into_meta
1191
+ from diffusers.utils import is_accelerate_available
1192
+ if is_accelerate_available():
1193
+ import accelerate
1194
+
1195
+ # Instantiate model with empty weights
1196
+ with accelerate.init_empty_weights():
1197
+ model = cls.from_config(config, **transformer_additional_kwargs)
1198
+
1199
+ param_device = "cpu"
1200
+ if os.path.exists(model_file):
1201
+ state_dict = torch.load(model_file, map_location="cpu")
1202
+ elif os.path.exists(model_file_safetensors):
1203
+ from safetensors.torch import load_file, safe_open
1204
+ state_dict = load_file(model_file_safetensors)
1205
+ else:
1206
+ from safetensors.torch import load_file, safe_open
1207
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1208
+ state_dict = {}
1209
+ print(model_files_safetensors)
1210
+ for _model_file_safetensors in model_files_safetensors:
1211
+ _state_dict = load_file(_model_file_safetensors)
1212
+ for key in _state_dict:
1213
+ state_dict[key] = _state_dict[key]
1214
+
1215
+ if diffusers_version >= "0.33.0":
1216
+ # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1217
+ # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1218
+ load_model_dict_into_meta(
1219
+ model,
1220
+ state_dict,
1221
+ dtype=torch_dtype,
1222
+ model_name_or_path=pretrained_model_path,
1223
+ )
1224
+ else:
1225
+ model._convert_deprecated_attention_blocks(state_dict)
1226
+ # move the params from meta device to cpu
1227
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
1228
+ if len(missing_keys) > 0:
1229
+ raise ValueError(
1230
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
1231
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
1232
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
1233
+ " those weights or else make sure your checkpoint file is correct."
1234
+ )
1235
+
1236
+ unexpected_keys = load_model_dict_into_meta(
1237
+ model,
1238
+ state_dict,
1239
+ device=param_device,
1240
+ dtype=torch_dtype,
1241
+ model_name_or_path=pretrained_model_path,
1242
+ )
1243
+
1244
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1245
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1246
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1247
+
1248
+ if len(unexpected_keys) > 0:
1249
+ print(
1250
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1251
+ )
1252
+
1253
+ return model
1254
+ except Exception as e:
1255
+ print(
1256
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1257
+ )
1258
+
1259
+ model = cls.from_config(config, **transformer_additional_kwargs)
1260
+ if os.path.exists(model_file):
1261
+ state_dict = torch.load(model_file, map_location="cpu")
1262
+ elif os.path.exists(model_file_safetensors):
1263
+ from safetensors.torch import load_file, safe_open
1264
+ state_dict = load_file(model_file_safetensors)
1265
+ else:
1266
+ from safetensors.torch import load_file, safe_open
1267
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1268
+ state_dict = {}
1269
+ for _model_file_safetensors in model_files_safetensors:
1270
+ _state_dict = load_file(_model_file_safetensors)
1271
+ for key in _state_dict:
1272
+ state_dict[key] = _state_dict[key]
1273
+
1274
+ if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
1275
+ model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :]
1276
+ model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
1277
+ state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
1278
+
1279
+ tmp_state_dict = {}
1280
+ for key in state_dict:
1281
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1282
+ tmp_state_dict[key] = state_dict[key]
1283
+ else:
1284
+ print(key, "Size don't match, skip")
1285
+
1286
+ state_dict = tmp_state_dict
1287
+
1288
+ m, u = model.load_state_dict(state_dict, strict=False)
1289
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1290
+ print(m)
1291
+
1292
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1293
+ print(f"### All Parameters: {sum(params) / 1e6} M")
1294
+
1295
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1296
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1297
+
1298
+ model = model.to(torch_dtype)
1299
+ return model
1300
+
1301
+
1302
+ class Wan2_2Transformer3DModel(WanTransformer3DModel):
1303
+ r"""
1304
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
1305
+ """
1306
+
1307
+ # ignore_for_config = [
1308
+ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
1309
+ # ]
1310
+ # _no_split_modules = ['WanAttentionBlock']
1311
+ _supports_gradient_checkpointing = True
1312
+
1313
+ def __init__(
1314
+ self,
1315
+ model_type='t2v',
1316
+ patch_size=(1, 2, 2),
1317
+ text_len=512,
1318
+ in_dim=16,
1319
+ dim=2048,
1320
+ ffn_dim=8192,
1321
+ freq_dim=256,
1322
+ text_dim=4096,
1323
+ out_dim=16,
1324
+ num_heads=16,
1325
+ num_layers=32,
1326
+ window_size=(-1, -1),
1327
+ qk_norm=True,
1328
+ cross_attn_norm=True,
1329
+ eps=1e-6,
1330
+ in_channels=16,
1331
+ hidden_size=2048,
1332
+ add_control_adapter=False,
1333
+ in_dim_control_adapter=24,
1334
+ downscale_factor_control_adapter=8,
1335
+ add_ref_conv=False,
1336
+ in_dim_ref_conv=16,
1337
+ ):
1338
+ r"""
1339
+ Initialize the diffusion model backbone.
1340
+ Args:
1341
+ model_type (`str`, *optional*, defaults to 't2v'):
1342
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
1343
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
1344
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
1345
+ text_len (`int`, *optional*, defaults to 512):
1346
+ Fixed length for text embeddings
1347
+ in_dim (`int`, *optional*, defaults to 16):
1348
+ Input video channels (C_in)
1349
+ dim (`int`, *optional*, defaults to 2048):
1350
+ Hidden dimension of the transformer
1351
+ ffn_dim (`int`, *optional*, defaults to 8192):
1352
+ Intermediate dimension in feed-forward network
1353
+ freq_dim (`int`, *optional*, defaults to 256):
1354
+ Dimension for sinusoidal time embeddings
1355
+ text_dim (`int`, *optional*, defaults to 4096):
1356
+ Input dimension for text embeddings
1357
+ out_dim (`int`, *optional*, defaults to 16):
1358
+ Output video channels (C_out)
1359
+ num_heads (`int`, *optional*, defaults to 16):
1360
+ Number of attention heads
1361
+ num_layers (`int`, *optional*, defaults to 32):
1362
+ Number of transformer blocks
1363
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
1364
+ Window size for local attention (-1 indicates global attention)
1365
+ qk_norm (`bool`, *optional*, defaults to True):
1366
+ Enable query/key normalization
1367
+ cross_attn_norm (`bool`, *optional*, defaults to False):
1368
+ Enable cross-attention normalization
1369
+ eps (`float`, *optional*, defaults to 1e-6):
1370
+ Epsilon value for normalization layers
1371
+ """
1372
+ super().__init__(
1373
+ model_type=model_type,
1374
+ patch_size=patch_size,
1375
+ text_len=text_len,
1376
+ in_dim=in_dim,
1377
+ dim=dim,
1378
+ ffn_dim=ffn_dim,
1379
+ freq_dim=freq_dim,
1380
+ text_dim=text_dim,
1381
+ out_dim=out_dim,
1382
+ num_heads=num_heads,
1383
+ num_layers=num_layers,
1384
+ window_size=window_size,
1385
+ qk_norm=qk_norm,
1386
+ cross_attn_norm=cross_attn_norm,
1387
+ eps=eps,
1388
+ in_channels=in_channels,
1389
+ hidden_size=hidden_size,
1390
+ add_control_adapter=add_control_adapter,
1391
+ in_dim_control_adapter=in_dim_control_adapter,
1392
+ downscale_factor_control_adapter=downscale_factor_control_adapter,
1393
+ add_ref_conv=add_ref_conv,
1394
+ in_dim_ref_conv=in_dim_ref_conv,
1395
+ cross_attn_type="cross_attn"
1396
+ )
1397
+
1398
+ if hasattr(self, "img_emb"):
1399
+ del self.img_emb
videox_fun/models/wan_transformer3d_s2v.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/model_s2v.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+
4
+ import math
5
+ import types
6
+ from copy import deepcopy
7
+ from typing import Any, Dict
8
+
9
+ import torch
10
+ import torch.cuda.amp as amp
11
+ import torch.nn as nn
12
+ from diffusers.configuration_utils import register_to_config
13
+ from diffusers.utils import is_torch_version
14
+ from einops import rearrange
15
+
16
+ from ..dist import (get_sequence_parallel_rank,
17
+ get_sequence_parallel_world_size, get_sp_group,
18
+ usp_attn_s2v_forward)
19
+ from .attention_utils import attention
20
+ from .wan_audio_injector import (AudioInjector_WAN, CausalAudioEncoder,
21
+ FramePackMotioner, MotionerTransformers,
22
+ rope_precompute)
23
+ from .wan_transformer3d import (Wan2_2Transformer3DModel, WanAttentionBlock,
24
+ WanLayerNorm, WanSelfAttention,
25
+ sinusoidal_embedding_1d)
26
+
27
+
28
+ def zero_module(module):
29
+ """
30
+ Zero out the parameters of a module and return it.
31
+ """
32
+ for p in module.parameters():
33
+ p.detach().zero_()
34
+ return module
35
+
36
+
37
+ def torch_dfs(model: nn.Module, parent_name='root'):
38
+ module_names, modules = [], []
39
+ current_name = parent_name if parent_name else 'root'
40
+ module_names.append(current_name)
41
+ modules.append(model)
42
+
43
+ for name, child in model.named_children():
44
+ if parent_name:
45
+ child_name = f'{parent_name}.{name}'
46
+ else:
47
+ child_name = name
48
+ child_modules, child_names = torch_dfs(child, child_name)
49
+ module_names += child_names
50
+ modules += child_modules
51
+ return modules, module_names
52
+
53
+
54
+ @amp.autocast(enabled=False)
55
+ @torch.compiler.disable()
56
+ def s2v_rope_apply(x, grid_sizes, freqs, start=None):
57
+ n, c = x.size(2), x.size(3) // 2
58
+ # loop over samples
59
+ output = []
60
+ for i, _ in enumerate(x):
61
+ s = x.size(1)
62
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
63
+ freqs_i = freqs[i, :s]
64
+ # apply rotary embedding
65
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
66
+ x_i = torch.cat([x_i, x[i, s:]])
67
+ # append to collection
68
+ output.append(x_i)
69
+ return torch.stack(output).float()
70
+
71
+
72
+ def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
73
+ q = s2v_rope_apply(q, grid_sizes, freqs)
74
+ k = s2v_rope_apply(k, grid_sizes, freqs)
75
+ return q, k
76
+
77
+
78
+ class WanS2VSelfAttention(WanSelfAttention):
79
+
80
+ def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0):
81
+ """
82
+ Args:
83
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
84
+ seq_lens(Tensor): Shape [B]
85
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
86
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
87
+ """
88
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
89
+
90
+ # query, key, value function
91
+ def qkv_fn(x):
92
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
93
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
94
+ v = self.v(x).view(b, s, n, d)
95
+ return q, k, v
96
+
97
+ q, k, v = qkv_fn(x)
98
+
99
+ q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
100
+
101
+ x = attention(
102
+ q.to(dtype),
103
+ k.to(dtype),
104
+ v=v.to(dtype),
105
+ k_lens=seq_lens,
106
+ window_size=self.window_size)
107
+ x = x.to(dtype)
108
+
109
+ # output
110
+ x = x.flatten(2)
111
+ x = self.o(x)
112
+ return x
113
+
114
+
115
+ class WanS2VAttentionBlock(WanAttentionBlock):
116
+
117
+ def __init__(self,
118
+ cross_attn_type,
119
+ dim,
120
+ ffn_dim,
121
+ num_heads,
122
+ window_size=(-1, -1),
123
+ qk_norm=True,
124
+ cross_attn_norm=False,
125
+ eps=1e-6):
126
+ super().__init__(
127
+ cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps
128
+ )
129
+ self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,qk_norm, eps)
130
+
131
+ def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, dtype=torch.bfloat16, t=0):
132
+ # e
133
+ seg_idx = e[1].item()
134
+ seg_idx = min(max(0, seg_idx), x.size(1))
135
+ seg_idx = [0, seg_idx, x.size(1)]
136
+ e = e[0]
137
+ modulation = self.modulation.unsqueeze(2)
138
+ e = (modulation + e).chunk(6, dim=1)
139
+ e = [element.squeeze(1) for element in e]
140
+
141
+ # norm
142
+ norm_x = self.norm1(x).float()
143
+ parts = []
144
+ for i in range(2):
145
+ parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] *
146
+ (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
147
+ norm_x = torch.cat(parts, dim=1)
148
+ # self-attention
149
+ y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs)
150
+ with amp.autocast(dtype=torch.float32):
151
+ z = []
152
+ for i in range(2):
153
+ z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
154
+ y = torch.cat(z, dim=1)
155
+ x = x + y
156
+
157
+ # cross-attention & ffn function
158
+ def cross_attn_ffn(x, context, context_lens, e):
159
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
160
+ norm2_x = self.norm2(x).float()
161
+ parts = []
162
+ for i in range(2):
163
+ parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] *
164
+ (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
165
+ norm2_x = torch.cat(parts, dim=1)
166
+ y = self.ffn(norm2_x)
167
+ with amp.autocast(dtype=torch.float32):
168
+ z = []
169
+ for i in range(2):
170
+ z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
171
+ y = torch.cat(z, dim=1)
172
+ x = x + y
173
+ return x
174
+
175
+ x = cross_attn_ffn(x, context, context_lens, e)
176
+ return x
177
+
178
+
179
+ class Wan2_2Transformer3DModel_S2V(Wan2_2Transformer3DModel):
180
+ # ignore_for_config = [
181
+ # 'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm',
182
+ # 'text_dim', 'window_size'
183
+ # ]
184
+ # _no_split_modules = ['WanS2VAttentionBlock']
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ cond_dim=0,
190
+ audio_dim=5120,
191
+ num_audio_token=4,
192
+ enable_adain=False,
193
+ adain_mode="attn_norm",
194
+ audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
195
+ zero_init=False,
196
+ zero_timestep=False,
197
+ enable_motioner=True,
198
+ add_last_motion=True,
199
+ enable_tsm=False,
200
+ trainable_token_pos_emb=False,
201
+ motion_token_num=1024,
202
+ enable_framepack=False, # Mutually exclusive with enable_motioner
203
+ framepack_drop_mode="drop",
204
+ model_type='s2v',
205
+ patch_size=(1, 2, 2),
206
+ text_len=512,
207
+ in_dim=16,
208
+ dim=2048,
209
+ ffn_dim=8192,
210
+ freq_dim=256,
211
+ text_dim=4096,
212
+ out_dim=16,
213
+ num_heads=16,
214
+ num_layers=32,
215
+ window_size=(-1, -1),
216
+ qk_norm=True,
217
+ cross_attn_norm=True,
218
+ eps=1e-6,
219
+ in_channels=16,
220
+ hidden_size=2048,
221
+ *args,
222
+ **kwargs
223
+ ):
224
+ super().__init__(
225
+ model_type=model_type,
226
+ patch_size=patch_size,
227
+ text_len=text_len,
228
+ in_dim=in_dim,
229
+ dim=dim,
230
+ ffn_dim=ffn_dim,
231
+ freq_dim=freq_dim,
232
+ text_dim=text_dim,
233
+ out_dim=out_dim,
234
+ num_heads=num_heads,
235
+ num_layers=num_layers,
236
+ window_size=window_size,
237
+ qk_norm=qk_norm,
238
+ cross_attn_norm=cross_attn_norm,
239
+ eps=eps,
240
+ in_channels=in_channels,
241
+ hidden_size=hidden_size
242
+ )
243
+
244
+ assert model_type == 's2v'
245
+ self.enbale_adain = enable_adain
246
+ # Whether to assign 0 value timestep to ref/motion
247
+ self.adain_mode = adain_mode
248
+ self.zero_timestep = zero_timestep
249
+ self.enable_motioner = enable_motioner
250
+ self.add_last_motion = add_last_motion
251
+ self.enable_framepack = enable_framepack
252
+
253
+ # Replace blocks
254
+ self.blocks = nn.ModuleList([
255
+ WanS2VAttentionBlock("cross_attn", dim, ffn_dim, num_heads, window_size, qk_norm,
256
+ cross_attn_norm, eps)
257
+ for _ in range(num_layers)
258
+ ])
259
+
260
+ # init audio injector
261
+ all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
262
+ if cond_dim > 0:
263
+ self.cond_encoder = nn.Conv3d(
264
+ cond_dim,
265
+ self.dim,
266
+ kernel_size=self.patch_size,
267
+ stride=self.patch_size)
268
+ self.trainable_cond_mask = nn.Embedding(3, self.dim)
269
+ self.casual_audio_encoder = CausalAudioEncoder(
270
+ dim=audio_dim,
271
+ out_dim=self.dim,
272
+ num_token=num_audio_token,
273
+ need_global=enable_adain)
274
+ self.audio_injector = AudioInjector_WAN(
275
+ all_modules,
276
+ all_modules_names,
277
+ dim=self.dim,
278
+ num_heads=self.num_heads,
279
+ inject_layer=audio_inject_layers,
280
+ root_net=self,
281
+ enable_adain=enable_adain,
282
+ adain_dim=self.dim,
283
+ need_adain_ont=adain_mode != "attn_norm",
284
+ )
285
+
286
+ if zero_init:
287
+ self.zero_init_weights()
288
+
289
+ # init motioner
290
+ if enable_motioner and enable_framepack:
291
+ raise ValueError(
292
+ "enable_motioner and enable_framepack are mutually exclusive, please set one of them to False"
293
+ )
294
+ if enable_motioner:
295
+ motioner_dim = 2048
296
+ self.motioner = MotionerTransformers(
297
+ patch_size=(2, 4, 4),
298
+ dim=motioner_dim,
299
+ ffn_dim=motioner_dim,
300
+ freq_dim=256,
301
+ out_dim=16,
302
+ num_heads=16,
303
+ num_layers=13,
304
+ window_size=(-1, -1),
305
+ qk_norm=True,
306
+ cross_attn_norm=False,
307
+ eps=1e-6,
308
+ motion_token_num=motion_token_num,
309
+ enable_tsm=enable_tsm,
310
+ motion_stride=4,
311
+ expand_ratio=2,
312
+ trainable_token_pos_emb=trainable_token_pos_emb,
313
+ )
314
+ self.zip_motion_out = torch.nn.Sequential(
315
+ WanLayerNorm(motioner_dim),
316
+ zero_module(nn.Linear(motioner_dim, self.dim)))
317
+
318
+ self.trainable_token_pos_emb = trainable_token_pos_emb
319
+ if trainable_token_pos_emb:
320
+ d = self.dim // self.num_heads
321
+ x = torch.zeros([1, motion_token_num, self.num_heads, d])
322
+ x[..., ::2] = 1
323
+
324
+ gride_sizes = [[
325
+ torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
326
+ torch.tensor([
327
+ 1, self.motioner.motion_side_len,
328
+ self.motioner.motion_side_len
329
+ ]).unsqueeze(0).repeat(1, 1),
330
+ torch.tensor([
331
+ 1, self.motioner.motion_side_len,
332
+ self.motioner.motion_side_len
333
+ ]).unsqueeze(0).repeat(1, 1),
334
+ ]]
335
+ token_freqs = s2v_rope_apply(x, gride_sizes, self.freqs)
336
+ token_freqs = token_freqs[0, :,
337
+ 0].reshape(motion_token_num, -1, 2)
338
+ token_freqs = token_freqs * 0.01
339
+ self.token_freqs = torch.nn.Parameter(token_freqs)
340
+
341
+ if enable_framepack:
342
+ self.frame_packer = FramePackMotioner(
343
+ inner_dim=self.dim,
344
+ num_heads=self.num_heads,
345
+ zip_frame_buckets=[1, 2, 16],
346
+ drop_mode=framepack_drop_mode)
347
+
348
+ def enable_multi_gpus_inference(self,):
349
+ self.sp_world_size = get_sequence_parallel_world_size()
350
+ self.sp_world_rank = get_sequence_parallel_rank()
351
+ self.all_gather = get_sp_group().all_gather
352
+ for block in self.blocks:
353
+ block.self_attn.forward = types.MethodType(
354
+ usp_attn_s2v_forward, block.self_attn)
355
+
356
+ def process_motion(self, motion_latents, drop_motion_frames=False):
357
+ if drop_motion_frames or motion_latents[0].shape[1] == 0:
358
+ return [], []
359
+ self.lat_motion_frames = motion_latents[0].shape[1]
360
+ mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents]
361
+ batch_size = len(mot)
362
+
363
+ mot_remb = []
364
+ flattern_mot = []
365
+ for bs in range(batch_size):
366
+ height, width = mot[bs].shape[3], mot[bs].shape[4]
367
+ flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous()
368
+ motion_grid_sizes = [[
369
+ torch.tensor([-self.lat_motion_frames, 0,
370
+ 0]).unsqueeze(0).repeat(1, 1),
371
+ torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1),
372
+ torch.tensor([self.lat_motion_frames, height,
373
+ width]).unsqueeze(0).repeat(1, 1)
374
+ ]]
375
+ motion_rope_emb = rope_precompute(
376
+ flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads,
377
+ self.dim // self.num_heads),
378
+ motion_grid_sizes,
379
+ self.freqs,
380
+ start=None)
381
+ mot_remb.append(motion_rope_emb)
382
+ flattern_mot.append(flat_mot)
383
+ return flattern_mot, mot_remb
384
+
385
+ def process_motion_frame_pack(self,
386
+ motion_latents,
387
+ drop_motion_frames=False,
388
+ add_last_motion=2):
389
+ flattern_mot, mot_remb = self.frame_packer(motion_latents,
390
+ add_last_motion)
391
+ if drop_motion_frames:
392
+ return [m[:, :0] for m in flattern_mot
393
+ ], [m[:, :0] for m in mot_remb]
394
+ else:
395
+ return flattern_mot, mot_remb
396
+
397
+ def process_motion_transformer_motioner(self,
398
+ motion_latents,
399
+ drop_motion_frames=False,
400
+ add_last_motion=True):
401
+ batch_size, height, width = len(
402
+ motion_latents), motion_latents[0].shape[2] // self.patch_size[
403
+ 1], motion_latents[0].shape[3] // self.patch_size[2]
404
+
405
+ freqs = self.freqs
406
+ device = self.patch_embedding.weight.device
407
+ if freqs.device != device:
408
+ freqs = freqs.to(device)
409
+ if self.trainable_token_pos_emb:
410
+ with amp.autocast(dtype=torch.float64):
411
+ token_freqs = self.token_freqs.to(torch.float64)
412
+ token_freqs = token_freqs / token_freqs.norm(
413
+ dim=-1, keepdim=True)
414
+ freqs = [freqs, torch.view_as_complex(token_freqs)]
415
+
416
+ if not drop_motion_frames and add_last_motion:
417
+ last_motion_latent = [u[:, -1:] for u in motion_latents]
418
+ last_mot = [
419
+ self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent
420
+ ]
421
+ last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot]
422
+ last_mot = torch.cat(last_mot)
423
+ gride_sizes = [[
424
+ torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
425
+ torch.tensor([0, height,
426
+ width]).unsqueeze(0).repeat(batch_size, 1),
427
+ torch.tensor([1, height,
428
+ width]).unsqueeze(0).repeat(batch_size, 1)
429
+ ]]
430
+ else:
431
+ last_mot = torch.zeros([batch_size, 0, self.dim],
432
+ device=motion_latents[0].device,
433
+ dtype=motion_latents[0].dtype)
434
+ gride_sizes = []
435
+
436
+ zip_motion = self.motioner(motion_latents)
437
+ zip_motion = self.zip_motion_out(zip_motion)
438
+ if drop_motion_frames:
439
+ zip_motion = zip_motion * 0.0
440
+ zip_motion_grid_sizes = [[
441
+ torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
442
+ torch.tensor([
443
+ 0, self.motioner.motion_side_len, self.motioner.motion_side_len
444
+ ]).unsqueeze(0).repeat(batch_size, 1),
445
+ torch.tensor(
446
+ [1 if not self.trainable_token_pos_emb else -1, height,
447
+ width]).unsqueeze(0).repeat(batch_size, 1),
448
+ ]]
449
+
450
+ mot = torch.cat([last_mot, zip_motion], dim=1)
451
+ gride_sizes = gride_sizes + zip_motion_grid_sizes
452
+
453
+ motion_rope_emb = rope_precompute(
454
+ mot.detach().view(batch_size, mot.shape[1], self.num_heads,
455
+ self.dim // self.num_heads),
456
+ gride_sizes,
457
+ freqs,
458
+ start=None)
459
+ return [m.unsqueeze(0) for m in mot
460
+ ], [r.unsqueeze(0) for r in motion_rope_emb]
461
+
462
+ def inject_motion(self,
463
+ x,
464
+ seq_lens,
465
+ rope_embs,
466
+ mask_input,
467
+ motion_latents,
468
+ drop_motion_frames=False,
469
+ add_last_motion=True):
470
+ # Inject the motion frames token to the hidden states
471
+ if self.enable_motioner:
472
+ mot, mot_remb = self.process_motion_transformer_motioner(
473
+ motion_latents,
474
+ drop_motion_frames=drop_motion_frames,
475
+ add_last_motion=add_last_motion)
476
+ elif self.enable_framepack:
477
+ mot, mot_remb = self.process_motion_frame_pack(
478
+ motion_latents,
479
+ drop_motion_frames=drop_motion_frames,
480
+ add_last_motion=add_last_motion)
481
+ else:
482
+ mot, mot_remb = self.process_motion(
483
+ motion_latents, drop_motion_frames=drop_motion_frames)
484
+
485
+ if len(mot) > 0:
486
+ x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)]
487
+ seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot],
488
+ dtype=torch.long)
489
+ rope_embs = [
490
+ torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)
491
+ ]
492
+ mask_input = [
493
+ torch.cat([
494
+ m, 2 * torch.ones([1, u.shape[1] - m.shape[1]],
495
+ device=m.device,
496
+ dtype=m.dtype)
497
+ ],
498
+ dim=1) for m, u in zip(mask_input, x)
499
+ ]
500
+ return x, seq_lens, rope_embs, mask_input
501
+
502
+ def after_transformer_block(self, block_idx, hidden_states):
503
+ if block_idx in self.audio_injector.injected_block_id.keys():
504
+ audio_attn_id = self.audio_injector.injected_block_id[block_idx]
505
+ audio_emb = self.merged_audio_emb # b f n c
506
+ num_frames = audio_emb.shape[1]
507
+
508
+ if self.sp_world_size > 1:
509
+ hidden_states = self.all_gather(hidden_states, dim=1)
510
+
511
+ input_hidden_states = hidden_states[:, :self.original_seq_len].clone()
512
+ input_hidden_states = rearrange(
513
+ input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
514
+
515
+ if self.enbale_adain and self.adain_mode == "attn_norm":
516
+ audio_emb_global = self.audio_emb_global
517
+ audio_emb_global = rearrange(audio_emb_global,
518
+ "b t n c -> (b t) n c")
519
+ adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](
520
+ input_hidden_states, temb=audio_emb_global[:, 0]
521
+ )
522
+ attn_hidden_states = adain_hidden_states
523
+ else:
524
+ attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](
525
+ input_hidden_states
526
+ )
527
+ audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
528
+ attn_audio_emb = audio_emb
529
+ context_lens = torch.ones(
530
+ attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device
531
+ ) * attn_audio_emb.shape[1]
532
+ residual_out = self.audio_injector.injector[audio_attn_id](
533
+ x=attn_hidden_states,
534
+ context=attn_audio_emb,
535
+ context_lens=context_lens)
536
+ residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
537
+ hidden_states[:, :self.original_seq_len] = hidden_states[:, :self.original_seq_len] + residual_out
538
+
539
+ if self.sp_world_size > 1:
540
+ hidden_states = torch.chunk(
541
+ hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
542
+
543
+ return hidden_states
544
+
545
+ def forward(
546
+ self,
547
+ x,
548
+ t,
549
+ context,
550
+ seq_len,
551
+ ref_latents,
552
+ motion_latents,
553
+ cond_states,
554
+ audio_input=None,
555
+ motion_frames=[17, 5],
556
+ add_last_motion=2,
557
+ drop_motion_frames=False,
558
+ cond_flag=True,
559
+ *extra_args,
560
+ **extra_kwargs
561
+ ):
562
+ """
563
+ x: A list of videos each with shape [C, T, H, W].
564
+ t: [B].
565
+ context: A list of text embeddings each with shape [L, C].
566
+ seq_len: A list of video token lens, no need for this model.
567
+ ref_latents A list of reference image for each video with shape [C, 1, H, W].
568
+ motion_latents A list of motion frames for each video with shape [C, T_m, H, W].
569
+ cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W].
570
+ audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a].
571
+ motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]
572
+ add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added.
573
+ For frame packing, the behavior depends on the value of add_last_motion:
574
+ add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included.
575
+ add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included.
576
+ add_last_motion = 2: All motion-related latents are used.
577
+ drop_motion_frames Bool, whether drop the motion frames info
578
+ """
579
+ device = self.patch_embedding.weight.device
580
+ dtype = x.dtype
581
+ add_last_motion = self.add_last_motion * add_last_motion
582
+
583
+ # Embeddings
584
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
585
+
586
+ # Audio process
587
+ audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
588
+ audio_emb_res = self.casual_audio_encoder(audio_input)
589
+ if self.enbale_adain:
590
+ audio_emb_global, audio_emb = audio_emb_res
591
+ self.audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
592
+ else:
593
+ audio_emb = audio_emb_res
594
+ self.merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
595
+
596
+ # Cond states
597
+ cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states]
598
+ x = [x_ + pose for x_, pose in zip(x, cond)]
599
+
600
+ grid_sizes = torch.stack(
601
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
602
+ x = [u.flatten(2).transpose(1, 2) for u in x]
603
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
604
+
605
+ original_grid_sizes = deepcopy(grid_sizes)
606
+ grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
607
+
608
+ # Ref latents
609
+ ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents]
610
+ batch_size = len(ref)
611
+ height, width = ref[0].shape[3], ref[0].shape[4]
612
+ ref = [r.flatten(2).transpose(1, 2) for r in ref] # r: 1 c f h w
613
+ x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)]
614
+
615
+ self.original_seq_len = seq_lens[0]
616
+ seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref], dtype=torch.long)
617
+ ref_grid_sizes = [
618
+ [
619
+ torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), # the start index
620
+ torch.tensor([31, height,width]).unsqueeze(0).repeat(batch_size, 1), # the end index
621
+ torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1),
622
+ ] # the range
623
+ ]
624
+ grid_sizes = grid_sizes + ref_grid_sizes
625
+
626
+ # Compute the rope embeddings for the input
627
+ x = torch.cat(x)
628
+ b, s, n, d = x.size(0), x.size(1), self.num_heads, self.dim // self.num_heads
629
+ self.pre_compute_freqs = rope_precompute(
630
+ x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None)
631
+ x = [u.unsqueeze(0) for u in x]
632
+ self.pre_compute_freqs = [u.unsqueeze(0) for u in self.pre_compute_freqs]
633
+
634
+ # Inject Motion latents.
635
+ # Initialize masks to indicate noisy latent, ref latent, and motion latent.
636
+ # However, at this point, only the first two (noisy and ref latents) are marked;
637
+ # the marking of motion latent will be implemented inside `inject_motion`.
638
+ mask_input = [
639
+ torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device)
640
+ for u in x
641
+ ]
642
+ for i in range(len(mask_input)):
643
+ mask_input[i][:, self.original_seq_len:] = 1
644
+
645
+ self.lat_motion_frames = motion_latents[0].shape[1]
646
+ x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion(
647
+ x,
648
+ seq_lens,
649
+ self.pre_compute_freqs,
650
+ mask_input,
651
+ motion_latents,
652
+ drop_motion_frames=drop_motion_frames,
653
+ add_last_motion=add_last_motion)
654
+ x = torch.cat(x, dim=0)
655
+ self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0)
656
+ mask_input = torch.cat(mask_input, dim=0)
657
+
658
+ # Apply trainable_cond_mask
659
+ x = x + self.trainable_cond_mask(mask_input).to(x.dtype)
660
+
661
+ seq_len = seq_lens.max()
662
+ if self.sp_world_size > 1:
663
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
664
+ assert seq_lens.max() <= seq_len
665
+ x = torch.cat([
666
+ torch.cat([u.unsqueeze(0), u.new_zeros(1, seq_len - u.size(0), u.size(1))],
667
+ dim=1) for u in x
668
+ ])
669
+
670
+ # Time embeddings
671
+ if self.zero_timestep:
672
+ t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])
673
+ with amp.autocast(dtype=torch.float32):
674
+ e = self.time_embedding(
675
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
676
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
677
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
678
+ if self.zero_timestep:
679
+ e = e[:-1]
680
+ zero_e0 = e0[-1:]
681
+ e0 = e0[:-1]
682
+ token_len = x.shape[1]
683
+
684
+ e0 = torch.cat(
685
+ [
686
+ e0.unsqueeze(2),
687
+ zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)
688
+ ],
689
+ dim=2
690
+ )
691
+ e0 = [e0, self.original_seq_len]
692
+ else:
693
+ e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1)
694
+ e0 = [e0, 0]
695
+
696
+ # context
697
+ context_lens = None
698
+ context = self.text_embedding(
699
+ torch.stack([
700
+ torch.cat(
701
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
702
+ for u in context
703
+ ]))
704
+
705
+ if self.sp_world_size > 1:
706
+ # Sharded tensors for long context attn
707
+ x = torch.chunk(x, self.sp_world_size, dim=1)
708
+ sq_size = [u.shape[1] for u in x]
709
+ sq_start_size = sum(sq_size[:self.sp_world_rank])
710
+ x = x[self.sp_world_rank]
711
+ # Confirm the application range of the time embedding in e0[0] for each sequence:
712
+ # - For tokens before seg_id: apply e0[0][:, :, 0]
713
+ # - For tokens after seg_id: apply e0[0][:, :, 1]
714
+ sp_size = x.shape[1]
715
+ seg_idx = e0[1] - sq_start_size
716
+ e0[1] = seg_idx
717
+
718
+ self.pre_compute_freqs = torch.chunk(self.pre_compute_freqs, self.sp_world_size, dim=1)
719
+ self.pre_compute_freqs = self.pre_compute_freqs[self.sp_world_rank]
720
+
721
+ # TeaCache
722
+ if self.teacache is not None:
723
+ if cond_flag:
724
+ if t.dim() != 1:
725
+ modulated_inp = e0[0][:, -1, :]
726
+ else:
727
+ modulated_inp = e0[0]
728
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
729
+ if skip_flag:
730
+ self.should_calc = True
731
+ self.teacache.accumulated_rel_l1_distance = 0
732
+ else:
733
+ if cond_flag:
734
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
735
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
736
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
737
+ self.should_calc = False
738
+ else:
739
+ self.should_calc = True
740
+ self.teacache.accumulated_rel_l1_distance = 0
741
+ self.teacache.previous_modulated_input = modulated_inp
742
+ self.teacache.should_calc = self.should_calc
743
+ else:
744
+ self.should_calc = self.teacache.should_calc
745
+
746
+ # TeaCache
747
+ if self.teacache is not None:
748
+ if not self.should_calc:
749
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
750
+ x = x + previous_residual.to(x.device)[-x.size()[0]:,]
751
+ else:
752
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
753
+
754
+ for idx, block in enumerate(self.blocks):
755
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
756
+
757
+ def create_custom_forward(module):
758
+ def custom_forward(*inputs):
759
+ return module(*inputs)
760
+
761
+ return custom_forward
762
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
763
+ x = torch.utils.checkpoint.checkpoint(
764
+ create_custom_forward(block),
765
+ x,
766
+ e0,
767
+ seq_lens,
768
+ grid_sizes,
769
+ self.pre_compute_freqs,
770
+ context,
771
+ context_lens,
772
+ dtype,
773
+ t,
774
+ **ckpt_kwargs,
775
+ )
776
+ x = self.after_transformer_block(idx, x)
777
+ else:
778
+ # arguments
779
+ kwargs = dict(
780
+ e=e0,
781
+ seq_lens=seq_lens,
782
+ grid_sizes=grid_sizes,
783
+ freqs=self.pre_compute_freqs,
784
+ context=context,
785
+ context_lens=context_lens,
786
+ dtype=dtype,
787
+ t=t
788
+ )
789
+ x = block(x, **kwargs)
790
+ x = self.after_transformer_block(idx, x)
791
+
792
+ if cond_flag:
793
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
794
+ else:
795
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
796
+ else:
797
+ for idx, block in enumerate(self.blocks):
798
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
799
+
800
+ def create_custom_forward(module):
801
+ def custom_forward(*inputs):
802
+ return module(*inputs)
803
+
804
+ return custom_forward
805
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
806
+ x = torch.utils.checkpoint.checkpoint(
807
+ create_custom_forward(block),
808
+ x,
809
+ e0,
810
+ seq_lens,
811
+ grid_sizes,
812
+ self.pre_compute_freqs,
813
+ context,
814
+ context_lens,
815
+ dtype,
816
+ t,
817
+ **ckpt_kwargs,
818
+ )
819
+ x = self.after_transformer_block(idx, x)
820
+ else:
821
+ # arguments
822
+ kwargs = dict(
823
+ e=e0,
824
+ seq_lens=seq_lens,
825
+ grid_sizes=grid_sizes,
826
+ freqs=self.pre_compute_freqs,
827
+ context=context,
828
+ context_lens=context_lens,
829
+ dtype=dtype,
830
+ t=t
831
+ )
832
+ x = block(x, **kwargs)
833
+ x = self.after_transformer_block(idx, x)
834
+
835
+ # Context Parallel
836
+ if self.sp_world_size > 1:
837
+ x = self.all_gather(x.contiguous(), dim=1)
838
+
839
+ # Unpatchify
840
+ x = x[:, :self.original_seq_len]
841
+ # Head
842
+ x = self.head(x, e)
843
+ x = self.unpatchify(x, original_grid_sizes)
844
+ x = torch.stack(x)
845
+ if self.teacache is not None and cond_flag:
846
+ self.teacache.cnt += 1
847
+ if self.teacache.cnt == self.teacache.num_steps:
848
+ self.teacache.reset()
849
+ return x
850
+
851
+ def unpatchify(self, x, grid_sizes):
852
+ """
853
+ Reconstruct video tensors from patch embeddings.
854
+
855
+ Args:
856
+ x (List[Tensor]):
857
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
858
+ grid_sizes (Tensor):
859
+ Original spatial-temporal grid dimensions before patching,
860
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
861
+
862
+ Returns:
863
+ List[Tensor]:
864
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
865
+ """
866
+
867
+ c = self.out_dim
868
+ out = []
869
+ for u, v in zip(x, grid_sizes.tolist()):
870
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
871
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
872
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
873
+ out.append(u)
874
+ return out
875
+
876
+ def zero_init_weights(self):
877
+ with torch.no_grad():
878
+ self.trainable_cond_mask = zero_module(self.trainable_cond_mask)
879
+ if hasattr(self, "cond_encoder"):
880
+ self.cond_encoder = zero_module(self.cond_encoder)
881
+
882
+ for i in range(self.audio_injector.injector.__len__()):
883
+ self.audio_injector.injector[i].o = zero_module(
884
+ self.audio_injector.injector[i].o)
885
+ if self.enbale_adain:
886
+ self.audio_injector.injector_adain_layers[i].linear = \
887
+ zero_module(self.audio_injector.injector_adain_layers[i].linear)
videox_fun/models/wan_transformer3d_vace.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) Alibaba, Inc. and its affiliates.
4
+ from typing import Any, Dict
5
+
6
+ import os
7
+ import math
8
+ import torch
9
+ import torch.cuda.amp as amp
10
+ import torch.nn as nn
11
+ from diffusers.configuration_utils import register_to_config
12
+ from diffusers.utils import is_torch_version
13
+
14
+ from .wan_transformer3d import (WanAttentionBlock, WanTransformer3DModel,
15
+ sinusoidal_embedding_1d)
16
+
17
+
18
+ VIDEOX_OFFLOAD_VACE_LATENTS = os.environ.get("VIDEOX_OFFLOAD_VACE_LATENTS", False)
19
+
20
+ class VaceWanAttentionBlock(WanAttentionBlock):
21
+ def __init__(
22
+ self,
23
+ cross_attn_type,
24
+ dim,
25
+ ffn_dim,
26
+ num_heads,
27
+ window_size=(-1, -1),
28
+ qk_norm=True,
29
+ cross_attn_norm=False,
30
+ eps=1e-6,
31
+ block_id=0
32
+ ):
33
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
34
+ self.block_id = block_id
35
+ if block_id == 0:
36
+ self.before_proj = nn.Linear(self.dim, self.dim)
37
+ nn.init.zeros_(self.before_proj.weight)
38
+ nn.init.zeros_(self.before_proj.bias)
39
+ self.after_proj = nn.Linear(self.dim, self.dim)
40
+ nn.init.zeros_(self.after_proj.weight)
41
+ nn.init.zeros_(self.after_proj.bias)
42
+
43
+ def forward(self, c, x, **kwargs):
44
+ if self.block_id == 0:
45
+ c = self.before_proj(c) + x
46
+ all_c = []
47
+ else:
48
+ all_c = list(torch.unbind(c))
49
+ c = all_c.pop(-1)
50
+
51
+ if VIDEOX_OFFLOAD_VACE_LATENTS:
52
+ c = c.to(x.device)
53
+
54
+ c = super().forward(c, **kwargs)
55
+ c_skip = self.after_proj(c)
56
+
57
+ if VIDEOX_OFFLOAD_VACE_LATENTS:
58
+ c_skip = c_skip.to("cpu")
59
+ c = c.to("cpu")
60
+
61
+ all_c += [c_skip, c]
62
+ c = torch.stack(all_c)
63
+ return c
64
+
65
+
66
+ class BaseWanAttentionBlock(WanAttentionBlock):
67
+ def __init__(
68
+ self,
69
+ cross_attn_type,
70
+ dim,
71
+ ffn_dim,
72
+ num_heads,
73
+ window_size=(-1, -1),
74
+ qk_norm=True,
75
+ cross_attn_norm=False,
76
+ eps=1e-6,
77
+ block_id=None
78
+ ):
79
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
80
+ self.block_id = block_id
81
+
82
+ def forward(self, x, hints, context_scale=1.0, **kwargs):
83
+ x = super().forward(x, **kwargs)
84
+ if self.block_id is not None:
85
+ if VIDEOX_OFFLOAD_VACE_LATENTS:
86
+ x = x + hints[self.block_id].to(x.device) * context_scale
87
+ else:
88
+ x = x + hints[self.block_id] * context_scale
89
+ return x
90
+
91
+
92
+ class VaceWanTransformer3DModel(WanTransformer3DModel):
93
+ @register_to_config
94
+ def __init__(self,
95
+ vace_layers=None,
96
+ vace_in_dim=None,
97
+ model_type='t2v',
98
+ patch_size=(1, 2, 2),
99
+ text_len=512,
100
+ in_dim=16,
101
+ dim=2048,
102
+ ffn_dim=8192,
103
+ freq_dim=256,
104
+ text_dim=4096,
105
+ out_dim=16,
106
+ num_heads=16,
107
+ num_layers=32,
108
+ window_size=(-1, -1),
109
+ qk_norm=True,
110
+ cross_attn_norm=True,
111
+ eps=1e-6):
112
+ model_type = "t2v" # TODO: Hard code for both preview and official versions.
113
+ super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
114
+ num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
115
+
116
+ self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
117
+ self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
118
+
119
+ assert 0 in self.vace_layers
120
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
121
+
122
+ # blocks
123
+ self.blocks = nn.ModuleList([
124
+ BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
125
+ self.cross_attn_norm, self.eps,
126
+ block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
127
+ for i in range(self.num_layers)
128
+ ])
129
+
130
+ # vace blocks
131
+ self.vace_blocks = nn.ModuleList([
132
+ VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
133
+ self.cross_attn_norm, self.eps, block_id=i)
134
+ for i in self.vace_layers
135
+ ])
136
+
137
+ # vace patch embeddings
138
+ self.vace_patch_embedding = nn.Conv3d(
139
+ self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
140
+ )
141
+
142
+ def forward_vace(
143
+ self,
144
+ x,
145
+ vace_context,
146
+ seq_len,
147
+ kwargs
148
+ ):
149
+ # embeddings
150
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
151
+ c = [u.flatten(2).transpose(1, 2) for u in c]
152
+ c = torch.cat([
153
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
154
+ dim=1) for u in c
155
+ ])
156
+ # Context Parallel
157
+ if self.sp_world_size > 1:
158
+ c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
159
+
160
+ # arguments
161
+ new_kwargs = dict(x=x)
162
+ new_kwargs.update(kwargs)
163
+
164
+ for block in self.vace_blocks:
165
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
166
+ def create_custom_forward(module, **static_kwargs):
167
+ def custom_forward(*inputs):
168
+ return module(*inputs, **static_kwargs)
169
+ return custom_forward
170
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
171
+ c = torch.utils.checkpoint.checkpoint(
172
+ create_custom_forward(block, **new_kwargs),
173
+ c,
174
+ **ckpt_kwargs,
175
+ )
176
+ else:
177
+ c = block(c, **new_kwargs)
178
+ hints = torch.unbind(c)[:-1]
179
+ return hints
180
+
181
+ def forward(
182
+ self,
183
+ x,
184
+ t,
185
+ vace_context,
186
+ context,
187
+ seq_len,
188
+ vace_context_scale=1.0,
189
+ clip_fea=None,
190
+ y=None,
191
+ cond_flag=True
192
+ ):
193
+ r"""
194
+ Forward pass through the diffusion model
195
+
196
+ Args:
197
+ x (List[Tensor]):
198
+ List of input video tensors, each with shape [C_in, F, H, W]
199
+ t (Tensor):
200
+ Diffusion timesteps tensor of shape [B]
201
+ context (List[Tensor]):
202
+ List of text embeddings each with shape [L, C]
203
+ seq_len (`int`):
204
+ Maximum sequence length for positional encoding
205
+ clip_fea (Tensor, *optional*):
206
+ CLIP image features for image-to-video mode
207
+ y (List[Tensor], *optional*):
208
+ Conditional video inputs for image-to-video mode, same shape as x
209
+
210
+ Returns:
211
+ List[Tensor]:
212
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
213
+ """
214
+ # if self.model_type == 'i2v':
215
+ # assert clip_fea is not None and y is not None
216
+ # params
217
+ dtype = x.dtype
218
+ device = self.patch_embedding.weight.device
219
+ if self.freqs.device != device and torch.device(type="meta") != device:
220
+ self.freqs = self.freqs.to(device)
221
+
222
+ # if y is not None:
223
+ # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
224
+
225
+ # embeddings
226
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
227
+ grid_sizes = torch.stack(
228
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
229
+ x = [u.flatten(2).transpose(1, 2) for u in x]
230
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
231
+ if self.sp_world_size > 1:
232
+ seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
233
+ assert seq_lens.max() <= seq_len
234
+ x = torch.cat([
235
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
236
+ dim=1) for u in x
237
+ ])
238
+
239
+ # time embeddings
240
+ with amp.autocast(dtype=torch.float32):
241
+ e = self.time_embedding(
242
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
243
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
244
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
245
+
246
+ # context
247
+ context_lens = None
248
+ context = self.text_embedding(
249
+ torch.stack([
250
+ torch.cat(
251
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
252
+ for u in context
253
+ ]))
254
+
255
+ # Context Parallel
256
+ if self.sp_world_size > 1:
257
+ x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
258
+
259
+ # arguments
260
+ kwargs = dict(
261
+ e=e0,
262
+ seq_lens=seq_lens,
263
+ grid_sizes=grid_sizes,
264
+ freqs=self.freqs,
265
+ context=context,
266
+ context_lens=context_lens,
267
+ dtype=dtype,
268
+ t=t)
269
+ hints = self.forward_vace(x, vace_context, seq_len, kwargs)
270
+
271
+ kwargs['hints'] = hints
272
+ kwargs['context_scale'] = vace_context_scale
273
+
274
+ # TeaCache
275
+ if self.teacache is not None:
276
+ if cond_flag:
277
+ if t.dim() != 1:
278
+ modulated_inp = e0[:, -1, :]
279
+ else:
280
+ modulated_inp = e0
281
+ skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
282
+ if skip_flag:
283
+ self.should_calc = True
284
+ self.teacache.accumulated_rel_l1_distance = 0
285
+ else:
286
+ if cond_flag:
287
+ rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
288
+ self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
289
+ if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
290
+ self.should_calc = False
291
+ else:
292
+ self.should_calc = True
293
+ self.teacache.accumulated_rel_l1_distance = 0
294
+ self.teacache.previous_modulated_input = modulated_inp
295
+ self.teacache.should_calc = self.should_calc
296
+ else:
297
+ self.should_calc = self.teacache.should_calc
298
+
299
+ # TeaCache
300
+ if self.teacache is not None:
301
+ if not self.should_calc:
302
+ previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
303
+ x = x + previous_residual.to(x.device)[-x.size()[0]:,]
304
+ else:
305
+ ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
306
+
307
+ for block in self.blocks:
308
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
309
+ def create_custom_forward(module, **static_kwargs):
310
+ def custom_forward(*inputs):
311
+ return module(*inputs, **static_kwargs)
312
+ return custom_forward
313
+ extra_kwargs = {
314
+ 'e': e0,
315
+ 'seq_lens': seq_lens,
316
+ 'grid_sizes': grid_sizes,
317
+ 'freqs': self.freqs,
318
+ 'context': context,
319
+ 'context_lens': context_lens,
320
+ 'dtype': dtype,
321
+ 't': t,
322
+ }
323
+
324
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
325
+
326
+ x = torch.utils.checkpoint.checkpoint(
327
+ create_custom_forward(block, **extra_kwargs),
328
+ x,
329
+ hints,
330
+ vace_context_scale,
331
+ **ckpt_kwargs,
332
+ )
333
+ else:
334
+ x = block(x, **kwargs)
335
+
336
+ if cond_flag:
337
+ self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
338
+ else:
339
+ self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
340
+ else:
341
+ for block in self.blocks:
342
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
343
+ def create_custom_forward(module, **static_kwargs):
344
+ def custom_forward(*inputs):
345
+ return module(*inputs, **static_kwargs)
346
+ return custom_forward
347
+ extra_kwargs = {
348
+ 'e': e0,
349
+ 'seq_lens': seq_lens,
350
+ 'grid_sizes': grid_sizes,
351
+ 'freqs': self.freqs,
352
+ 'context': context,
353
+ 'context_lens': context_lens,
354
+ 'dtype': dtype,
355
+ 't': t,
356
+ }
357
+
358
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
359
+
360
+ x = torch.utils.checkpoint.checkpoint(
361
+ create_custom_forward(block, **extra_kwargs),
362
+ x,
363
+ hints,
364
+ vace_context_scale,
365
+ **ckpt_kwargs,
366
+ )
367
+ else:
368
+ x = block(x, **kwargs)
369
+
370
+ # head
371
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
372
+ def create_custom_forward(module):
373
+ def custom_forward(*inputs):
374
+ return module(*inputs)
375
+
376
+ return custom_forward
377
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
378
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
379
+ else:
380
+ x = self.head(x, e)
381
+
382
+ if self.sp_world_size > 1:
383
+ x = self.all_gather(x, dim=1)
384
+
385
+ # unpatchify
386
+ x = self.unpatchify(x, grid_sizes)
387
+ x = torch.stack(x)
388
+ if self.teacache is not None and cond_flag:
389
+ self.teacache.cnt += 1
390
+ if self.teacache.cnt == self.teacache.num_steps:
391
+ self.teacache.reset()
392
+ return x
videox_fun/models/wan_vae.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
10
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
11
+ DiagonalGaussianDistribution)
12
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.utils.accelerate_utils import apply_forward_hook
15
+ from einops import rearrange
16
+
17
+
18
+ CACHE_T = 2
19
+
20
+
21
+ class CausalConv3d(nn.Conv3d):
22
+ """
23
+ Causal 3d convolusion.
24
+ """
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
29
+ self.padding[1], 2 * self.padding[0], 0)
30
+ self.padding = (0, 0, 0)
31
+
32
+ def forward(self, x, cache_x=None):
33
+ padding = list(self._padding)
34
+ if cache_x is not None and self._padding[4] > 0:
35
+ cache_x = cache_x.to(x.device)
36
+ x = torch.cat([cache_x, x], dim=2)
37
+ padding[4] -= cache_x.shape[2]
38
+ x = F.pad(x, padding)
39
+
40
+ return super().forward(x)
41
+
42
+
43
+ class RMS_norm(nn.Module):
44
+
45
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
46
+ super().__init__()
47
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
48
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
49
+
50
+ self.channel_first = channel_first
51
+ self.scale = dim**0.5
52
+ self.gamma = nn.Parameter(torch.ones(shape))
53
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
54
+
55
+ def forward(self, x):
56
+ return F.normalize(
57
+ x, dim=(1 if self.channel_first else
58
+ -1)) * self.scale * self.gamma + self.bias
59
+
60
+
61
+ class Upsample(nn.Upsample):
62
+
63
+ def forward(self, x):
64
+ """
65
+ Fix bfloat16 support for nearest neighbor interpolation.
66
+ """
67
+ return super().forward(x.float()).type_as(x)
68
+
69
+
70
+ class Resample(nn.Module):
71
+
72
+ def __init__(self, dim, mode):
73
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
74
+ 'downsample3d')
75
+ super().__init__()
76
+ self.dim = dim
77
+ self.mode = mode
78
+
79
+ # layers
80
+ if mode == 'upsample2d':
81
+ self.resample = nn.Sequential(
82
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
84
+ elif mode == 'upsample3d':
85
+ self.resample = nn.Sequential(
86
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
87
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
88
+ self.time_conv = CausalConv3d(
89
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
90
+
91
+ elif mode == 'downsample2d':
92
+ self.resample = nn.Sequential(
93
+ nn.ZeroPad2d((0, 1, 0, 1)),
94
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
+ elif mode == 'downsample3d':
96
+ self.resample = nn.Sequential(
97
+ nn.ZeroPad2d((0, 1, 0, 1)),
98
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
99
+ self.time_conv = CausalConv3d(
100
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
101
+
102
+ else:
103
+ self.resample = nn.Identity()
104
+
105
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
106
+ b, c, t, h, w = x.size()
107
+ if self.mode == 'upsample3d':
108
+ if feat_cache is not None:
109
+ idx = feat_idx[0]
110
+ if feat_cache[idx] is None:
111
+ feat_cache[idx] = 'Rep'
112
+ feat_idx[0] += 1
113
+ else:
114
+
115
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
116
+ if cache_x.shape[2] < 2 and feat_cache[
117
+ idx] is not None and feat_cache[idx] != 'Rep':
118
+ # cache last frame of last two chunk
119
+ cache_x = torch.cat([
120
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
121
+ cache_x.device), cache_x
122
+ ],
123
+ dim=2)
124
+ if cache_x.shape[2] < 2 and feat_cache[
125
+ idx] is not None and feat_cache[idx] == 'Rep':
126
+ cache_x = torch.cat([
127
+ torch.zeros_like(cache_x).to(cache_x.device),
128
+ cache_x
129
+ ],
130
+ dim=2)
131
+ if feat_cache[idx] == 'Rep':
132
+ x = self.time_conv(x)
133
+ else:
134
+ x = self.time_conv(x, feat_cache[idx])
135
+ feat_cache[idx] = cache_x
136
+ feat_idx[0] += 1
137
+
138
+ x = x.reshape(b, 2, c, t, h, w)
139
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
140
+ 3)
141
+ x = x.reshape(b, c, t * 2, h, w)
142
+ t = x.shape[2]
143
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
144
+ x = self.resample(x)
145
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
146
+
147
+ if self.mode == 'downsample3d':
148
+ if feat_cache is not None:
149
+ idx = feat_idx[0]
150
+ if feat_cache[idx] is None:
151
+ feat_cache[idx] = x.clone()
152
+ feat_idx[0] += 1
153
+ else:
154
+
155
+ cache_x = x[:, :, -1:, :, :].clone()
156
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
157
+ # # cache last frame of last two chunk
158
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
159
+
160
+ x = self.time_conv(
161
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
162
+ feat_cache[idx] = cache_x
163
+ feat_idx[0] += 1
164
+ return x
165
+
166
+ def init_weight(self, conv):
167
+ conv_weight = conv.weight
168
+ nn.init.zeros_(conv_weight)
169
+ c1, c2, t, h, w = conv_weight.size()
170
+ one_matrix = torch.eye(c1, c2)
171
+ init_matrix = one_matrix
172
+ nn.init.zeros_(conv_weight)
173
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
174
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
175
+ conv.weight.data.copy_(conv_weight)
176
+ nn.init.zeros_(conv.bias.data)
177
+
178
+ def init_weight2(self, conv):
179
+ conv_weight = conv.weight.data
180
+ nn.init.zeros_(conv_weight)
181
+ c1, c2, t, h, w = conv_weight.size()
182
+ init_matrix = torch.eye(c1 // 2, c2)
183
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
184
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
185
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
186
+ conv.weight.data.copy_(conv_weight)
187
+ nn.init.zeros_(conv.bias.data)
188
+
189
+
190
+ class ResidualBlock(nn.Module):
191
+
192
+ def __init__(self, in_dim, out_dim, dropout=0.0):
193
+ super().__init__()
194
+ self.in_dim = in_dim
195
+ self.out_dim = out_dim
196
+
197
+ # layers
198
+ self.residual = nn.Sequential(
199
+ RMS_norm(in_dim, images=False), nn.SiLU(),
200
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
201
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
202
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
203
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
204
+ if in_dim != out_dim else nn.Identity()
205
+
206
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
207
+ h = self.shortcut(x)
208
+ for layer in self.residual:
209
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
210
+ idx = feat_idx[0]
211
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
212
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
213
+ # cache last frame of last two chunk
214
+ cache_x = torch.cat([
215
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
216
+ cache_x.device), cache_x
217
+ ],
218
+ dim=2)
219
+ x = layer(x, feat_cache[idx])
220
+ feat_cache[idx] = cache_x
221
+ feat_idx[0] += 1
222
+ else:
223
+ x = layer(x)
224
+ return x + h
225
+
226
+
227
+ class AttentionBlock(nn.Module):
228
+ """
229
+ Causal self-attention with a single head.
230
+ """
231
+
232
+ def __init__(self, dim):
233
+ super().__init__()
234
+ self.dim = dim
235
+
236
+ # layers
237
+ self.norm = RMS_norm(dim)
238
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
239
+ self.proj = nn.Conv2d(dim, dim, 1)
240
+
241
+ # zero out the last layer params
242
+ nn.init.zeros_(self.proj.weight)
243
+
244
+ def forward(self, x):
245
+ identity = x
246
+ b, c, t, h, w = x.size()
247
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
248
+ x = self.norm(x)
249
+ # compute query, key, value
250
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
251
+ -1).permute(0, 1, 3,
252
+ 2).contiguous().chunk(
253
+ 3, dim=-1)
254
+
255
+ # apply attention
256
+ x = F.scaled_dot_product_attention(
257
+ q,
258
+ k,
259
+ v,
260
+ )
261
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
262
+
263
+ # output
264
+ x = self.proj(x)
265
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
266
+ return x + identity
267
+
268
+
269
+ class Encoder3d(nn.Module):
270
+
271
+ def __init__(self,
272
+ dim=128,
273
+ z_dim=4,
274
+ dim_mult=[1, 2, 4, 4],
275
+ num_res_blocks=2,
276
+ attn_scales=[],
277
+ temperal_downsample=[True, True, False],
278
+ dropout=0.0):
279
+ super().__init__()
280
+ self.dim = dim
281
+ self.z_dim = z_dim
282
+ self.dim_mult = dim_mult
283
+ self.num_res_blocks = num_res_blocks
284
+ self.attn_scales = attn_scales
285
+ self.temperal_downsample = temperal_downsample
286
+
287
+ # dimensions
288
+ dims = [dim * u for u in [1] + dim_mult]
289
+ scale = 1.0
290
+
291
+ # init block
292
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
293
+
294
+ # downsample blocks
295
+ downsamples = []
296
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
297
+ # residual (+attention) blocks
298
+ for _ in range(num_res_blocks):
299
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
300
+ if scale in attn_scales:
301
+ downsamples.append(AttentionBlock(out_dim))
302
+ in_dim = out_dim
303
+
304
+ # downsample block
305
+ if i != len(dim_mult) - 1:
306
+ mode = 'downsample3d' if temperal_downsample[
307
+ i] else 'downsample2d'
308
+ downsamples.append(Resample(out_dim, mode=mode))
309
+ scale /= 2.0
310
+ self.downsamples = nn.Sequential(*downsamples)
311
+
312
+ # middle blocks
313
+ self.middle = nn.Sequential(
314
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
315
+ ResidualBlock(out_dim, out_dim, dropout))
316
+
317
+ # output blocks
318
+ self.head = nn.Sequential(
319
+ RMS_norm(out_dim, images=False), nn.SiLU(),
320
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
321
+
322
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
323
+ if feat_cache is not None:
324
+ idx = feat_idx[0]
325
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
326
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
327
+ # cache last frame of last two chunk
328
+ cache_x = torch.cat([
329
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
330
+ cache_x.device), cache_x
331
+ ],
332
+ dim=2)
333
+ x = self.conv1(x, feat_cache[idx])
334
+ feat_cache[idx] = cache_x
335
+ feat_idx[0] += 1
336
+ else:
337
+ x = self.conv1(x)
338
+
339
+ ## downsamples
340
+ for layer in self.downsamples:
341
+ if feat_cache is not None:
342
+ x = layer(x, feat_cache, feat_idx)
343
+ else:
344
+ x = layer(x)
345
+
346
+ ## middle
347
+ for layer in self.middle:
348
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
349
+ x = layer(x, feat_cache, feat_idx)
350
+ else:
351
+ x = layer(x)
352
+
353
+ ## head
354
+ for layer in self.head:
355
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
356
+ idx = feat_idx[0]
357
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
358
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
359
+ # cache last frame of last two chunk
360
+ cache_x = torch.cat([
361
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
362
+ cache_x.device), cache_x
363
+ ],
364
+ dim=2)
365
+ x = layer(x, feat_cache[idx])
366
+ feat_cache[idx] = cache_x
367
+ feat_idx[0] += 1
368
+ else:
369
+ x = layer(x)
370
+ return x
371
+
372
+
373
+ class Decoder3d(nn.Module):
374
+
375
+ def __init__(self,
376
+ dim=128,
377
+ z_dim=4,
378
+ dim_mult=[1, 2, 4, 4],
379
+ num_res_blocks=2,
380
+ attn_scales=[],
381
+ temperal_upsample=[False, True, True],
382
+ dropout=0.0):
383
+ super().__init__()
384
+ self.dim = dim
385
+ self.z_dim = z_dim
386
+ self.dim_mult = dim_mult
387
+ self.num_res_blocks = num_res_blocks
388
+ self.attn_scales = attn_scales
389
+ self.temperal_upsample = temperal_upsample
390
+
391
+ # dimensions
392
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
393
+ scale = 1.0 / 2**(len(dim_mult) - 2)
394
+
395
+ # init block
396
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
397
+
398
+ # middle blocks
399
+ self.middle = nn.Sequential(
400
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
401
+ ResidualBlock(dims[0], dims[0], dropout))
402
+
403
+ # upsample blocks
404
+ upsamples = []
405
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
406
+ # residual (+attention) blocks
407
+ if i == 1 or i == 2 or i == 3:
408
+ in_dim = in_dim // 2
409
+ for _ in range(num_res_blocks + 1):
410
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
411
+ if scale in attn_scales:
412
+ upsamples.append(AttentionBlock(out_dim))
413
+ in_dim = out_dim
414
+
415
+ # upsample block
416
+ if i != len(dim_mult) - 1:
417
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
418
+ upsamples.append(Resample(out_dim, mode=mode))
419
+ scale *= 2.0
420
+ self.upsamples = nn.Sequential(*upsamples)
421
+
422
+ # output blocks
423
+ self.head = nn.Sequential(
424
+ RMS_norm(out_dim, images=False), nn.SiLU(),
425
+ CausalConv3d(out_dim, 3, 3, padding=1))
426
+
427
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
428
+ ## conv1
429
+ if feat_cache is not None:
430
+ idx = feat_idx[0]
431
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
432
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
433
+ # cache last frame of last two chunk
434
+ cache_x = torch.cat([
435
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
436
+ cache_x.device), cache_x
437
+ ],
438
+ dim=2)
439
+ x = self.conv1(x, feat_cache[idx])
440
+ feat_cache[idx] = cache_x
441
+ feat_idx[0] += 1
442
+ else:
443
+ x = self.conv1(x)
444
+
445
+ ## middle
446
+ for layer in self.middle:
447
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
448
+ x = layer(x, feat_cache, feat_idx)
449
+ else:
450
+ x = layer(x)
451
+
452
+ ## upsamples
453
+ for layer in self.upsamples:
454
+ if feat_cache is not None:
455
+ x = layer(x, feat_cache, feat_idx)
456
+ else:
457
+ x = layer(x)
458
+
459
+ ## head
460
+ for layer in self.head:
461
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
462
+ idx = feat_idx[0]
463
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
464
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
465
+ # cache last frame of last two chunk
466
+ cache_x = torch.cat([
467
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
468
+ cache_x.device), cache_x
469
+ ],
470
+ dim=2)
471
+ x = layer(x, feat_cache[idx])
472
+ feat_cache[idx] = cache_x
473
+ feat_idx[0] += 1
474
+ else:
475
+ x = layer(x)
476
+ return x
477
+
478
+
479
+ def count_conv3d(model):
480
+ count = 0
481
+ for m in model.modules():
482
+ if isinstance(m, CausalConv3d):
483
+ count += 1
484
+ return count
485
+
486
+
487
+ class AutoencoderKLWan_(nn.Module):
488
+
489
+ def __init__(self,
490
+ dim=128,
491
+ z_dim=4,
492
+ dim_mult=[1, 2, 4, 4],
493
+ num_res_blocks=2,
494
+ attn_scales=[],
495
+ temperal_downsample=[True, True, False],
496
+ dropout=0.0):
497
+ super().__init__()
498
+ self.dim = dim
499
+ self.z_dim = z_dim
500
+ self.dim_mult = dim_mult
501
+ self.num_res_blocks = num_res_blocks
502
+ self.attn_scales = attn_scales
503
+ self.temperal_downsample = temperal_downsample
504
+ self.temperal_upsample = temperal_downsample[::-1]
505
+
506
+ # modules
507
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
508
+ attn_scales, self.temperal_downsample, dropout)
509
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
510
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
511
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
512
+ attn_scales, self.temperal_upsample, dropout)
513
+
514
+ def forward(self, x):
515
+ mu, log_var = self.encode(x)
516
+ z = self.reparameterize(mu, log_var)
517
+ x_recon = self.decode(z)
518
+ return x_recon, mu, log_var
519
+
520
+ def encode(self, x, scale):
521
+ self.clear_cache()
522
+ ## cache
523
+ t = x.shape[2]
524
+ iter_ = 1 + (t - 1) // 4
525
+ scale = [item.to(x.device, x.dtype) for item in scale]
526
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
527
+ for i in range(iter_):
528
+ self._enc_conv_idx = [0]
529
+ if i == 0:
530
+ out = self.encoder(
531
+ x[:, :, :1, :, :],
532
+ feat_cache=self._enc_feat_map,
533
+ feat_idx=self._enc_conv_idx)
534
+ else:
535
+ out_ = self.encoder(
536
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
537
+ feat_cache=self._enc_feat_map,
538
+ feat_idx=self._enc_conv_idx)
539
+ out = torch.cat([out, out_], 2)
540
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
541
+ if isinstance(scale[0], torch.Tensor):
542
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
543
+ 1, self.z_dim, 1, 1, 1)
544
+ else:
545
+ mu = (mu - scale[0]) * scale[1]
546
+ x = torch.cat([mu, log_var], dim = 1)
547
+ self.clear_cache()
548
+ return x
549
+
550
+ def decode(self, z, scale):
551
+ self.clear_cache()
552
+ # z: [b,c,t,h,w]
553
+ scale = [item.to(z.device, z.dtype) for item in scale]
554
+ if isinstance(scale[0], torch.Tensor):
555
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
556
+ 1, self.z_dim, 1, 1, 1)
557
+ else:
558
+ z = z / scale[1] + scale[0]
559
+ iter_ = z.shape[2]
560
+ x = self.conv2(z)
561
+ for i in range(iter_):
562
+ self._conv_idx = [0]
563
+ if i == 0:
564
+ out = self.decoder(
565
+ x[:, :, i:i + 1, :, :],
566
+ feat_cache=self._feat_map,
567
+ feat_idx=self._conv_idx)
568
+ else:
569
+ out_ = self.decoder(
570
+ x[:, :, i:i + 1, :, :],
571
+ feat_cache=self._feat_map,
572
+ feat_idx=self._conv_idx)
573
+ out = torch.cat([out, out_], 2)
574
+ self.clear_cache()
575
+ return out
576
+
577
+ def reparameterize(self, mu, log_var):
578
+ std = torch.exp(0.5 * log_var)
579
+ eps = torch.randn_like(std)
580
+ return eps * std + mu
581
+
582
+ def sample(self, imgs, deterministic=False):
583
+ mu, log_var = self.encode(imgs)
584
+ if deterministic:
585
+ return mu
586
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
587
+ return mu + std * torch.randn_like(std)
588
+
589
+ def clear_cache(self):
590
+ self._conv_num = count_conv3d(self.decoder)
591
+ self._conv_idx = [0]
592
+ self._feat_map = [None] * self._conv_num
593
+ #cache encode
594
+ self._enc_conv_num = count_conv3d(self.encoder)
595
+ self._enc_conv_idx = [0]
596
+ self._enc_feat_map = [None] * self._enc_conv_num
597
+
598
+
599
+ def _video_vae(z_dim=None, **kwargs):
600
+ """
601
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
602
+ """
603
+ # params
604
+ cfg = dict(
605
+ dim=96,
606
+ z_dim=z_dim,
607
+ dim_mult=[1, 2, 4, 4],
608
+ num_res_blocks=2,
609
+ attn_scales=[],
610
+ temperal_downsample=[False, True, True],
611
+ dropout=0.0)
612
+ cfg.update(**kwargs)
613
+
614
+ # init model
615
+ model = AutoencoderKLWan_(**cfg)
616
+
617
+ return model
618
+
619
+
620
+ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
621
+
622
+ @register_to_config
623
+ def __init__(
624
+ self,
625
+ latent_channels=16,
626
+ temporal_compression_ratio=4,
627
+ spatial_compression_ratio=8
628
+ ):
629
+ super().__init__()
630
+ mean = [
631
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
632
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
633
+ ]
634
+ std = [
635
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
636
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
637
+ ]
638
+ self.mean = torch.tensor(mean, dtype=torch.float32)
639
+ self.std = torch.tensor(std, dtype=torch.float32)
640
+ self.scale = [self.mean, 1.0 / self.std]
641
+
642
+ # init model
643
+ self.model = _video_vae(
644
+ z_dim=latent_channels,
645
+ )
646
+
647
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
648
+ x = [
649
+ self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
650
+ for u in x
651
+ ]
652
+ x = torch.stack(x)
653
+ return x
654
+
655
+ @apply_forward_hook
656
+ def encode(
657
+ self, x: torch.Tensor, return_dict: bool = True
658
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
659
+ h = self._encode(x)
660
+
661
+ posterior = DiagonalGaussianDistribution(h)
662
+
663
+ if not return_dict:
664
+ return (posterior,)
665
+ return AutoencoderKLOutput(latent_dist=posterior)
666
+
667
+ def _decode(self, zs):
668
+ dec = [
669
+ self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
670
+ for u in zs
671
+ ]
672
+ dec = torch.stack(dec)
673
+
674
+ return DecoderOutput(sample=dec)
675
+
676
+ @apply_forward_hook
677
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
678
+ decoded = self._decode(z).sample
679
+
680
+ if not return_dict:
681
+ return (decoded,)
682
+ return DecoderOutput(sample=decoded)
683
+
684
+ @classmethod
685
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
686
+ def filter_kwargs(cls, kwargs):
687
+ import inspect
688
+ sig = inspect.signature(cls.__init__)
689
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
690
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
691
+ return filtered_kwargs
692
+
693
+ model = cls(**filter_kwargs(cls, additional_kwargs))
694
+ if pretrained_model_path.endswith(".safetensors"):
695
+ from safetensors.torch import load_file, safe_open
696
+ state_dict = load_file(pretrained_model_path)
697
+ else:
698
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
699
+ tmp_state_dict = {}
700
+ for key in state_dict:
701
+ tmp_state_dict["model." + key] = state_dict[key]
702
+ state_dict = tmp_state_dict
703
+ m, u = model.load_state_dict(state_dict, strict=False)
704
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
705
+ print(m, u)
706
+ return model
videox_fun/models/wan_vae3_8.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ import torch.cuda.amp as amp
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
12
+ DiagonalGaussianDistribution)
13
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.utils.accelerate_utils import apply_forward_hook
16
+ from einops import rearrange
17
+
18
+
19
+ CACHE_T = 2
20
+
21
+
22
+ class CausalConv3d(nn.Conv3d):
23
+ """
24
+ Causal 3d convolusion.
25
+ """
26
+
27
+ def __init__(self, *args, **kwargs):
28
+ super().__init__(*args, **kwargs)
29
+ self._padding = (
30
+ self.padding[2],
31
+ self.padding[2],
32
+ self.padding[1],
33
+ self.padding[1],
34
+ 2 * self.padding[0],
35
+ 0,
36
+ )
37
+ self.padding = (0, 0, 0)
38
+
39
+ def forward(self, x, cache_x=None):
40
+ padding = list(self._padding)
41
+ if cache_x is not None and self._padding[4] > 0:
42
+ cache_x = cache_x.to(x.device)
43
+ x = torch.cat([cache_x, x], dim=2)
44
+ padding[4] -= cache_x.shape[2]
45
+ x = F.pad(x, padding)
46
+
47
+ return super().forward(x)
48
+
49
+
50
+ class RMS_norm(nn.Module):
51
+
52
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
53
+ super().__init__()
54
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
55
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
56
+
57
+ self.channel_first = channel_first
58
+ self.scale = dim**0.5
59
+ self.gamma = nn.Parameter(torch.ones(shape))
60
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
61
+
62
+ def forward(self, x):
63
+ return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
64
+ self.scale * self.gamma + self.bias)
65
+
66
+
67
+ class Upsample(nn.Upsample):
68
+
69
+ def forward(self, x):
70
+ """
71
+ Fix bfloat16 support for nearest neighbor interpolation.
72
+ """
73
+ return super().forward(x.float()).type_as(x)
74
+
75
+
76
+ class Resample(nn.Module):
77
+
78
+ def __init__(self, dim, mode):
79
+ assert mode in (
80
+ "none",
81
+ "upsample2d",
82
+ "upsample3d",
83
+ "downsample2d",
84
+ "downsample3d",
85
+ )
86
+ super().__init__()
87
+ self.dim = dim
88
+ self.mode = mode
89
+
90
+ # layers
91
+ if mode == "upsample2d":
92
+ self.resample = nn.Sequential(
93
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
94
+ nn.Conv2d(dim, dim, 3, padding=1),
95
+ )
96
+ elif mode == "upsample3d":
97
+ self.resample = nn.Sequential(
98
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
99
+ nn.Conv2d(dim, dim, 3, padding=1),
100
+ # nn.Conv2d(dim, dim//2, 3, padding=1)
101
+ )
102
+ self.time_conv = CausalConv3d(
103
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
104
+ elif mode == "downsample2d":
105
+ self.resample = nn.Sequential(
106
+ nn.ZeroPad2d((0, 1, 0, 1)),
107
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108
+ elif mode == "downsample3d":
109
+ self.resample = nn.Sequential(
110
+ nn.ZeroPad2d((0, 1, 0, 1)),
111
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
112
+ self.time_conv = CausalConv3d(
113
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
114
+ else:
115
+ self.resample = nn.Identity()
116
+
117
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
118
+ b, c, t, h, w = x.size()
119
+ if self.mode == "upsample3d":
120
+ if feat_cache is not None:
121
+ idx = feat_idx[0]
122
+ if feat_cache[idx] is None:
123
+ feat_cache[idx] = "Rep"
124
+ feat_idx[0] += 1
125
+ else:
126
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
127
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
128
+ feat_cache[idx] != "Rep"):
129
+ # cache last frame of last two chunk
130
+ cache_x = torch.cat(
131
+ [
132
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
133
+ cache_x.device),
134
+ cache_x,
135
+ ],
136
+ dim=2,
137
+ )
138
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
139
+ feat_cache[idx] == "Rep"):
140
+ cache_x = torch.cat(
141
+ [
142
+ torch.zeros_like(cache_x).to(cache_x.device),
143
+ cache_x
144
+ ],
145
+ dim=2,
146
+ )
147
+ if feat_cache[idx] == "Rep":
148
+ x = self.time_conv(x)
149
+ else:
150
+ x = self.time_conv(x, feat_cache[idx])
151
+ feat_cache[idx] = cache_x
152
+ feat_idx[0] += 1
153
+ x = x.reshape(b, 2, c, t, h, w)
154
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
155
+ 3)
156
+ x = x.reshape(b, c, t * 2, h, w)
157
+ t = x.shape[2]
158
+ x = rearrange(x, "b c t h w -> (b t) c h w")
159
+ x = self.resample(x)
160
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
161
+
162
+ if self.mode == "downsample3d":
163
+ if feat_cache is not None:
164
+ idx = feat_idx[0]
165
+ if feat_cache[idx] is None:
166
+ feat_cache[idx] = x.clone()
167
+ feat_idx[0] += 1
168
+ else:
169
+ cache_x = x[:, :, -1:, :, :].clone()
170
+ x = self.time_conv(
171
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
172
+ feat_cache[idx] = cache_x
173
+ feat_idx[0] += 1
174
+ return x
175
+
176
+ def init_weight(self, conv):
177
+ conv_weight = conv.weight.detach().clone()
178
+ nn.init.zeros_(conv_weight)
179
+ c1, c2, t, h, w = conv_weight.size()
180
+ one_matrix = torch.eye(c1, c2)
181
+ init_matrix = one_matrix
182
+ nn.init.zeros_(conv_weight)
183
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
184
+ conv.weight = nn.Parameter(conv_weight)
185
+ nn.init.zeros_(conv.bias.data)
186
+
187
+ def init_weight2(self, conv):
188
+ conv_weight = conv.weight.data.detach().clone()
189
+ nn.init.zeros_(conv_weight)
190
+ c1, c2, t, h, w = conv_weight.size()
191
+ init_matrix = torch.eye(c1 // 2, c2)
192
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
193
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
194
+ conv.weight = nn.Parameter(conv_weight)
195
+ nn.init.zeros_(conv.bias.data)
196
+
197
+
198
+ class ResidualBlock(nn.Module):
199
+
200
+ def __init__(self, in_dim, out_dim, dropout=0.0):
201
+ super().__init__()
202
+ self.in_dim = in_dim
203
+ self.out_dim = out_dim
204
+
205
+ # layers
206
+ self.residual = nn.Sequential(
207
+ RMS_norm(in_dim, images=False),
208
+ nn.SiLU(),
209
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
210
+ RMS_norm(out_dim, images=False),
211
+ nn.SiLU(),
212
+ nn.Dropout(dropout),
213
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
214
+ )
215
+ self.shortcut = (
216
+ CausalConv3d(in_dim, out_dim, 1)
217
+ if in_dim != out_dim else nn.Identity())
218
+
219
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
220
+ h = self.shortcut(x)
221
+ for layer in self.residual:
222
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
223
+ idx = feat_idx[0]
224
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
225
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
226
+ # cache last frame of last two chunk
227
+ cache_x = torch.cat(
228
+ [
229
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
230
+ cache_x.device),
231
+ cache_x,
232
+ ],
233
+ dim=2,
234
+ )
235
+ x = layer(x, feat_cache[idx])
236
+ feat_cache[idx] = cache_x
237
+ feat_idx[0] += 1
238
+ else:
239
+ x = layer(x)
240
+ return x + h
241
+
242
+
243
+ class AttentionBlock(nn.Module):
244
+ """
245
+ Causal self-attention with a single head.
246
+ """
247
+
248
+ def __init__(self, dim):
249
+ super().__init__()
250
+ self.dim = dim
251
+
252
+ # layers
253
+ self.norm = RMS_norm(dim)
254
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
255
+ self.proj = nn.Conv2d(dim, dim, 1)
256
+
257
+ # zero out the last layer params
258
+ nn.init.zeros_(self.proj.weight)
259
+
260
+ def forward(self, x):
261
+ identity = x
262
+ b, c, t, h, w = x.size()
263
+ x = rearrange(x, "b c t h w -> (b t) c h w")
264
+ x = self.norm(x)
265
+ # compute query, key, value
266
+ q, k, v = (
267
+ self.to_qkv(x).reshape(b * t, 1, c * 3,
268
+ -1).permute(0, 1, 3,
269
+ 2).contiguous().chunk(3, dim=-1))
270
+
271
+ # apply attention
272
+ x = F.scaled_dot_product_attention(
273
+ q,
274
+ k,
275
+ v,
276
+ )
277
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
278
+
279
+ # output
280
+ x = self.proj(x)
281
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
282
+ return x + identity
283
+
284
+
285
+ def patchify(x, patch_size):
286
+ if patch_size == 1:
287
+ return x
288
+ if x.dim() == 4:
289
+ x = rearrange(
290
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
291
+ elif x.dim() == 5:
292
+ x = rearrange(
293
+ x,
294
+ "b c f (h q) (w r) -> b (c r q) f h w",
295
+ q=patch_size,
296
+ r=patch_size,
297
+ )
298
+ else:
299
+ raise ValueError(f"Invalid input shape: {x.shape}")
300
+
301
+ return x
302
+
303
+
304
+ def unpatchify(x, patch_size):
305
+ if patch_size == 1:
306
+ return x
307
+
308
+ if x.dim() == 4:
309
+ x = rearrange(
310
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
311
+ elif x.dim() == 5:
312
+ x = rearrange(
313
+ x,
314
+ "b (c r q) f h w -> b c f (h q) (w r)",
315
+ q=patch_size,
316
+ r=patch_size,
317
+ )
318
+ return x
319
+
320
+
321
+ class AvgDown3D(nn.Module):
322
+
323
+ def __init__(
324
+ self,
325
+ in_channels,
326
+ out_channels,
327
+ factor_t,
328
+ factor_s=1,
329
+ ):
330
+ super().__init__()
331
+ self.in_channels = in_channels
332
+ self.out_channels = out_channels
333
+ self.factor_t = factor_t
334
+ self.factor_s = factor_s
335
+ self.factor = self.factor_t * self.factor_s * self.factor_s
336
+
337
+ assert in_channels * self.factor % out_channels == 0
338
+ self.group_size = in_channels * self.factor // out_channels
339
+
340
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
341
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
342
+ pad = (0, 0, 0, 0, pad_t, 0)
343
+ x = F.pad(x, pad)
344
+ B, C, T, H, W = x.shape
345
+ x = x.view(
346
+ B,
347
+ C,
348
+ T // self.factor_t,
349
+ self.factor_t,
350
+ H // self.factor_s,
351
+ self.factor_s,
352
+ W // self.factor_s,
353
+ self.factor_s,
354
+ )
355
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
356
+ x = x.view(
357
+ B,
358
+ C * self.factor,
359
+ T // self.factor_t,
360
+ H // self.factor_s,
361
+ W // self.factor_s,
362
+ )
363
+ x = x.view(
364
+ B,
365
+ self.out_channels,
366
+ self.group_size,
367
+ T // self.factor_t,
368
+ H // self.factor_s,
369
+ W // self.factor_s,
370
+ )
371
+ x = x.mean(dim=2)
372
+ return x
373
+
374
+
375
+ class DupUp3D(nn.Module):
376
+
377
+ def __init__(
378
+ self,
379
+ in_channels: int,
380
+ out_channels: int,
381
+ factor_t,
382
+ factor_s=1,
383
+ ):
384
+ super().__init__()
385
+ self.in_channels = in_channels
386
+ self.out_channels = out_channels
387
+
388
+ self.factor_t = factor_t
389
+ self.factor_s = factor_s
390
+ self.factor = self.factor_t * self.factor_s * self.factor_s
391
+
392
+ assert out_channels * self.factor % in_channels == 0
393
+ self.repeats = out_channels * self.factor // in_channels
394
+
395
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
396
+ x = x.repeat_interleave(self.repeats, dim=1)
397
+ x = x.view(
398
+ x.size(0),
399
+ self.out_channels,
400
+ self.factor_t,
401
+ self.factor_s,
402
+ self.factor_s,
403
+ x.size(2),
404
+ x.size(3),
405
+ x.size(4),
406
+ )
407
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
408
+ x = x.view(
409
+ x.size(0),
410
+ self.out_channels,
411
+ x.size(2) * self.factor_t,
412
+ x.size(4) * self.factor_s,
413
+ x.size(6) * self.factor_s,
414
+ )
415
+ if first_chunk:
416
+ x = x[:, :, self.factor_t - 1:, :, :]
417
+ return x
418
+
419
+
420
+ class Down_ResidualBlock(nn.Module):
421
+
422
+ def __init__(self,
423
+ in_dim,
424
+ out_dim,
425
+ dropout,
426
+ mult,
427
+ temperal_downsample=False,
428
+ down_flag=False):
429
+ super().__init__()
430
+
431
+ # Shortcut path with downsample
432
+ self.avg_shortcut = AvgDown3D(
433
+ in_dim,
434
+ out_dim,
435
+ factor_t=2 if temperal_downsample else 1,
436
+ factor_s=2 if down_flag else 1,
437
+ )
438
+
439
+ # Main path with residual blocks and downsample
440
+ downsamples = []
441
+ for _ in range(mult):
442
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
443
+ in_dim = out_dim
444
+
445
+ # Add the final downsample block
446
+ if down_flag:
447
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
448
+ downsamples.append(Resample(out_dim, mode=mode))
449
+
450
+ self.downsamples = nn.Sequential(*downsamples)
451
+
452
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
453
+ x_copy = x.clone()
454
+ for module in self.downsamples:
455
+ x = module(x, feat_cache, feat_idx)
456
+
457
+ return x + self.avg_shortcut(x_copy)
458
+
459
+
460
+ class Up_ResidualBlock(nn.Module):
461
+
462
+ def __init__(self,
463
+ in_dim,
464
+ out_dim,
465
+ dropout,
466
+ mult,
467
+ temperal_upsample=False,
468
+ up_flag=False):
469
+ super().__init__()
470
+ # Shortcut path with upsample
471
+ if up_flag:
472
+ self.avg_shortcut = DupUp3D(
473
+ in_dim,
474
+ out_dim,
475
+ factor_t=2 if temperal_upsample else 1,
476
+ factor_s=2 if up_flag else 1,
477
+ )
478
+ else:
479
+ self.avg_shortcut = None
480
+
481
+ # Main path with residual blocks and upsample
482
+ upsamples = []
483
+ for _ in range(mult):
484
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
485
+ in_dim = out_dim
486
+
487
+ # Add the final upsample block
488
+ if up_flag:
489
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
490
+ upsamples.append(Resample(out_dim, mode=mode))
491
+
492
+ self.upsamples = nn.Sequential(*upsamples)
493
+
494
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
495
+ x_main = x.clone()
496
+ for module in self.upsamples:
497
+ x_main = module(x_main, feat_cache, feat_idx)
498
+ if self.avg_shortcut is not None:
499
+ x_shortcut = self.avg_shortcut(x, first_chunk)
500
+ return x_main + x_shortcut
501
+ else:
502
+ return x_main
503
+
504
+
505
+ class Encoder3d(nn.Module):
506
+
507
+ def __init__(
508
+ self,
509
+ dim=128,
510
+ z_dim=4,
511
+ dim_mult=[1, 2, 4, 4],
512
+ num_res_blocks=2,
513
+ attn_scales=[],
514
+ temperal_downsample=[True, True, False],
515
+ dropout=0.0,
516
+ ):
517
+ super().__init__()
518
+ self.dim = dim
519
+ self.z_dim = z_dim
520
+ self.dim_mult = dim_mult
521
+ self.num_res_blocks = num_res_blocks
522
+ self.attn_scales = attn_scales
523
+ self.temperal_downsample = temperal_downsample
524
+
525
+ # dimensions
526
+ dims = [dim * u for u in [1] + dim_mult]
527
+ scale = 1.0
528
+
529
+ # init block
530
+ self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
531
+
532
+ # downsample blocks
533
+ downsamples = []
534
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
535
+ t_down_flag = (
536
+ temperal_downsample[i]
537
+ if i < len(temperal_downsample) else False)
538
+ downsamples.append(
539
+ Down_ResidualBlock(
540
+ in_dim=in_dim,
541
+ out_dim=out_dim,
542
+ dropout=dropout,
543
+ mult=num_res_blocks,
544
+ temperal_downsample=t_down_flag,
545
+ down_flag=i != len(dim_mult) - 1,
546
+ ))
547
+ scale /= 2.0
548
+ self.downsamples = nn.Sequential(*downsamples)
549
+
550
+ # middle blocks
551
+ self.middle = nn.Sequential(
552
+ ResidualBlock(out_dim, out_dim, dropout),
553
+ AttentionBlock(out_dim),
554
+ ResidualBlock(out_dim, out_dim, dropout),
555
+ )
556
+
557
+ # # output blocks
558
+ self.head = nn.Sequential(
559
+ RMS_norm(out_dim, images=False),
560
+ nn.SiLU(),
561
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
562
+ )
563
+
564
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
565
+
566
+ if feat_cache is not None:
567
+ idx = feat_idx[0]
568
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
569
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
570
+ cache_x = torch.cat(
571
+ [
572
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
573
+ cache_x.device),
574
+ cache_x,
575
+ ],
576
+ dim=2,
577
+ )
578
+ x = self.conv1(x, feat_cache[idx])
579
+ feat_cache[idx] = cache_x
580
+ feat_idx[0] += 1
581
+ else:
582
+ x = self.conv1(x)
583
+
584
+ ## downsamples
585
+ for layer in self.downsamples:
586
+ if feat_cache is not None:
587
+ x = layer(x, feat_cache, feat_idx)
588
+ else:
589
+ x = layer(x)
590
+
591
+ ## middle
592
+ for layer in self.middle:
593
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
594
+ x = layer(x, feat_cache, feat_idx)
595
+ else:
596
+ x = layer(x)
597
+
598
+ ## head
599
+ for layer in self.head:
600
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
601
+ idx = feat_idx[0]
602
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
603
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
604
+ cache_x = torch.cat(
605
+ [
606
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
607
+ cache_x.device),
608
+ cache_x,
609
+ ],
610
+ dim=2,
611
+ )
612
+ x = layer(x, feat_cache[idx])
613
+ feat_cache[idx] = cache_x
614
+ feat_idx[0] += 1
615
+ else:
616
+ x = layer(x)
617
+
618
+ return x
619
+
620
+
621
+ class Decoder3d(nn.Module):
622
+
623
+ def __init__(
624
+ self,
625
+ dim=128,
626
+ z_dim=4,
627
+ dim_mult=[1, 2, 4, 4],
628
+ num_res_blocks=2,
629
+ attn_scales=[],
630
+ temperal_upsample=[False, True, True],
631
+ dropout=0.0,
632
+ ):
633
+ super().__init__()
634
+ self.dim = dim
635
+ self.z_dim = z_dim
636
+ self.dim_mult = dim_mult
637
+ self.num_res_blocks = num_res_blocks
638
+ self.attn_scales = attn_scales
639
+ self.temperal_upsample = temperal_upsample
640
+
641
+ # dimensions
642
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
643
+ scale = 1.0 / 2**(len(dim_mult) - 2)
644
+ # init block
645
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
646
+
647
+ # middle blocks
648
+ self.middle = nn.Sequential(
649
+ ResidualBlock(dims[0], dims[0], dropout),
650
+ AttentionBlock(dims[0]),
651
+ ResidualBlock(dims[0], dims[0], dropout),
652
+ )
653
+
654
+ # upsample blocks
655
+ upsamples = []
656
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
657
+ t_up_flag = temperal_upsample[i] if i < len(
658
+ temperal_upsample) else False
659
+ upsamples.append(
660
+ Up_ResidualBlock(
661
+ in_dim=in_dim,
662
+ out_dim=out_dim,
663
+ dropout=dropout,
664
+ mult=num_res_blocks + 1,
665
+ temperal_upsample=t_up_flag,
666
+ up_flag=i != len(dim_mult) - 1,
667
+ ))
668
+ self.upsamples = nn.Sequential(*upsamples)
669
+
670
+ # output blocks
671
+ self.head = nn.Sequential(
672
+ RMS_norm(out_dim, images=False),
673
+ nn.SiLU(),
674
+ CausalConv3d(out_dim, 12, 3, padding=1),
675
+ )
676
+
677
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
678
+ if feat_cache is not None:
679
+ idx = feat_idx[0]
680
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
681
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
682
+ cache_x = torch.cat(
683
+ [
684
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
685
+ cache_x.device),
686
+ cache_x,
687
+ ],
688
+ dim=2,
689
+ )
690
+ x = self.conv1(x, feat_cache[idx])
691
+ feat_cache[idx] = cache_x
692
+ feat_idx[0] += 1
693
+ else:
694
+ x = self.conv1(x)
695
+
696
+ for layer in self.middle:
697
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
698
+ x = layer(x, feat_cache, feat_idx)
699
+ else:
700
+ x = layer(x)
701
+
702
+ ## upsamples
703
+ for layer in self.upsamples:
704
+ if feat_cache is not None:
705
+ x = layer(x, feat_cache, feat_idx, first_chunk)
706
+ else:
707
+ x = layer(x)
708
+
709
+ ## head
710
+ for layer in self.head:
711
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
712
+ idx = feat_idx[0]
713
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
714
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
715
+ cache_x = torch.cat(
716
+ [
717
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
718
+ cache_x.device),
719
+ cache_x,
720
+ ],
721
+ dim=2,
722
+ )
723
+ x = layer(x, feat_cache[idx])
724
+ feat_cache[idx] = cache_x
725
+ feat_idx[0] += 1
726
+ else:
727
+ x = layer(x)
728
+ return x
729
+
730
+
731
+ def count_conv3d(model):
732
+ count = 0
733
+ for m in model.modules():
734
+ if isinstance(m, CausalConv3d):
735
+ count += 1
736
+ return count
737
+
738
+
739
+ class AutoencoderKLWan2_2_(nn.Module):
740
+
741
+ def __init__(
742
+ self,
743
+ dim=160,
744
+ dec_dim=256,
745
+ z_dim=16,
746
+ dim_mult=[1, 2, 4, 4],
747
+ num_res_blocks=2,
748
+ attn_scales=[],
749
+ temperal_downsample=[True, True, False],
750
+ dropout=0.0,
751
+ ):
752
+ super().__init__()
753
+ self.dim = dim
754
+ self.z_dim = z_dim
755
+ self.dim_mult = dim_mult
756
+ self.num_res_blocks = num_res_blocks
757
+ self.attn_scales = attn_scales
758
+ self.temperal_downsample = temperal_downsample
759
+ self.temperal_upsample = temperal_downsample[::-1]
760
+
761
+ # modules
762
+ self.encoder = Encoder3d(
763
+ dim,
764
+ z_dim * 2,
765
+ dim_mult,
766
+ num_res_blocks,
767
+ attn_scales,
768
+ self.temperal_downsample,
769
+ dropout,
770
+ )
771
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
772
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
773
+ self.decoder = Decoder3d(
774
+ dec_dim,
775
+ z_dim,
776
+ dim_mult,
777
+ num_res_blocks,
778
+ attn_scales,
779
+ self.temperal_upsample,
780
+ dropout,
781
+ )
782
+
783
+ def forward(self, x, scale=[0, 1]):
784
+ mu = self.encode(x, scale)
785
+ x_recon = self.decode(mu, scale)
786
+ return x_recon, mu
787
+
788
+ def encode(self, x, scale):
789
+ self.clear_cache()
790
+ # z: [b,c,t,h,w]
791
+ scale = [item.to(x.device, x.dtype) for item in scale]
792
+ x = patchify(x, patch_size=2)
793
+ t = x.shape[2]
794
+ iter_ = 1 + (t - 1) // 4
795
+ for i in range(iter_):
796
+ self._enc_conv_idx = [0]
797
+ if i == 0:
798
+ out = self.encoder(
799
+ x[:, :, :1, :, :],
800
+ feat_cache=self._enc_feat_map,
801
+ feat_idx=self._enc_conv_idx,
802
+ )
803
+ else:
804
+ out_ = self.encoder(
805
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
806
+ feat_cache=self._enc_feat_map,
807
+ feat_idx=self._enc_conv_idx,
808
+ )
809
+ out = torch.cat([out, out_], 2)
810
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
811
+ if isinstance(scale[0], torch.Tensor):
812
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
813
+ 1, self.z_dim, 1, 1, 1)
814
+ else:
815
+ mu = (mu - scale[0]) * scale[1]
816
+ x = torch.cat([mu, log_var], dim = 1)
817
+ self.clear_cache()
818
+ return x
819
+
820
+ def decode(self, z, scale):
821
+ self.clear_cache()
822
+ # z: [b,c,t,h,w]
823
+ scale = [item.to(z.device, z.dtype) for item in scale]
824
+ if isinstance(scale[0], torch.Tensor):
825
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
826
+ 1, self.z_dim, 1, 1, 1)
827
+ else:
828
+ z = z / scale[1] + scale[0]
829
+ iter_ = z.shape[2]
830
+ x = self.conv2(z)
831
+ for i in range(iter_):
832
+ self._conv_idx = [0]
833
+ if i == 0:
834
+ out = self.decoder(
835
+ x[:, :, i:i + 1, :, :],
836
+ feat_cache=self._feat_map,
837
+ feat_idx=self._conv_idx,
838
+ first_chunk=True,
839
+ )
840
+ else:
841
+ out_ = self.decoder(
842
+ x[:, :, i:i + 1, :, :],
843
+ feat_cache=self._feat_map,
844
+ feat_idx=self._conv_idx,
845
+ )
846
+ out = torch.cat([out, out_], 2)
847
+ out = unpatchify(out, patch_size=2)
848
+ self.clear_cache()
849
+ return out
850
+
851
+ def reparameterize(self, mu, log_var):
852
+ std = torch.exp(0.5 * log_var)
853
+ eps = torch.randn_like(std)
854
+ return eps * std + mu
855
+
856
+ def sample(self, imgs, deterministic=False):
857
+ mu, log_var = self.encode(imgs)
858
+ if deterministic:
859
+ return mu
860
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
861
+ return mu + std * torch.randn_like(std)
862
+
863
+ def clear_cache(self):
864
+ self._conv_num = count_conv3d(self.decoder)
865
+ self._conv_idx = [0]
866
+ self._feat_map = [None] * self._conv_num
867
+ # cache encode
868
+ self._enc_conv_num = count_conv3d(self.encoder)
869
+ self._enc_conv_idx = [0]
870
+ self._enc_feat_map = [None] * self._enc_conv_num
871
+
872
+
873
+ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
874
+ # params
875
+ cfg = dict(
876
+ dim=dim,
877
+ z_dim=z_dim,
878
+ dim_mult=[1, 2, 4, 4],
879
+ num_res_blocks=2,
880
+ attn_scales=[],
881
+ temperal_downsample=[True, True, True],
882
+ dropout=0.0,
883
+ )
884
+ cfg.update(**kwargs)
885
+
886
+ # init model
887
+ model = AutoencoderKLWan2_2_(**cfg)
888
+
889
+ return model
890
+
891
+
892
+ class AutoencoderKLWan3_8(ModelMixin, ConfigMixin, FromOriginalModelMixin):
893
+
894
+ @register_to_config
895
+ def __init__(
896
+ self,
897
+ latent_channels=48,
898
+ c_dim=160,
899
+ vae_pth=None,
900
+ dim_mult=[1, 2, 4, 4],
901
+ temperal_downsample=[False, True, True],
902
+ temporal_compression_ratio=4,
903
+ spatial_compression_ratio=8
904
+ ):
905
+ super().__init__()
906
+ mean = torch.tensor(
907
+ [
908
+ -0.2289,
909
+ -0.0052,
910
+ -0.1323,
911
+ -0.2339,
912
+ -0.2799,
913
+ 0.0174,
914
+ 0.1838,
915
+ 0.1557,
916
+ -0.1382,
917
+ 0.0542,
918
+ 0.2813,
919
+ 0.0891,
920
+ 0.1570,
921
+ -0.0098,
922
+ 0.0375,
923
+ -0.1825,
924
+ -0.2246,
925
+ -0.1207,
926
+ -0.0698,
927
+ 0.5109,
928
+ 0.2665,
929
+ -0.2108,
930
+ -0.2158,
931
+ 0.2502,
932
+ -0.2055,
933
+ -0.0322,
934
+ 0.1109,
935
+ 0.1567,
936
+ -0.0729,
937
+ 0.0899,
938
+ -0.2799,
939
+ -0.1230,
940
+ -0.0313,
941
+ -0.1649,
942
+ 0.0117,
943
+ 0.0723,
944
+ -0.2839,
945
+ -0.2083,
946
+ -0.0520,
947
+ 0.3748,
948
+ 0.0152,
949
+ 0.1957,
950
+ 0.1433,
951
+ -0.2944,
952
+ 0.3573,
953
+ -0.0548,
954
+ -0.1681,
955
+ -0.0667,
956
+ ], dtype=torch.float32
957
+ )
958
+ std = torch.tensor(
959
+ [
960
+ 0.4765,
961
+ 1.0364,
962
+ 0.4514,
963
+ 1.1677,
964
+ 0.5313,
965
+ 0.4990,
966
+ 0.4818,
967
+ 0.5013,
968
+ 0.8158,
969
+ 1.0344,
970
+ 0.5894,
971
+ 1.0901,
972
+ 0.6885,
973
+ 0.6165,
974
+ 0.8454,
975
+ 0.4978,
976
+ 0.5759,
977
+ 0.3523,
978
+ 0.7135,
979
+ 0.6804,
980
+ 0.5833,
981
+ 1.4146,
982
+ 0.8986,
983
+ 0.5659,
984
+ 0.7069,
985
+ 0.5338,
986
+ 0.4889,
987
+ 0.4917,
988
+ 0.4069,
989
+ 0.4999,
990
+ 0.6866,
991
+ 0.4093,
992
+ 0.5709,
993
+ 0.6065,
994
+ 0.6415,
995
+ 0.4944,
996
+ 0.5726,
997
+ 1.2042,
998
+ 0.5458,
999
+ 1.6887,
1000
+ 0.3971,
1001
+ 1.0600,
1002
+ 0.3943,
1003
+ 0.5537,
1004
+ 0.5444,
1005
+ 0.4089,
1006
+ 0.7468,
1007
+ 0.7744,
1008
+ ], dtype=torch.float32
1009
+ )
1010
+ self.scale = [mean, 1.0 / std]
1011
+
1012
+ # init model
1013
+ self.model = _video_vae(
1014
+ pretrained_path=vae_pth,
1015
+ z_dim=latent_channels,
1016
+ dim=c_dim,
1017
+ dim_mult=dim_mult,
1018
+ temperal_downsample=temperal_downsample,
1019
+ ).eval().requires_grad_(False)
1020
+
1021
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1022
+ x = [
1023
+ self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
1024
+ for u in x
1025
+ ]
1026
+ x = torch.stack(x)
1027
+ return x
1028
+
1029
+ @apply_forward_hook
1030
+ def encode(
1031
+ self, x: torch.Tensor, return_dict: bool = True
1032
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1033
+ h = self._encode(x)
1034
+
1035
+ posterior = DiagonalGaussianDistribution(h)
1036
+
1037
+ if not return_dict:
1038
+ return (posterior,)
1039
+ return AutoencoderKLOutput(latent_dist=posterior)
1040
+
1041
+ def _decode(self, zs):
1042
+ dec = [
1043
+ self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
1044
+ for u in zs
1045
+ ]
1046
+ dec = torch.stack(dec)
1047
+
1048
+ return DecoderOutput(sample=dec)
1049
+
1050
+ @apply_forward_hook
1051
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1052
+ decoded = self._decode(z).sample
1053
+
1054
+ if not return_dict:
1055
+ return (decoded,)
1056
+ return DecoderOutput(sample=decoded)
1057
+
1058
+ @classmethod
1059
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
1060
+ def filter_kwargs(cls, kwargs):
1061
+ import inspect
1062
+ sig = inspect.signature(cls.__init__)
1063
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
1064
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
1065
+ return filtered_kwargs
1066
+
1067
+ model = cls(**filter_kwargs(cls, additional_kwargs))
1068
+ if pretrained_model_path.endswith(".safetensors"):
1069
+ from safetensors.torch import load_file, safe_open
1070
+ state_dict = load_file(pretrained_model_path)
1071
+ else:
1072
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
1073
+ tmp_state_dict = {}
1074
+ for key in state_dict:
1075
+ tmp_state_dict["model." + key] = state_dict[key]
1076
+ state_dict = tmp_state_dict
1077
+ m, u = model.load_state_dict(state_dict, strict=False)
1078
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1079
+ print(m, u)
1080
+ return model
videox_fun/models/wan_xlm_roberta.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+
12
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
+ assert dim % num_heads == 0
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.num_heads = num_heads
17
+ self.head_dim = dim // num_heads
18
+ self.eps = eps
19
+
20
+ # layers
21
+ self.q = nn.Linear(dim, dim)
22
+ self.k = nn.Linear(dim, dim)
23
+ self.v = nn.Linear(dim, dim)
24
+ self.o = nn.Linear(dim, dim)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x, mask):
28
+ """
29
+ x: [B, L, C].
30
+ """
31
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
+
33
+ # compute query, key, value
34
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
+
38
+ # compute attention
39
+ p = self.dropout.p if self.training else 0.0
40
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
+
43
+ # output
44
+ x = self.o(x)
45
+ x = self.dropout(x)
46
+ return x
47
+
48
+
49
+ class AttentionBlock(nn.Module):
50
+
51
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.num_heads = num_heads
55
+ self.post_norm = post_norm
56
+ self.eps = eps
57
+
58
+ # layers
59
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
61
+ self.ffn = nn.Sequential(
62
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
+ nn.Dropout(dropout))
64
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
65
+
66
+ def forward(self, x, mask):
67
+ if self.post_norm:
68
+ x = self.norm1(x + self.attn(x, mask))
69
+ x = self.norm2(x + self.ffn(x))
70
+ else:
71
+ x = x + self.attn(self.norm1(x), mask)
72
+ x = x + self.ffn(self.norm2(x))
73
+ return x
74
+
75
+
76
+ class XLMRoberta(nn.Module):
77
+ """
78
+ XLMRobertaModel with no pooler and no LM head.
79
+ """
80
+
81
+ def __init__(self,
82
+ vocab_size=250002,
83
+ max_seq_len=514,
84
+ type_size=1,
85
+ pad_id=1,
86
+ dim=1024,
87
+ num_heads=16,
88
+ num_layers=24,
89
+ post_norm=True,
90
+ dropout=0.1,
91
+ eps=1e-5):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.max_seq_len = max_seq_len
95
+ self.type_size = type_size
96
+ self.pad_id = pad_id
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.post_norm = post_norm
101
+ self.eps = eps
102
+
103
+ # embeddings
104
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
+ self.type_embedding = nn.Embedding(type_size, dim)
106
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ # blocks
110
+ self.blocks = nn.ModuleList([
111
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
+ for _ in range(num_layers)
113
+ ])
114
+
115
+ # norm layer
116
+ self.norm = nn.LayerNorm(dim, eps=eps)
117
+
118
+ def forward(self, ids):
119
+ """
120
+ ids: [B, L] of torch.LongTensor.
121
+ """
122
+ b, s = ids.shape
123
+ mask = ids.ne(self.pad_id).long()
124
+
125
+ # embeddings
126
+ x = self.token_embedding(ids) + \
127
+ self.type_embedding(torch.zeros_like(ids)) + \
128
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
+ if self.post_norm:
130
+ x = self.norm(x)
131
+ x = self.dropout(x)
132
+
133
+ # blocks
134
+ mask = torch.where(
135
+ mask.view(b, 1, 1, s).gt(0), 0.0,
136
+ torch.finfo(x.dtype).min)
137
+ for block in self.blocks:
138
+ x = block(x, mask)
139
+
140
+ # output
141
+ if not self.post_norm:
142
+ x = self.norm(x)
143
+ return x
144
+
145
+
146
+ def xlm_roberta_large(pretrained=False,
147
+ return_tokenizer=False,
148
+ device='cpu',
149
+ **kwargs):
150
+ """
151
+ XLMRobertaLarge adapted from Huggingface.
152
+ """
153
+ # params
154
+ cfg = dict(
155
+ vocab_size=250002,
156
+ max_seq_len=514,
157
+ type_size=1,
158
+ pad_id=1,
159
+ dim=1024,
160
+ num_heads=16,
161
+ num_layers=24,
162
+ post_norm=True,
163
+ dropout=0.1,
164
+ eps=1e-5)
165
+ cfg.update(**kwargs)
166
+
167
+ # init a model on device
168
+ with torch.device(device):
169
+ model = XLMRoberta(**cfg)
170
+ return model