diffusers integration

#23
README.md CHANGED
@@ -185,7 +185,6 @@ guider = AdaptiveProjectedGuidance(
185
 
186
  pipe = MotifVideoPipeline.from_pretrained(
187
  "Motif-Technologies/Motif-Video-2B",
188
- revision="diffusers-integration",
189
  torch_dtype=torch.bfloat16,
190
  guider=guider,
191
  )
@@ -234,9 +233,8 @@ guider = AdaptiveProjectedGuidance(
234
  normalization_dims="spatial",
235
  )
236
 
237
- pipe = MotifVideoPipeline.from_pretrained(
238
  "Motif-Technologies/Motif-Video-2B",
239
- revision="diffusers-integration",
240
  torch_dtype=torch.bfloat16,
241
  guider=guider,
242
  )
 
185
 
186
  pipe = MotifVideoPipeline.from_pretrained(
187
  "Motif-Technologies/Motif-Video-2B",
 
188
  torch_dtype=torch.bfloat16,
189
  guider=guider,
190
  )
 
233
  normalization_dims="spatial",
234
  )
235
 
236
+ pipe = MotifVideoImage2VideoPipeline.from_pretrained(
237
  "Motif-Technologies/Motif-Video-2B",
 
238
  torch_dtype=torch.bfloat16,
239
  guider=guider,
240
  )
inference.py DELETED
@@ -1,210 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Motif-Video 2B — Text-to-Video inference.
3
-
4
- GPU requirements: ~24GB VRAM for 720p (1280x736, 121 frames).
5
- Requires: torch, diffusers (with MotifVideoPipeline), transformers>=5.5.4,
6
- accelerate, ftfy, einops, sentencepiece, regex
7
-
8
- Uses Adaptive Projected Guidance (APG) and DPMSolver++ scheduler by default.
9
- """
10
-
11
- import argparse
12
-
13
- import torch
14
- from diffusers import (
15
- AdaptiveProjectedGuidance,
16
- DPMSolverMultistepScheduler,
17
- MotifVideoPipeline,
18
- )
19
- from diffusers.utils import export_to_video
20
-
21
- _DEFAULT_NEGATIVE_PROMPT = (
22
- "text overlay, graphic overlay, watermark, logo, subtitles, timestamp, "
23
- "broadcast graphics, UI elements, random letters, frozen pose, rigid, "
24
- "static expression, jerky motion, mechanical motion, discontinuous motion, "
25
- "flat framing, depthless, dull lighting, monotone, crushed shadows, "
26
- "blown-out highlights, shifting background, fading background, poor continuity, "
27
- "identity drift, deformation, flickering, ghosting, smearing, duplication, "
28
- "mutated proportions, inconsistent clothing, flat colors, desaturated, "
29
- "tonally compressed, poor background separation, exposure shift, "
30
- "uneven brightness, color balance shift"
31
- )
32
-
33
-
34
- def parse_args():
35
- parser = argparse.ArgumentParser(description="Motif-Video 2B Inference (T2V)")
36
- parser.add_argument(
37
- "--model-path",
38
- type=str,
39
- default="Motif-Technologies/Motif-Video-2B",
40
- help="HuggingFace model ID or local checkpoint path",
41
- )
42
- parser.add_argument(
43
- "--prompt",
44
- type=str,
45
- default="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
46
- help="Text prompt for video generation",
47
- )
48
- parser.add_argument(
49
- "--negative-prompt",
50
- type=str,
51
- default=_DEFAULT_NEGATIVE_PROMPT,
52
- help="Negative prompt",
53
- )
54
- parser.add_argument("--output", type=str, default="output.mp4", help="Output video file path")
55
- parser.add_argument("--num-frames", type=int, default=121, help="Number of frames to generate (121 = ~5s at 24fps)")
56
- parser.add_argument("--height", type=int, default=736, help="Video height in pixels")
57
- parser.add_argument("--width", type=int, default=1280, help="Video width in pixels")
58
- parser.add_argument("--guidance-scale", type=float, default=8.0, help="Classifier-free guidance scale")
59
- parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of denoising steps")
60
- parser.add_argument("--fps", type=int, default=24, help="Output video frame rate")
61
- parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
62
- parser.add_argument(
63
- "--dtype",
64
- type=str,
65
- default="bfloat16",
66
- choices=["float16", "bfloat16", "float32"],
67
- help="Model dtype",
68
- )
69
- parser.add_argument(
70
- "--use-sage-attention",
71
- action="store_true",
72
- help="Enable SageAttention for ~2x faster attention (requires: pip install sageattention>=2.1.1 from GitHub source)",
73
- )
74
- return parser.parse_args()
75
-
76
-
77
- def _enable_sage_attention(transformer):
78
- """Patch transformer attention to use SageAttention.
79
-
80
- Only patches _compute_attention (self-attention path). Cross-attention
81
- uses _handle_cross_attention_mode which calls F.sdpa directly and is
82
- unaffected by this patch.
83
-
84
- Mask handling follows motif-models dispatch_optimized_attention pattern:
85
- - mask=None: sage directly
86
- - mask with uniform active length: slice active region -> sage -> pad back
87
- - mask with non-uniform active length: SDPA fallback
88
- """
89
- from sageattention import sageattn
90
- from diffusers.models.transformers.transformer_motif_video import MotifVideoAttnProcessor2_0
91
-
92
- _orig_compute = MotifVideoAttnProcessor2_0._compute_attention
93
-
94
- def _sage_compute(self, query, key, value, attention_mask):
95
- if attention_mask is None:
96
- out = sageattn(
97
- query.contiguous(), key.contiguous(), value.contiguous(),
98
- tensor_layout="HND", is_causal=False,
99
- )
100
- out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
101
- return out
102
-
103
- # Find active token count from mask (shape: [B, 1, 1, S])
104
- padding_indices = attention_mask.sum(dim=-1).long().flatten()
105
- common_padding_index = padding_indices[0]
106
- is_uniform = (padding_indices == common_padding_index).all()
107
-
108
- if not is_uniform:
109
- return _orig_compute(self, query, key, value, attention_mask)
110
-
111
- active_len = common_padding_index.item()
112
- S = query.shape[2]
113
-
114
- if active_len == S:
115
- out = sageattn(
116
- query.contiguous(), key.contiguous(), value.contiguous(),
117
- tensor_layout="HND", is_causal=False,
118
- )
119
- out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
120
- return out
121
-
122
- # Slice to active region, run sage, pad back
123
- q_a = query[:, :, :active_len, :].contiguous()
124
- k_a = key[:, :, :active_len, :].contiguous()
125
- v_a = value[:, :, :active_len, :].contiguous()
126
-
127
- out_a = sageattn(q_a, k_a, v_a, tensor_layout="HND", is_causal=False)
128
-
129
- out = query.new_zeros(query.shape)
130
- out[:, :, :active_len, :] = out_a
131
- out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
132
- return out
133
-
134
- MotifVideoAttnProcessor2_0._compute_attention = _sage_compute
135
- transformer.to(memory_format=torch.channels_last_3d)
136
- print("[SageAttention] Enabled (patched _compute_attention + channels_last_3d)")
137
-
138
-
139
- def main():
140
- args = parse_args()
141
-
142
- dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
143
- torch_dtype = dtype_map[args.dtype]
144
-
145
- print(f"[T2V] Loading model from: {args.model_path}")
146
-
147
- guider = AdaptiveProjectedGuidance(
148
- guidance_scale=args.guidance_scale,
149
- adaptive_projected_guidance_rescale=12.0,
150
- adaptive_projected_guidance_momentum=0.1,
151
- use_original_formulation=True,
152
- normalization_dims="spatial",
153
- )
154
-
155
- pipe = MotifVideoPipeline.from_pretrained(
156
- args.model_path,
157
- torch_dtype=torch_dtype,
158
- guider=guider,
159
- )
160
-
161
- # Replace scheduler with DPMSolver++ for faster convergence and better quality.
162
- # Subclass ignores pipeline-supplied sigmas (PR branch always passes them)
163
- # and uses its own flow-matching sigma schedule instead.
164
- class _FlowDPMSolver(DPMSolverMultistepScheduler):
165
- def set_timesteps(self, num_inference_steps=None, device=None,
166
- sigmas=None, mu=None, timesteps=None):
167
- if sigmas is not None and num_inference_steps is None:
168
- num_inference_steps = len(sigmas)
169
- super().set_timesteps(
170
- num_inference_steps=num_inference_steps,
171
- device=device, timesteps=timesteps,
172
- )
173
-
174
- pipe.scheduler = _FlowDPMSolver(
175
- num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
176
- algorithm_type="dpmsolver++",
177
- solver_order=2,
178
- prediction_type="flow_prediction",
179
- use_flow_sigmas=True,
180
- flow_shift=15.0,
181
- )
182
-
183
- # Offload model components to CPU between uses to reduce peak VRAM
184
- pipe.enable_model_cpu_offload()
185
-
186
- if args.use_sage_attention:
187
- _enable_sage_attention(pipe.transformer)
188
-
189
- generator = torch.Generator(device="cuda").manual_seed(args.seed)
190
-
191
- print(f"Generating video: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps")
192
- output = pipe(
193
- prompt=args.prompt,
194
- negative_prompt=args.negative_prompt,
195
- height=args.height,
196
- width=args.width,
197
- num_frames=args.num_frames,
198
- num_inference_steps=args.num_inference_steps,
199
- frame_rate=args.fps,
200
- use_linear_quadratic_schedule=False,
201
- generator=generator,
202
- )
203
-
204
- video_frames = output.frames[0]
205
- export_to_video(video_frames, args.output, fps=args.fps)
206
- print(f"Video saved to: {args.output}")
207
-
208
-
209
- if __name__ == "__main__":
210
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_index.json CHANGED
@@ -14,7 +14,7 @@
14
  "GemmaTokenizer"
15
  ],
16
  "transformer": [
17
- "transformer_motif_video",
18
  "MotifVideoTransformer3DModel"
19
  ],
20
  "vae": [
 
14
  "GemmaTokenizer"
15
  ],
16
  "transformer": [
17
+ "diffusers",
18
  "MotifVideoTransformer3DModel"
19
  ],
20
  "vae": [
transformer/config.json CHANGED
@@ -3,7 +3,6 @@
3
  "_diffusers_version": "0.36.0",
4
  "_library": "diffusers",
5
  "attention_head_dim": 128,
6
- "base_latent_size": null,
7
  "image_embed_dim": 1152,
8
  "in_channels": 33,
9
  "mlp_ratio": 4.0,
@@ -15,7 +14,6 @@
15
  "out_channels": 16,
16
  "patch_size": 2,
17
  "patch_size_t": 1,
18
- "pooled_projection_dim": null,
19
  "qk_norm": "rms_norm",
20
  "rope_axes_dim": [
21
  16,
 
3
  "_diffusers_version": "0.36.0",
4
  "_library": "diffusers",
5
  "attention_head_dim": 128,
 
6
  "image_embed_dim": 1152,
7
  "in_channels": 33,
8
  "mlp_ratio": 4.0,
 
14
  "out_channels": 16,
15
  "patch_size": 2,
16
  "patch_size_t": 1,
 
17
  "qk_norm": "rms_norm",
18
  "rope_axes_dim": [
19
  16,
transformer/transformer_motif_video.py DELETED
@@ -1,1350 +0,0 @@
1
- # Copyright 2026 Motif Technologies. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
- from functools import lru_cache
17
- from typing import Any, Dict, List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
25
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
26
- from diffusers.models.attention import FeedForward
27
- from diffusers.models.attention_processor import Attention, AttentionProcessor
28
- from diffusers.models.cache_utils import CacheMixin
29
- from diffusers.models.embeddings import (
30
- PixArtAlphaTextProjection,
31
- TimestepEmbedding,
32
- Timesteps,
33
- )
34
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
35
- from diffusers.models.modeling_utils import ModelMixin
36
- from diffusers.models.normalization import (
37
- AdaLayerNormContinuous,
38
- AdaLayerNormZero,
39
- AdaLayerNormZeroSingle,
40
- )
41
- from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
42
-
43
- # Stub functions for TREAD (Token REduction with Approximated Distillation).
44
- # These stubs ensure TREAD code paths are never activated during inference
45
- # without requiring the motif_core package.
46
- def is_tread_start(block_idx, start, end): return False
47
- def is_tread_end(block_idx, start, end): return False
48
-
49
-
50
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
-
52
- NUM_TRAIN_TIMESTEPS = 1000
53
-
54
-
55
- def apply_rotary_emb(
56
- x: torch.Tensor,
57
- freqs_cis: Tuple[torch.Tensor, torch.Tensor],
58
- use_real: bool = True,
59
- use_real_unbind_dim: int = -1,
60
- ) -> torch.Tensor:
61
- """
62
- Apply rotary positional embeddings (RoPE) to input tensors.
63
-
64
- This implementation supports both standard 2D RoPE tensors [L, Dh] and batched 4D RoPE
65
- tensors [B, 1, L, Dh] for compatibility with TREAD's token-dropping mechanism where
66
- different batches may have different token subsets.
67
-
68
- Args:
69
- x: Input tensor of shape [B, H, L, Dh].
70
- freqs_cis: Tuple of (cos, sin) tensors. Supports shapes [L, Dh] or [B, 1, L, Dh].
71
- use_real: Whether to use real-valued RoPE implementation.
72
- use_real_unbind_dim: Dimension to unbind when using real-valued RoPE (-1 or -2).
73
-
74
- Returns:
75
- Tensor with rotary embeddings applied, same shape as input x.
76
- """
77
- if use_real:
78
- cos, sin = freqs_cis
79
- if cos.dim() == 2: # [L, Dh] → [1, 1, L, Dh]
80
- cos = cos.unsqueeze(0).unsqueeze(0)
81
- sin = sin.unsqueeze(0).unsqueeze(0)
82
- if cos.dim() != 4 or sin.dim() != 4:
83
- raise RuntimeError(f"RoPE must be 2D or 4D, got cos={cos.dim()}D, sin={sin.dim()}D")
84
-
85
- cos, sin = cos.to(x.device), sin.to(x.device)
86
-
87
- if cos.size(-2) != x.size(-2) or cos.size(-1) != x.size(-1):
88
- raise RuntimeError(
89
- f"RoPE shape mismatch: rope[-2:]=({cos.size(-2)},{cos.size(-1)}) vs x[-2:]=({x.size(-2)},{x.size(-1)})"
90
- )
91
-
92
- if use_real_unbind_dim == -1:
93
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
94
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
95
- elif use_real_unbind_dim == -2:
96
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
97
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
98
- else:
99
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
100
-
101
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
102
- return out
103
- else:
104
- x_rot = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
105
- freqs = freqs_cis.unsqueeze(2)
106
- x_out = torch.view_as_real(x_rot * freqs).flatten(3)
107
- return x_out.type_as(x)
108
-
109
-
110
- class MotifVideoAttnProcessor2_0:
111
- def __init__(self):
112
- if not hasattr(F, "scaled_dot_product_attention"):
113
- raise ImportError(
114
- "MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
115
- )
116
-
117
- def __call__(
118
- self,
119
- attn: Attention,
120
- hidden_states: torch.Tensor,
121
- encoder_hidden_states: Optional[torch.Tensor] = None,
122
- attention_mask: Optional[torch.Tensor] = None,
123
- image_rotary_emb: Optional[torch.Tensor] = None,
124
- query_input: Optional[torch.Tensor] = None,
125
- key_input: Optional[torch.Tensor] = None,
126
- value_input: Optional[torch.Tensor] = None,
127
- ) -> torch.Tensor:
128
- # Cross-attention mode: query already projected externally (cross_attn_query_proj + norm),
129
- # skip to_q and only apply reshape + norm_q + RoPE. K/V use to_k/to_v as normal.
130
- if query_input is not None:
131
- query = query_input.unflatten(2, (attn.heads, -1)).transpose(1, 2)
132
- key = attn.to_k(key_input)
133
- value = attn.to_v(value_input)
134
-
135
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
136
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
137
-
138
- if attn.norm_q is not None:
139
- query = attn.norm_q(query)
140
- if attn.norm_k is not None:
141
- key = attn.norm_k(key)
142
-
143
- if image_rotary_emb is not None:
144
- query = apply_rotary_emb(query, image_rotary_emb)
145
-
146
- hidden_states = F.scaled_dot_product_attention(
147
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
148
- )
149
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
150
- hidden_states = hidden_states.to(query.dtype)
151
- return hidden_states, None
152
-
153
- if attn.add_q_proj is None and encoder_hidden_states is not None:
154
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
155
-
156
- # 1. QKV projections
157
- query = attn.to_q(hidden_states)
158
- key = attn.to_k(hidden_states)
159
- value = attn.to_v(hidden_states)
160
-
161
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
162
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
163
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
164
-
165
- # 2. QK normalization
166
- if attn.norm_q is not None:
167
- query = attn.norm_q(query)
168
- if attn.norm_k is not None:
169
- key = attn.norm_k(key)
170
-
171
- # 3. Rotational positional embeddings applied to latent stream
172
- if image_rotary_emb is not None:
173
- if attn.add_q_proj is None and encoder_hidden_states is not None:
174
- query = torch.cat(
175
- [
176
- apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
177
- query[:, :, -encoder_hidden_states.shape[1] :],
178
- ],
179
- dim=2,
180
- )
181
- key = torch.cat(
182
- [
183
- apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
184
- key[:, :, -encoder_hidden_states.shape[1] :],
185
- ],
186
- dim=2,
187
- )
188
- else:
189
- query = apply_rotary_emb(query, image_rotary_emb)
190
- key = apply_rotary_emb(key, image_rotary_emb)
191
-
192
- # 4. Encoder condition QKV projection and normalization
193
- if attn.add_q_proj is not None and encoder_hidden_states is not None:
194
- encoder_query = attn.add_q_proj(encoder_hidden_states)
195
- encoder_key = attn.add_k_proj(encoder_hidden_states)
196
- encoder_value = attn.add_v_proj(encoder_hidden_states)
197
-
198
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
199
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
200
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
201
-
202
- if attn.norm_added_q is not None:
203
- encoder_query = attn.norm_added_q(encoder_query)
204
- if attn.norm_added_k is not None:
205
- encoder_key = attn.norm_added_k(encoder_key)
206
-
207
- query = torch.cat([query, encoder_query], dim=2)
208
- key = torch.cat([key, encoder_key], dim=2)
209
- value = torch.cat([value, encoder_value], dim=2)
210
-
211
- # 5. Attention
212
- hidden_states = F.scaled_dot_product_attention(
213
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
214
- )
215
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
216
- hidden_states = hidden_states.to(query.dtype)
217
-
218
- # 6. Output projection
219
- if encoder_hidden_states is not None:
220
- hidden_states, encoder_hidden_states = (
221
- hidden_states[:, : -encoder_hidden_states.shape[1]],
222
- hidden_states[:, -encoder_hidden_states.shape[1] :],
223
- )
224
-
225
- if getattr(attn, "to_out", None) is not None:
226
- hidden_states = attn.to_out[0](hidden_states)
227
- hidden_states = attn.to_out[1](hidden_states)
228
-
229
- if getattr(attn, "to_add_out", None) is not None:
230
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
231
-
232
- return hidden_states, encoder_hidden_states
233
-
234
-
235
- class MotifVideoPatchEmbed(nn.Module):
236
- def __init__(
237
- self,
238
- patch_size: Union[int, Tuple[int, int, int]] = 16,
239
- in_chans: int = 3,
240
- embed_dim: int = 768,
241
- ) -> None:
242
- super().__init__()
243
-
244
- patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
245
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
246
-
247
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
248
- hidden_states = self.proj(hidden_states)
249
- hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
250
- return hidden_states
251
-
252
-
253
- class MotifVideoAdaNorm(nn.Module):
254
- def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
255
- super().__init__()
256
-
257
- out_features = out_features or 2 * in_features
258
- self.linear = nn.Linear(in_features, out_features)
259
- self.nonlinearity = nn.SiLU()
260
-
261
- def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
262
- temb = self.linear(self.nonlinearity(temb))
263
- gate_msa, gate_mlp = temb.chunk(2, dim=1)
264
- gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
265
- return gate_msa, gate_mlp
266
-
267
-
268
- class MotifVideoConditionEmbedding(nn.Module):
269
- def __init__(
270
- self,
271
- embedding_dim: int,
272
- pooled_projection_dim: int | None,
273
- ):
274
- super().__init__()
275
-
276
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
277
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
278
-
279
- if isinstance(pooled_projection_dim, int):
280
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
281
-
282
- def forward(
283
- self,
284
- timestep: torch.Tensor,
285
- pooled_projection: torch.Tensor | None = None,
286
- ) -> Tuple[torch.Tensor, torch.Tensor]:
287
- timesteps_proj = self.time_proj(timestep)
288
- timestep_embedder_dtype = next(self.timestep_embedder.parameters()).dtype
289
- conditioning = self.timestep_embedder(timesteps_proj.to(timestep_embedder_dtype)) # (N, D)
290
- if pooled_projection is not None:
291
- conditioning = conditioning + self.text_embedder(pooled_projection)
292
-
293
- token_replace_emb = None
294
-
295
- return conditioning, token_replace_emb
296
-
297
-
298
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L485-L486
299
- def find_correction_factor(num_rotations, dim, base, max_position_embeddings):
300
- dtype = num_rotations.dtype if isinstance(num_rotations, torch.Tensor) else torch.float32
301
- max_pos_tensor = torch.as_tensor(max_position_embeddings, dtype=dtype)
302
- return (dim * torch.log(max_pos_tensor / (num_rotations * 2 * math.pi))) / (
303
- 2 * math.log(base)
304
- ) # Inverse dim formula to find number of rotations
305
-
306
-
307
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L489-L495
308
- def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len):
309
- """
310
- Find the correction range for NTK-by-parts interpolation.
311
- """
312
- low = torch.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len))
313
- high = torch.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len))
314
- low = torch.clamp(low, min=0)
315
- high = torch.clamp(high, max=dim - 1)
316
- return low, high # Clamp values just in case
317
-
318
-
319
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L498-L504
320
- def linear_ramp_mask(min_val, max_val, num_dim):
321
- if isinstance(min_val, torch.Tensor):
322
- if (min_val == max_val).all():
323
- max_val = max_val + 0.001
324
- elif min_val == max_val:
325
- max_val += 0.001
326
-
327
- linear_func = (torch.arange(num_dim, dtype=torch.float32) - min_val) / (max_val - min_val)
328
- ramp_func = torch.clamp(linear_func, 0, 1)
329
- return ramp_func
330
-
331
-
332
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L507-L511
333
- def find_newbase_ntk(dim, base, scale):
334
- """
335
- Calculate the new base for NTK-aware scaling.
336
- """
337
- # Avoid division by zero when dim == 2 (or invalid smaller values).
338
- # In these degenerate cases, fall back to the original base (no NTK adjustment).
339
- if dim <= 2:
340
- return base
341
- return base * (scale ** (dim / (dim - 2)))
342
-
343
-
344
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L514-L652
345
- def get_1d_rotary_pos_embed(
346
- dim: int,
347
- pos: Union[np.ndarray, int],
348
- theta: float = 10000.0,
349
- use_real=False,
350
- linear_factor=1.0,
351
- ntk_factor=1.0,
352
- repeat_interleave_real=True,
353
- freqs_dtype=torch.float32,
354
- yarn=False,
355
- max_pe_len=None,
356
- ori_max_pe_len=64,
357
- dype=False,
358
- current_timestep=1.0,
359
- ):
360
- """
361
- Precompute the frequency tensor for complex exponentials with RoPE.
362
- Supports YARN interpolation for vision transformers.
363
-
364
- Args:
365
- dim (`int`):
366
- Dimension of the frequency tensor.
367
- pos (`np.ndarray` or `int`):
368
- Position indices for the frequency tensor. [S] or scalar.
369
- theta (`float`, *optional*, defaults to 10000.0):
370
- Scaling factor for frequency computation.
371
- use_real (`bool`, *optional*, defaults to False):
372
- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
373
- linear_factor (`float`, *optional*, defaults to 1.0):
374
- Scaling factor for linear interpolation.
375
- ntk_factor (`float`, *optional*, defaults to 1.0):
376
- Scaling factor for NTK-Aware RoPE.
377
- repeat_interleave_real (`bool`, *optional*, defaults to True):
378
- If True and use_real, real and imaginary parts are interleaved with themselves to reach dim.
379
- Otherwise, they are concatenated.
380
- freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
381
- Data type of the frequency tensor.
382
- yarn (`bool`, *optional*, defaults to False):
383
- If True, use YARN interpolation combining NTK, linear, and base methods.
384
- max_pe_len (`int`, *optional*):
385
- Maximum position encoding length (current patches for vision models).
386
- ori_max_pe_len (`int`, *optional*, defaults to 64):
387
- Original maximum position encoding length (base patches for vision models).
388
- dype (`bool`, *optional*, defaults to False):
389
- If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
390
- current_timestep (`float`, *optional*, defaults to 1.0):
391
- Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
392
-
393
- Returns:
394
- `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
395
- If use_real=True, returns tuple of (cos, sin) tensors.
396
- """
397
- assert dim % 2 == 0
398
-
399
- if isinstance(pos, int):
400
- pos = torch.arange(pos)
401
- if isinstance(pos, np.ndarray):
402
- pos = torch.from_numpy(pos)
403
-
404
- device = pos.device
405
-
406
- if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
407
- if not isinstance(max_pe_len, torch.Tensor):
408
- max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device)
409
-
410
- scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
411
-
412
- beta_0 = 1.25
413
- beta_1 = 0.75
414
- gamma_0 = 16
415
- gamma_1 = 2
416
-
417
- freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
418
-
419
- freqs_linear = 1.0 / torch.einsum(
420
- "..., f -> ... f",
421
- scale,
422
- (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)),
423
- )
424
-
425
- new_base = find_newbase_ntk(dim, theta, scale)
426
- if new_base.dim() > 0:
427
- new_base = new_base.view(-1, 1)
428
- freqs_ntk = 1.0 / torch.pow(new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
429
- if freqs_ntk.dim() > 1:
430
- freqs_ntk = freqs_ntk.squeeze()
431
-
432
- if dype:
433
- beta_0 = torch.pow(beta_0, 2.0 * torch.pow(current_timestep, 2.0))
434
- beta_1 = torch.pow(beta_1, 2.0 * torch.pow(current_timestep, 2.0))
435
-
436
- low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
437
- high = torch.clamp(high, max=dim // 2)
438
-
439
- freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
440
- freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
441
-
442
- if dype:
443
- gamma_0 = torch.pow(gamma_0, 2.0 * torch.pow(current_timestep, 2.0))
444
- gamma_1 = torch.pow(gamma_1, 2.0 * torch.pow(current_timestep, 2.0))
445
-
446
- low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
447
- high = torch.clamp(high, max=dim // 2)
448
-
449
- freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
450
- freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
451
-
452
- else:
453
- theta_ntk = theta * ntk_factor
454
- freqs = 1.0 / (theta_ntk ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)) / linear_factor
455
-
456
- freqs = torch.outer(pos, freqs)
457
-
458
- is_npu = freqs.device.type == "npu"
459
- if is_npu:
460
- freqs = freqs.float()
461
-
462
- if use_real and repeat_interleave_real:
463
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
464
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
465
-
466
- if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
467
- mscale = torch.where(scale <= 1.0, 1.0, 0.1 * torch.log(scale) + 1.0).to(scale)
468
- freqs_cos = freqs_cos * mscale
469
- freqs_sin = freqs_sin * mscale
470
-
471
- return freqs_cos, freqs_sin
472
- elif use_real:
473
- freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
474
- freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
475
- return freqs_cos, freqs_sin
476
- else:
477
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
478
- return freqs_cis
479
-
480
-
481
- class MotifVideoRotaryPosEmbed(nn.Module):
482
- def __init__(
483
- self,
484
- patch_size: int,
485
- patch_size_t: int,
486
- rope_dim: List[int],
487
- theta: float = 256.0,
488
- base_latent_size: int | None = None,
489
- ):
490
- """
491
- Rotary Positional Embedding (RoPE) for video latents.
492
-
493
- Args:
494
- patch_size (`int`):
495
- Spatial patch size (e.g., 2).
496
- patch_size_t (`int`):
497
- Temporal patch size (e.g., 1).
498
- rope_dim (`List[int]`):
499
- Dimensions for RoPE across [Time, Height, Width] axes.
500
- theta (`float`, *optional*, defaults to 256.0):
501
- Base frequency for rotary embeddings.
502
- base_latent_size (`int`, *optional*):
503
- The maximum spatial dimension (in latent units) seen during training,
504
- i.e. `training_resolution / vae_scale_factor_spatial`.
505
- For example, for 1280x1280 training images and a VAE spatial downscale
506
- (`vae_scale_factor_spatial`) of 8, this would be 160; for a downscale
507
- of 16, it would be 80.
508
- """
509
- super().__init__()
510
-
511
- self.patch_size = patch_size
512
- self.patch_size_t = patch_size_t
513
- self.rope_dim = rope_dim
514
- self.theta = theta
515
- self.base_latent_size = base_latent_size
516
-
517
- @lru_cache(maxsize=8)
518
- def _get_base_patch_grid_size(self, base_latent_size: Optional[int], patch_size: int) -> Optional[int]:
519
- return base_latent_size // patch_size if base_latent_size else None
520
-
521
- @lru_cache(maxsize=8)
522
- def _get_dynamic_interpolation_scale(self, h: int, w: int, base_grid_size: int) -> float:
523
- return math.sqrt(h * w / (base_grid_size**2))
524
-
525
- def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
526
- if self.training:
527
- assert self.base_latent_size is None, (
528
- "RoPE interpolation/extrapolation logic should only be enabled for inference. "
529
- f"During training, base_latent_size must be None, but got {self.base_latent_size!r}."
530
- )
531
-
532
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
533
- rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
534
-
535
- axes_grids = []
536
- for i in range(3):
537
- # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
538
- # original implementation creates it on CPU and then moves it to device. This results in numerical
539
- # differences in layerwise debugging outputs, but visually it is the same.
540
- grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
541
- axes_grids.append(grid)
542
- grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
543
- grid = torch.stack(grid, dim=0) # [3, W, H, T]
544
-
545
- base_patch_grid_size = self._get_base_patch_grid_size(self.base_latent_size, self.patch_size)
546
- if base_patch_grid_size is not None:
547
- if base_patch_grid_size <= 0:
548
- raise ValueError(f"base_patch_grid_size must be a positive number, got {base_patch_grid_size}.")
549
- dynamic_interpolation_scale = self._get_dynamic_interpolation_scale(
550
- rope_sizes[1], rope_sizes[2], base_patch_grid_size
551
- )
552
-
553
- normalized_timestep = torch.tensor(1.0)
554
- if not self.training and timestep is not None:
555
- normalized_timestep = timestep[0] / NUM_TRAIN_TIMESTEPS
556
-
557
- freqs = []
558
- for i in range(3):
559
- common_kwargs = {
560
- "dim": self.rope_dim[i],
561
- "pos": grid[i].reshape(-1),
562
- "theta": self.theta,
563
- "use_real": True,
564
- "freqs_dtype": torch.float64,
565
- }
566
-
567
- # Apply scaling only to spatial dimensions (Height and Width, i=1 and i=2)
568
- if i > 0 and base_patch_grid_size is not None and dynamic_interpolation_scale > 1.0:
569
- # We project the training base to the current size using the uniform scale factor.
570
- # max_pe_len tells the RoPE logic the "new" maximum length it's dealing with.
571
- max_pe_len = torch.tensor(
572
- base_patch_grid_size * dynamic_interpolation_scale,
573
- dtype=torch.float64,
574
- device=hidden_states.device,
575
- )
576
-
577
- freq = get_1d_rotary_pos_embed(
578
- **common_kwargs,
579
- yarn=True, # Enable Yet Another RoPE extensioN (YARN) for extrapolation
580
- max_pe_len=max_pe_len,
581
- ori_max_pe_len=base_patch_grid_size, # The original training scale
582
- dype=True, # Enable Dynamic Position Encoding (time-aware)
583
- current_timestep=normalized_timestep,
584
- )
585
- else:
586
- # Time dimension OR within training bounds -> Standard RoPE
587
- freq = get_1d_rotary_pos_embed(**common_kwargs)
588
-
589
- freqs.append(freq)
590
-
591
- freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
592
- freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
593
- return freqs_cos, freqs_sin
594
-
595
-
596
- class MotifVideoImageProjection(nn.Module):
597
- def __init__(self, in_features: int, hidden_size: int):
598
- super().__init__()
599
- self.norm_in = nn.LayerNorm(in_features)
600
- self.linear_1 = nn.Linear(in_features, in_features)
601
- self.act_fn = nn.GELU()
602
- self.linear_2 = nn.Linear(in_features, hidden_size)
603
- self.norm_out = nn.LayerNorm(hidden_size)
604
-
605
- def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
606
- hidden_states = self.norm_in(image_embeds)
607
- hidden_states = self.linear_1(hidden_states)
608
- hidden_states = self.act_fn(hidden_states)
609
- hidden_states = self.linear_2(hidden_states)
610
- hidden_states = self.norm_out(hidden_states)
611
- return hidden_states
612
-
613
-
614
- class MotifVideoSingleTransformerBlock(nn.Module):
615
- def __init__(
616
- self,
617
- num_attention_heads: int,
618
- attention_head_dim: int,
619
- mlp_ratio: float = 4.0,
620
- qk_norm: str = "rms_norm",
621
- norm_type: str = "layer_norm",
622
- enable_text_cross_attention: bool = False,
623
- ) -> None:
624
- super().__init__()
625
-
626
- hidden_size = num_attention_heads * attention_head_dim
627
- mlp_dim = int(hidden_size * mlp_ratio)
628
-
629
- self.attn = Attention(
630
- query_dim=hidden_size,
631
- cross_attention_dim=None,
632
- dim_head=attention_head_dim,
633
- heads=num_attention_heads,
634
- out_dim=hidden_size,
635
- bias=True,
636
- processor=MotifVideoAttnProcessor2_0(),
637
- qk_norm=qk_norm,
638
- eps=1e-6,
639
- pre_only=True,
640
- )
641
-
642
- self.enable_text_cross_attention = enable_text_cross_attention
643
- if enable_text_cross_attention:
644
- self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
645
- self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
646
- self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
647
- nn.init.zeros_(self.cross_attn_out_proj.weight)
648
- nn.init.zeros_(self.cross_attn_out_proj.bias)
649
-
650
- self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type)
651
- self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
652
- self.act_mlp = nn.GELU(approximate="tanh")
653
- self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
654
-
655
- def forward(
656
- self,
657
- hidden_states: torch.Tensor,
658
- encoder_hidden_states: torch.Tensor,
659
- temb: torch.Tensor,
660
- attention_mask: Optional[torch.Tensor] = None,
661
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
662
- token_replace_emb: torch.Tensor | None = None,
663
- first_frame_num_tokens: int | None = None,
664
- image_embed_seq_len: int = 0,
665
- encoder_attention_mask: torch.Tensor | None = None,
666
- ) -> torch.Tensor:
667
- text_seq_length = encoder_hidden_states.shape[1]
668
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
669
-
670
- residual = hidden_states
671
-
672
- # 1. Input normalization
673
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
674
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
675
-
676
- norm_hidden_states, norm_encoder_hidden_states = (
677
- norm_hidden_states[:, :-text_seq_length, :],
678
- norm_hidden_states[:, -text_seq_length:, :],
679
- )
680
-
681
- # 2. Attention
682
- attn_output, context_attn_output = self.attn(
683
- hidden_states=norm_hidden_states,
684
- encoder_hidden_states=norm_encoder_hidden_states,
685
- attention_mask=attention_mask,
686
- image_rotary_emb=image_rotary_emb,
687
- )
688
-
689
- # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
690
- if self.enable_text_cross_attention:
691
- txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
692
- text_mask = None
693
- if encoder_attention_mask is not None:
694
- text_mask = encoder_attention_mask[:, image_embed_seq_len:]
695
- text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
696
- cross_q = self.cross_attn_query_proj(attn_output)
697
- cross_output, _ = self.attn(
698
- hidden_states=cross_q,
699
- query_input=cross_q,
700
- key_input=txt_kv,
701
- value_input=txt_kv,
702
- attention_mask=text_mask,
703
- image_rotary_emb=image_rotary_emb,
704
- )
705
- attn_output = attn_output + self.cross_attn_out_proj(cross_output)
706
-
707
- attn_output = torch.cat([attn_output, context_attn_output], dim=1)
708
-
709
- # 3. Modulation and residual connection
710
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
711
- hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
712
- hidden_states = hidden_states + residual
713
-
714
- hidden_states, encoder_hidden_states = (
715
- hidden_states[:, :-text_seq_length, :],
716
- hidden_states[:, -text_seq_length:, :],
717
- )
718
- return hidden_states, encoder_hidden_states
719
-
720
-
721
- class MotifVideoTransformerBlock(nn.Module):
722
- def __init__(
723
- self,
724
- num_attention_heads: int,
725
- attention_head_dim: int,
726
- mlp_ratio: float,
727
- qk_norm: str = "rms_norm",
728
- norm_type: str = "layer_norm",
729
- enable_text_cross_attention: bool = False,
730
- ) -> None:
731
- super().__init__()
732
-
733
- hidden_size = num_attention_heads * attention_head_dim
734
-
735
- self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type)
736
- self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type)
737
-
738
- self.attn = Attention(
739
- query_dim=hidden_size,
740
- cross_attention_dim=None,
741
- added_kv_proj_dim=hidden_size,
742
- dim_head=attention_head_dim,
743
- heads=num_attention_heads,
744
- out_dim=hidden_size,
745
- context_pre_only=False,
746
- bias=True,
747
- processor=MotifVideoAttnProcessor2_0(),
748
- qk_norm=qk_norm,
749
- eps=1e-6,
750
- )
751
-
752
- self.enable_text_cross_attention = enable_text_cross_attention
753
- if enable_text_cross_attention:
754
- self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
755
- self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
756
- self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
757
- nn.init.zeros_(self.cross_attn_out_proj.weight)
758
- nn.init.zeros_(self.cross_attn_out_proj.bias)
759
-
760
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
761
- self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
762
-
763
- self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
764
- self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
765
-
766
- def forward(
767
- self,
768
- hidden_states: torch.Tensor,
769
- encoder_hidden_states: torch.Tensor,
770
- temb: torch.Tensor,
771
- attention_mask: Optional[torch.Tensor] = None,
772
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
773
- token_replace_emb: torch.Tensor | None = None,
774
- first_frame_num_tokens: int | None = None,
775
- image_embed_seq_len: int = 0,
776
- encoder_attention_mask: torch.Tensor | None = None,
777
- ) -> Tuple[torch.Tensor, torch.Tensor]:
778
- # 1. Input normalization
779
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
780
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
781
- encoder_hidden_states, emb=temb
782
- )
783
-
784
- # 2. Joint attention
785
- attn_output, context_attn_output = self.attn(
786
- hidden_states=norm_hidden_states,
787
- encoder_hidden_states=norm_encoder_hidden_states,
788
- attention_mask=attention_mask,
789
- image_rotary_emb=image_rotary_emb,
790
- )
791
-
792
- # 3. Modulation and residual connection
793
- hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
794
-
795
- # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
796
- if self.enable_text_cross_attention:
797
- txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
798
- text_mask = None
799
- if encoder_attention_mask is not None:
800
- text_mask = encoder_attention_mask[:, image_embed_seq_len:]
801
- text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
802
- cross_q = self.cross_attn_query_proj(attn_output)
803
- cross_output, _ = self.attn(
804
- hidden_states=cross_q,
805
- query_input=cross_q,
806
- key_input=txt_kv,
807
- value_input=txt_kv,
808
- attention_mask=text_mask,
809
- image_rotary_emb=image_rotary_emb,
810
- )
811
- hidden_states = hidden_states + self.cross_attn_out_proj(cross_output)
812
-
813
- encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
814
-
815
- norm_hidden_states = self.norm2(hidden_states)
816
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
817
-
818
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
819
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
820
-
821
- # 4. Feed-forward
822
- ff_output = self.ff(norm_hidden_states)
823
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
824
-
825
- hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
826
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
827
-
828
- return hidden_states, encoder_hidden_states
829
-
830
-
831
- TransformerBlockRegistry.register(
832
- model_class=MotifVideoTransformerBlock,
833
- metadata=TransformerBlockMetadata(
834
- return_hidden_states_index=0,
835
- return_encoder_hidden_states_index=1,
836
- ),
837
- )
838
- TransformerBlockRegistry.register(
839
- model_class=MotifVideoSingleTransformerBlock,
840
- metadata=TransformerBlockMetadata(
841
- return_hidden_states_index=0,
842
- return_encoder_hidden_states_index=1,
843
- ),
844
- )
845
-
846
-
847
- class MotifVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
848
- r"""
849
- A Transformer model for video-like data used in [MotifVideo](https://huggingface.co/motif/motifvideo).
850
-
851
- Args:
852
- in_channels (`int`, defaults to `16`):
853
- The number of channels in the input.
854
- out_channels (`int`, defaults to `16`):
855
- The number of channels in the output.
856
- num_attention_heads (`int`, defaults to `24`):
857
- The number of heads to use for multi-head attention.
858
- attention_head_dim (`int`, defaults to `128`):
859
- The number of channels in each head.
860
- num_layers (`int`, defaults to `20`):
861
- The number of layers of dual-stream blocks to use.
862
- num_single_layers (`int`, defaults to `40`):
863
- The number of layers of single-stream blocks to use.
864
-
865
- mlp_ratio (`float`, defaults to `4.0`):
866
- The ratio of the hidden layer size to the input size in the feedforward network.
867
- patch_size (`int`, defaults to `2`):
868
- The size of the spatial patches to use in the patch embedding layer.
869
- patch_size_t (`int`, defaults to `1`):
870
- The size of the temporal patches to use in the patch embedding layer.
871
- qk_norm (`str`, defaults to `rms_norm`):
872
- The normalization to use for the query and key projections in the attention layers.
873
- text_embed_dim (`int`, defaults to `4096`):
874
- Input dimension of text embeddings from the text encoder.
875
- rope_theta (`float`, defaults to `256.0`):
876
- The value of theta to use in the RoPE layer.
877
- rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
878
- The dimensions of the axes to use in the RoPE layer.
879
- base_latent_size (`int`, *optional*):
880
- The maximum spatial dimension (in latent units) seen during training.
881
- For example, if trained on 1280x1280 with a VAE downscale of 16, this is 80.
882
- """
883
-
884
- _supports_gradient_checkpointing = True
885
- _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
886
- _no_split_modules = [
887
- "MotifVideoTransformerBlock",
888
- "MotifVideoSingleTransformerBlock",
889
- "MotifVideoPatchEmbed",
890
- ]
891
-
892
- @register_to_config
893
- def __init__(
894
- self,
895
- in_channels: int = 33,
896
- out_channels: int = 16,
897
- num_attention_heads: int = 24,
898
- attention_head_dim: int = 128,
899
- num_layers: int = 20,
900
- num_single_layers: int = 40,
901
- num_decoder_layers: int = 0,
902
- mlp_ratio: float = 4.0,
903
- patch_size: int = 2,
904
- patch_size_t: int = 1,
905
- qk_norm: str = "rms_norm",
906
- norm_type: str = "layer_norm",
907
- text_embed_dim: int = 4096,
908
- image_embed_dim: int | None = None,
909
- pooled_projection_dim: int | None = None,
910
- rope_theta: float = 256.0,
911
- rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
912
- base_latent_size: int | None = None,
913
- enable_text_cross_attention_dual: bool = False,
914
- enable_text_cross_attention_single: bool = False,
915
- ) -> None:
916
- super().__init__()
917
-
918
- inner_dim = num_attention_heads * attention_head_dim
919
- out_channels = out_channels or in_channels
920
-
921
- # 1. Latent and condition embedders
922
- self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
923
- self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim)
924
-
925
- # First frame conditioning: Image conditioning embedders
926
- self.image_embed_dim = image_embed_dim
927
- if image_embed_dim is not None:
928
- # Project image embeddings from vision encoder to transformer dim
929
- self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim)
930
-
931
- self.time_text_embed = MotifVideoConditionEmbedding(inner_dim, pooled_projection_dim)
932
-
933
- # 2. RoPE
934
- self.rope = MotifVideoRotaryPosEmbed(
935
- patch_size, patch_size_t, rope_axes_dim, rope_theta, base_latent_size=base_latent_size
936
- )
937
-
938
- # Cross-attention config
939
- self.enable_text_cross_attention_dual = enable_text_cross_attention_dual
940
- self.enable_text_cross_attention_single = enable_text_cross_attention_single
941
-
942
- # 3. Dual stream transformer blocks
943
- self.transformer_blocks = nn.ModuleList(
944
- [
945
- MotifVideoTransformerBlock(
946
- num_attention_heads,
947
- attention_head_dim,
948
- mlp_ratio=mlp_ratio,
949
- qk_norm=qk_norm,
950
- norm_type=norm_type,
951
- enable_text_cross_attention=enable_text_cross_attention_dual,
952
- )
953
- for _ in range(num_layers)
954
- ]
955
- )
956
-
957
- # 4. Single stream transformer blocks
958
- # Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder)
959
- num_encoder_single = num_single_layers - num_decoder_layers
960
- self.single_transformer_blocks = nn.ModuleList(
961
- [
962
- MotifVideoSingleTransformerBlock(
963
- num_attention_heads,
964
- attention_head_dim,
965
- mlp_ratio=mlp_ratio,
966
- qk_norm=qk_norm,
967
- norm_type=norm_type,
968
- enable_text_cross_attention=enable_text_cross_attention_single
969
- if i < num_encoder_single
970
- else False,
971
- )
972
- for i in range(num_single_layers)
973
- ]
974
- )
975
-
976
- # 5. Output projection
977
- self.norm_out = AdaLayerNormContinuous(
978
- inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type=norm_type
979
- )
980
- self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
981
-
982
- # Verify cross-attention config matches actual block state.
983
- # Catches silent misconfiguration (e.g. checkpoint config with renamed keys).
984
- for i, block in enumerate(self.transformer_blocks):
985
- if block.enable_text_cross_attention != enable_text_cross_attention_dual:
986
- raise ValueError(
987
- f"transformer_blocks[{i}].enable_text_cross_attention="
988
- f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. "
989
- f"Check checkpoint config.json key names match __init__ parameters."
990
- )
991
- num_encoder_single = num_single_layers - num_decoder_layers
992
- for i, block in enumerate(self.single_transformer_blocks):
993
- expected = enable_text_cross_attention_single if i < num_encoder_single else False
994
- if block.enable_text_cross_attention != expected:
995
- raise ValueError(
996
- f"single_transformer_blocks[{i}].enable_text_cross_attention="
997
- f"{block.enable_text_cross_attention}, expected {expected}. "
998
- f"Check checkpoint config.json key names match __init__ parameters."
999
- )
1000
-
1001
- self.gradient_checkpointing = False
1002
- self.num_decoder_layers = num_decoder_layers
1003
-
1004
- @property
1005
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
1006
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
1007
- r"""
1008
- Returns:
1009
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
1010
- indexed by its weight name.
1011
- """
1012
- # set recursively
1013
- processors = {}
1014
-
1015
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1016
- if hasattr(module, "get_processor"):
1017
- processors[f"{name}.processor"] = module.get_processor()
1018
-
1019
- for sub_name, child in module.named_children():
1020
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1021
-
1022
- return processors
1023
-
1024
- for name, module in self.named_children():
1025
- fn_recursive_add_processors(name, module, processors)
1026
-
1027
- return processors
1028
-
1029
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1030
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
1031
- r"""
1032
- Sets the attention processor to use to compute attention.
1033
-
1034
- Parameters:
1035
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1036
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
1037
- for **all** `Attention` layers.
1038
-
1039
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1040
- processor. This is strongly recommended when setting trainable attention processors.
1041
-
1042
- """
1043
- count = len(self.attn_processors.keys())
1044
-
1045
- if isinstance(processor, dict) and len(processor) != count:
1046
- raise ValueError(
1047
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1048
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1049
- )
1050
-
1051
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1052
- if hasattr(module, "set_processor"):
1053
- if not isinstance(processor, dict):
1054
- module.set_processor(processor)
1055
- else:
1056
- module.set_processor(processor.pop(f"{name}.processor"))
1057
-
1058
- for sub_name, child in module.named_children():
1059
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1060
-
1061
- for name, module in self.named_children():
1062
- fn_recursive_attn_processor(name, module, processor)
1063
-
1064
- def _maybe_gradient_checkpoint_block(self, block, *args):
1065
- if torch.is_grad_enabled() and self.gradient_checkpointing:
1066
- return self._gradient_checkpointing_func(block, *args)
1067
- return block(*args)
1068
-
1069
- def _get_unwrapped_blocks(self, blocks):
1070
- if hasattr(blocks, "_checkpoint_wrapped_module"):
1071
- return blocks._checkpoint_wrapped_module
1072
- elif hasattr(blocks, "module"):
1073
- return blocks.module
1074
- return blocks
1075
-
1076
- def _create_attention_mask(
1077
- self,
1078
- hidden_states: torch.Tensor,
1079
- encoder_attention_mask: torch.Tensor,
1080
- ) -> torch.Tensor:
1081
- """
1082
- Create attention mask of shape [B, 1, 1, N] where N = L + E,
1083
- based on latent tokens (always valid) and the encoder mask.
1084
-
1085
- Args:
1086
- hidden_states: [B, L, D]
1087
- encoder_attention_mask: [B, E] (required)
1088
-
1089
- Returns:
1090
- attention_mask: [B, 1, 1, N]
1091
- """
1092
- attention_mask = F.pad(
1093
- encoder_attention_mask.to(torch.bool),
1094
- (hidden_states.shape[1], 0),
1095
- value=True,
1096
- )
1097
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L+E]
1098
- return attention_mask
1099
-
1100
- def forward(
1101
- self,
1102
- hidden_states: torch.Tensor,
1103
- timestep: torch.LongTensor,
1104
- encoder_hidden_states: torch.Tensor,
1105
- encoder_attention_mask: torch.Tensor | None = None,
1106
- pooled_projections: torch.Tensor | None = None,
1107
- image_embeds: torch.Tensor | None = None,
1108
- attention_kwargs: Optional[Dict[str, Any]] = None,
1109
- return_dict: bool = True,
1110
- tread_mixin: Optional[Any] = None,
1111
- tread_disabled: bool = False,
1112
- ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
1113
- """
1114
- Forward pass of the MotifVideoTransformer3DModel.
1115
-
1116
- Args:
1117
- hidden_states: Input latent tensor [B, C, F, H, W].
1118
- timestep: Diffusion timesteps [B].
1119
- encoder_hidden_states: Text conditioning [B, E, D].
1120
- encoder_attention_mask: Mask for text conditioning [B, E].
1121
- pooled_projections: Pooled text embeddings [B, D].
1122
- image_embeds: Optional image embeddings from vision encoder [B, N, D].
1123
- attention_kwargs: Additional arguments for attention processors.
1124
- return_dict: Whether to return a Transformer2DModelOutput.
1125
- tread_mixin: Optional TreadMixin instance for token reduction.
1126
- tread_disabled: When True, force tread_mixin to None (dense pass).
1127
- torch.compile specializes on this bool, producing separate graphs
1128
- for dense vs routed without attribute toggling.
1129
-
1130
- Returns:
1131
- Transformer2DModelOutput or tuple containing the predicted samples.
1132
- """
1133
- if tread_disabled:
1134
- tread_mixin = None
1135
- elif tread_mixin is None:
1136
- tread_mixin = getattr(self, "_inference_tread_mixin", None)
1137
-
1138
- if attention_kwargs is not None:
1139
- attention_kwargs = attention_kwargs.copy()
1140
- lora_scale = attention_kwargs.pop("scale", 1.0)
1141
- else:
1142
- lora_scale = 1.0
1143
-
1144
- if USE_PEFT_BACKEND:
1145
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1146
- scale_lora_layers(self, lora_scale)
1147
- else:
1148
- if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
1149
- logger.warning(
1150
- "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
1151
- )
1152
-
1153
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
1154
- p, p_t = self.config.patch_size, self.config.patch_size_t
1155
- post_patch_num_frames = num_frames // p_t
1156
- post_patch_height = height // p
1157
- post_patch_width = width // p
1158
- first_frame_num_tokens = 1 * post_patch_height * post_patch_width
1159
- # 1. RoPE
1160
- image_rotary_emb = self.rope(hidden_states, timestep=timestep)
1161
- # 2. Conditional embeddings
1162
- temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections)
1163
- hidden_states = self.x_embedder(hidden_states)
1164
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1165
-
1166
- # First frame conditioning: Image embeddings from vision encoder
1167
- if image_embeds is not None:
1168
- # image_embeds: [B, N, D_img] -> [B, N, D]
1169
- image_embeds = self.image_embedder(image_embeds)
1170
- encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1)
1171
- # Extend attention mask for image tokens
1172
- if encoder_attention_mask is not None:
1173
- image_mask = torch.ones(
1174
- image_embeds.shape[0],
1175
- image_embeds.shape[1],
1176
- device=encoder_attention_mask.device,
1177
- dtype=encoder_attention_mask.dtype,
1178
- )
1179
- encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1)
1180
-
1181
- # image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states
1182
- image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0
1183
-
1184
- decoder_hidden_states = hidden_states.clone()
1185
-
1186
- if encoder_attention_mask is not None:
1187
- attention_mask = self._create_attention_mask(
1188
- hidden_states=hidden_states,
1189
- encoder_attention_mask=encoder_attention_mask,
1190
- )
1191
- else:
1192
- attention_mask = None
1193
-
1194
- # TREAD state initialization: manage token reduction manually to support activation checkpointing
1195
- tread_active = False
1196
- current_route = None
1197
- ids_keep = None
1198
- x_full = None
1199
- orig_mask = attention_mask
1200
- orig_rope = image_rotary_emb
1201
- latent_len = hidden_states.shape[1]
1202
-
1203
- # 4. Dual stream transformer blocks (Encoder)
1204
- for i, block in enumerate(self.transformer_blocks):
1205
- # Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
1206
- if is_tread_start(tread_mixin, tread_active, i):
1207
- tread_active = True
1208
- current_route = tread_mixin._tread_route
1209
- # Reduce sequence length at the start of a TREAD route
1210
- ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
1211
- x_full = hidden_states.contiguous()
1212
- hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
1213
- attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1214
- image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1215
-
1216
- hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1217
- block,
1218
- hidden_states,
1219
- encoder_hidden_states,
1220
- temb,
1221
- attention_mask,
1222
- image_rotary_emb,
1223
- token_replace_emb,
1224
- first_frame_num_tokens,
1225
- image_embed_seq_len,
1226
- encoder_attention_mask,
1227
- )
1228
-
1229
- if is_tread_end(tread_mixin, tread_active, i):
1230
- # Restore full sequence length at the end of a TREAD route
1231
- hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
1232
- tread_active = False
1233
- current_route = None
1234
- ids_keep = None
1235
- x_full = None
1236
- attention_mask = orig_mask
1237
- image_rotary_emb = orig_rope
1238
-
1239
- # We need to unwrap the blocks because CheckpointWrapper does not support len(),
1240
- # which is required for slicing the blocks into encoder and decoder parts.
1241
- single_transformer_blocks = self.single_transformer_blocks
1242
-
1243
- # 5. Single stream transformer blocks (Encoder)
1244
- num_dual = len(self.transformer_blocks)
1245
- for i, block in enumerate(
1246
- single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]
1247
- ):
1248
- # Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
1249
- abs_i = num_dual + i
1250
- if is_tread_start(tread_mixin, tread_active, abs_i):
1251
- tread_active = True
1252
- current_route = tread_mixin._tread_route
1253
- # Reduce sequence length at the start of a TREAD route
1254
- ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
1255
- x_full = hidden_states.contiguous()
1256
- hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
1257
- attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1258
- image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1259
-
1260
- hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1261
- block,
1262
- hidden_states,
1263
- encoder_hidden_states,
1264
- temb,
1265
- attention_mask,
1266
- image_rotary_emb,
1267
- token_replace_emb,
1268
- first_frame_num_tokens,
1269
- image_embed_seq_len,
1270
- encoder_attention_mask,
1271
- )
1272
-
1273
- if is_tread_end(tread_mixin, tread_active, abs_i):
1274
- # Restore full sequence length at the end of a TREAD route
1275
- hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
1276
- tread_active = False
1277
- current_route = None
1278
- ids_keep = None
1279
- x_full = None
1280
- attention_mask = orig_mask
1281
- image_rotary_emb = orig_rope
1282
-
1283
- # 6. Single stream transformer blocks (Decoder)
1284
- if self.num_decoder_layers > 0:
1285
- encoder_hidden_states = hidden_states
1286
- attention_mask = None
1287
-
1288
- num_single = len(single_transformer_blocks)
1289
-
1290
- for i, block in enumerate(single_transformer_blocks[-self.num_decoder_layers :]):
1291
- abs_i = num_dual + (num_single - self.num_decoder_layers) + i
1292
- if is_tread_start(tread_mixin, tread_active, abs_i):
1293
- tread_active = True
1294
- current_route = tread_mixin._tread_route
1295
- # Reduce sequence length at the start of a TREAD route
1296
- ids_keep = tread_mixin.keep_indices(decoder_hidden_states, current_route["sel"]).to(
1297
- decoder_hidden_states.device
1298
- )
1299
- x_full = encoder_hidden_states.contiguous()
1300
- x_t_full = decoder_hidden_states.contiguous()
1301
- decoder_hidden_states = tread_mixin.gather_tokens(decoder_hidden_states, ids_keep)
1302
- encoder_hidden_states = tread_mixin.gather_tokens(encoder_hidden_states, ids_keep)
1303
- attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1304
- image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1305
-
1306
- decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1307
- block,
1308
- decoder_hidden_states,
1309
- encoder_hidden_states,
1310
- temb,
1311
- attention_mask,
1312
- image_rotary_emb,
1313
- token_replace_emb,
1314
- first_frame_num_tokens,
1315
- )
1316
-
1317
- if is_tread_end(tread_mixin, tread_active, abs_i):
1318
- # Restore full sequence length at the end of a TREAD route
1319
- decoder_hidden_states = tread_mixin.scatter_tokens(decoder_hidden_states, ids_keep, x_t_full)
1320
- encoder_hidden_states = tread_mixin.scatter_tokens(encoder_hidden_states, ids_keep, x_full)
1321
- tread_active = False
1322
- current_route = None
1323
- ids_keep = None
1324
- x_full = None
1325
- x_t_full = None
1326
- attention_mask = orig_mask
1327
- image_rotary_emb = orig_rope
1328
-
1329
- hidden_states = decoder_hidden_states
1330
-
1331
- # 7. Output projection
1332
- hidden_states = self.norm_out(hidden_states, temb)
1333
- hidden_states = self.proj_out(hidden_states)
1334
-
1335
- hidden_states = hidden_states.reshape(
1336
- batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
1337
- )
1338
- hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
1339
- hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
1340
-
1341
- if USE_PEFT_BACKEND:
1342
- # remove `lora_scale` from each PEFT layer
1343
- unscale_lora_layers(self, lora_scale)
1344
-
1345
- if not return_dict:
1346
- return (hidden_states,)
1347
-
1348
- return Transformer2DModelOutput(
1349
- sample=hidden_states,
1350
- )