Instructions to use Motif-Technologies/Motif-Video-2B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Motif-Technologies/Motif-Video-2B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Motif-Technologies/Motif-Video-2B", dtype=torch.bfloat16, device_map="cuda") prompt = "A vibrant blue jay perches gracefully on a slender branch, its feathers shimmering in the soft morning light. The bird's keen eyes scan the surroundings, capturing the essence of the tranquil forest. It flutters its wings briefly, showcasing the intricate patterns of blue, white, and black on its plumage. The background reveals a lush canopy of green leaves, with rays of sunlight filtering through, creating a dappled effect on the forest floor. The blue jay then tilts its head, emitting a melodious call that echoes through the serene woodland, adding a touch of magic to the peaceful scene." image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
diffusers integration
#23
by kencwt - opened
- README.md +1 -3
- inference.py +0 -210
- model_index.json +1 -1
- transformer/config.json +0 -2
- transformer/transformer_motif_video.py +0 -1350
README.md
CHANGED
|
@@ -185,7 +185,6 @@ guider = AdaptiveProjectedGuidance(
|
|
| 185 |
|
| 186 |
pipe = MotifVideoPipeline.from_pretrained(
|
| 187 |
"Motif-Technologies/Motif-Video-2B",
|
| 188 |
-
revision="diffusers-integration",
|
| 189 |
torch_dtype=torch.bfloat16,
|
| 190 |
guider=guider,
|
| 191 |
)
|
|
@@ -234,9 +233,8 @@ guider = AdaptiveProjectedGuidance(
|
|
| 234 |
normalization_dims="spatial",
|
| 235 |
)
|
| 236 |
|
| 237 |
-
pipe =
|
| 238 |
"Motif-Technologies/Motif-Video-2B",
|
| 239 |
-
revision="diffusers-integration",
|
| 240 |
torch_dtype=torch.bfloat16,
|
| 241 |
guider=guider,
|
| 242 |
)
|
|
|
|
| 185 |
|
| 186 |
pipe = MotifVideoPipeline.from_pretrained(
|
| 187 |
"Motif-Technologies/Motif-Video-2B",
|
|
|
|
| 188 |
torch_dtype=torch.bfloat16,
|
| 189 |
guider=guider,
|
| 190 |
)
|
|
|
|
| 233 |
normalization_dims="spatial",
|
| 234 |
)
|
| 235 |
|
| 236 |
+
pipe = MotifVideoImage2VideoPipeline.from_pretrained(
|
| 237 |
"Motif-Technologies/Motif-Video-2B",
|
|
|
|
| 238 |
torch_dtype=torch.bfloat16,
|
| 239 |
guider=guider,
|
| 240 |
)
|
inference.py
DELETED
|
@@ -1,210 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Motif-Video 2B — Text-to-Video inference.
|
| 3 |
-
|
| 4 |
-
GPU requirements: ~24GB VRAM for 720p (1280x736, 121 frames).
|
| 5 |
-
Requires: torch, diffusers (with MotifVideoPipeline), transformers>=5.5.4,
|
| 6 |
-
accelerate, ftfy, einops, sentencepiece, regex
|
| 7 |
-
|
| 8 |
-
Uses Adaptive Projected Guidance (APG) and DPMSolver++ scheduler by default.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
import argparse
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
from diffusers import (
|
| 15 |
-
AdaptiveProjectedGuidance,
|
| 16 |
-
DPMSolverMultistepScheduler,
|
| 17 |
-
MotifVideoPipeline,
|
| 18 |
-
)
|
| 19 |
-
from diffusers.utils import export_to_video
|
| 20 |
-
|
| 21 |
-
_DEFAULT_NEGATIVE_PROMPT = (
|
| 22 |
-
"text overlay, graphic overlay, watermark, logo, subtitles, timestamp, "
|
| 23 |
-
"broadcast graphics, UI elements, random letters, frozen pose, rigid, "
|
| 24 |
-
"static expression, jerky motion, mechanical motion, discontinuous motion, "
|
| 25 |
-
"flat framing, depthless, dull lighting, monotone, crushed shadows, "
|
| 26 |
-
"blown-out highlights, shifting background, fading background, poor continuity, "
|
| 27 |
-
"identity drift, deformation, flickering, ghosting, smearing, duplication, "
|
| 28 |
-
"mutated proportions, inconsistent clothing, flat colors, desaturated, "
|
| 29 |
-
"tonally compressed, poor background separation, exposure shift, "
|
| 30 |
-
"uneven brightness, color balance shift"
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def parse_args():
|
| 35 |
-
parser = argparse.ArgumentParser(description="Motif-Video 2B Inference (T2V)")
|
| 36 |
-
parser.add_argument(
|
| 37 |
-
"--model-path",
|
| 38 |
-
type=str,
|
| 39 |
-
default="Motif-Technologies/Motif-Video-2B",
|
| 40 |
-
help="HuggingFace model ID or local checkpoint path",
|
| 41 |
-
)
|
| 42 |
-
parser.add_argument(
|
| 43 |
-
"--prompt",
|
| 44 |
-
type=str,
|
| 45 |
-
default="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
|
| 46 |
-
help="Text prompt for video generation",
|
| 47 |
-
)
|
| 48 |
-
parser.add_argument(
|
| 49 |
-
"--negative-prompt",
|
| 50 |
-
type=str,
|
| 51 |
-
default=_DEFAULT_NEGATIVE_PROMPT,
|
| 52 |
-
help="Negative prompt",
|
| 53 |
-
)
|
| 54 |
-
parser.add_argument("--output", type=str, default="output.mp4", help="Output video file path")
|
| 55 |
-
parser.add_argument("--num-frames", type=int, default=121, help="Number of frames to generate (121 = ~5s at 24fps)")
|
| 56 |
-
parser.add_argument("--height", type=int, default=736, help="Video height in pixels")
|
| 57 |
-
parser.add_argument("--width", type=int, default=1280, help="Video width in pixels")
|
| 58 |
-
parser.add_argument("--guidance-scale", type=float, default=8.0, help="Classifier-free guidance scale")
|
| 59 |
-
parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of denoising steps")
|
| 60 |
-
parser.add_argument("--fps", type=int, default=24, help="Output video frame rate")
|
| 61 |
-
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
| 62 |
-
parser.add_argument(
|
| 63 |
-
"--dtype",
|
| 64 |
-
type=str,
|
| 65 |
-
default="bfloat16",
|
| 66 |
-
choices=["float16", "bfloat16", "float32"],
|
| 67 |
-
help="Model dtype",
|
| 68 |
-
)
|
| 69 |
-
parser.add_argument(
|
| 70 |
-
"--use-sage-attention",
|
| 71 |
-
action="store_true",
|
| 72 |
-
help="Enable SageAttention for ~2x faster attention (requires: pip install sageattention>=2.1.1 from GitHub source)",
|
| 73 |
-
)
|
| 74 |
-
return parser.parse_args()
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def _enable_sage_attention(transformer):
|
| 78 |
-
"""Patch transformer attention to use SageAttention.
|
| 79 |
-
|
| 80 |
-
Only patches _compute_attention (self-attention path). Cross-attention
|
| 81 |
-
uses _handle_cross_attention_mode which calls F.sdpa directly and is
|
| 82 |
-
unaffected by this patch.
|
| 83 |
-
|
| 84 |
-
Mask handling follows motif-models dispatch_optimized_attention pattern:
|
| 85 |
-
- mask=None: sage directly
|
| 86 |
-
- mask with uniform active length: slice active region -> sage -> pad back
|
| 87 |
-
- mask with non-uniform active length: SDPA fallback
|
| 88 |
-
"""
|
| 89 |
-
from sageattention import sageattn
|
| 90 |
-
from diffusers.models.transformers.transformer_motif_video import MotifVideoAttnProcessor2_0
|
| 91 |
-
|
| 92 |
-
_orig_compute = MotifVideoAttnProcessor2_0._compute_attention
|
| 93 |
-
|
| 94 |
-
def _sage_compute(self, query, key, value, attention_mask):
|
| 95 |
-
if attention_mask is None:
|
| 96 |
-
out = sageattn(
|
| 97 |
-
query.contiguous(), key.contiguous(), value.contiguous(),
|
| 98 |
-
tensor_layout="HND", is_causal=False,
|
| 99 |
-
)
|
| 100 |
-
out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
|
| 101 |
-
return out
|
| 102 |
-
|
| 103 |
-
# Find active token count from mask (shape: [B, 1, 1, S])
|
| 104 |
-
padding_indices = attention_mask.sum(dim=-1).long().flatten()
|
| 105 |
-
common_padding_index = padding_indices[0]
|
| 106 |
-
is_uniform = (padding_indices == common_padding_index).all()
|
| 107 |
-
|
| 108 |
-
if not is_uniform:
|
| 109 |
-
return _orig_compute(self, query, key, value, attention_mask)
|
| 110 |
-
|
| 111 |
-
active_len = common_padding_index.item()
|
| 112 |
-
S = query.shape[2]
|
| 113 |
-
|
| 114 |
-
if active_len == S:
|
| 115 |
-
out = sageattn(
|
| 116 |
-
query.contiguous(), key.contiguous(), value.contiguous(),
|
| 117 |
-
tensor_layout="HND", is_causal=False,
|
| 118 |
-
)
|
| 119 |
-
out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
|
| 120 |
-
return out
|
| 121 |
-
|
| 122 |
-
# Slice to active region, run sage, pad back
|
| 123 |
-
q_a = query[:, :, :active_len, :].contiguous()
|
| 124 |
-
k_a = key[:, :, :active_len, :].contiguous()
|
| 125 |
-
v_a = value[:, :, :active_len, :].contiguous()
|
| 126 |
-
|
| 127 |
-
out_a = sageattn(q_a, k_a, v_a, tensor_layout="HND", is_causal=False)
|
| 128 |
-
|
| 129 |
-
out = query.new_zeros(query.shape)
|
| 130 |
-
out[:, :, :active_len, :] = out_a
|
| 131 |
-
out = out.transpose(1, 2).flatten(2, 3).to(query.dtype)
|
| 132 |
-
return out
|
| 133 |
-
|
| 134 |
-
MotifVideoAttnProcessor2_0._compute_attention = _sage_compute
|
| 135 |
-
transformer.to(memory_format=torch.channels_last_3d)
|
| 136 |
-
print("[SageAttention] Enabled (patched _compute_attention + channels_last_3d)")
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def main():
|
| 140 |
-
args = parse_args()
|
| 141 |
-
|
| 142 |
-
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
|
| 143 |
-
torch_dtype = dtype_map[args.dtype]
|
| 144 |
-
|
| 145 |
-
print(f"[T2V] Loading model from: {args.model_path}")
|
| 146 |
-
|
| 147 |
-
guider = AdaptiveProjectedGuidance(
|
| 148 |
-
guidance_scale=args.guidance_scale,
|
| 149 |
-
adaptive_projected_guidance_rescale=12.0,
|
| 150 |
-
adaptive_projected_guidance_momentum=0.1,
|
| 151 |
-
use_original_formulation=True,
|
| 152 |
-
normalization_dims="spatial",
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
pipe = MotifVideoPipeline.from_pretrained(
|
| 156 |
-
args.model_path,
|
| 157 |
-
torch_dtype=torch_dtype,
|
| 158 |
-
guider=guider,
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
# Replace scheduler with DPMSolver++ for faster convergence and better quality.
|
| 162 |
-
# Subclass ignores pipeline-supplied sigmas (PR branch always passes them)
|
| 163 |
-
# and uses its own flow-matching sigma schedule instead.
|
| 164 |
-
class _FlowDPMSolver(DPMSolverMultistepScheduler):
|
| 165 |
-
def set_timesteps(self, num_inference_steps=None, device=None,
|
| 166 |
-
sigmas=None, mu=None, timesteps=None):
|
| 167 |
-
if sigmas is not None and num_inference_steps is None:
|
| 168 |
-
num_inference_steps = len(sigmas)
|
| 169 |
-
super().set_timesteps(
|
| 170 |
-
num_inference_steps=num_inference_steps,
|
| 171 |
-
device=device, timesteps=timesteps,
|
| 172 |
-
)
|
| 173 |
-
|
| 174 |
-
pipe.scheduler = _FlowDPMSolver(
|
| 175 |
-
num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
|
| 176 |
-
algorithm_type="dpmsolver++",
|
| 177 |
-
solver_order=2,
|
| 178 |
-
prediction_type="flow_prediction",
|
| 179 |
-
use_flow_sigmas=True,
|
| 180 |
-
flow_shift=15.0,
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
# Offload model components to CPU between uses to reduce peak VRAM
|
| 184 |
-
pipe.enable_model_cpu_offload()
|
| 185 |
-
|
| 186 |
-
if args.use_sage_attention:
|
| 187 |
-
_enable_sage_attention(pipe.transformer)
|
| 188 |
-
|
| 189 |
-
generator = torch.Generator(device="cuda").manual_seed(args.seed)
|
| 190 |
-
|
| 191 |
-
print(f"Generating video: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps")
|
| 192 |
-
output = pipe(
|
| 193 |
-
prompt=args.prompt,
|
| 194 |
-
negative_prompt=args.negative_prompt,
|
| 195 |
-
height=args.height,
|
| 196 |
-
width=args.width,
|
| 197 |
-
num_frames=args.num_frames,
|
| 198 |
-
num_inference_steps=args.num_inference_steps,
|
| 199 |
-
frame_rate=args.fps,
|
| 200 |
-
use_linear_quadratic_schedule=False,
|
| 201 |
-
generator=generator,
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
video_frames = output.frames[0]
|
| 205 |
-
export_to_video(video_frames, args.output, fps=args.fps)
|
| 206 |
-
print(f"Video saved to: {args.output}")
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
if __name__ == "__main__":
|
| 210 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_index.json
CHANGED
|
@@ -14,7 +14,7 @@
|
|
| 14 |
"GemmaTokenizer"
|
| 15 |
],
|
| 16 |
"transformer": [
|
| 17 |
-
"
|
| 18 |
"MotifVideoTransformer3DModel"
|
| 19 |
],
|
| 20 |
"vae": [
|
|
|
|
| 14 |
"GemmaTokenizer"
|
| 15 |
],
|
| 16 |
"transformer": [
|
| 17 |
+
"diffusers",
|
| 18 |
"MotifVideoTransformer3DModel"
|
| 19 |
],
|
| 20 |
"vae": [
|
transformer/config.json
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
| 4 |
"_library": "diffusers",
|
| 5 |
"attention_head_dim": 128,
|
| 6 |
-
"base_latent_size": null,
|
| 7 |
"image_embed_dim": 1152,
|
| 8 |
"in_channels": 33,
|
| 9 |
"mlp_ratio": 4.0,
|
|
@@ -15,7 +14,6 @@
|
|
| 15 |
"out_channels": 16,
|
| 16 |
"patch_size": 2,
|
| 17 |
"patch_size_t": 1,
|
| 18 |
-
"pooled_projection_dim": null,
|
| 19 |
"qk_norm": "rms_norm",
|
| 20 |
"rope_axes_dim": [
|
| 21 |
16,
|
|
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
| 4 |
"_library": "diffusers",
|
| 5 |
"attention_head_dim": 128,
|
|
|
|
| 6 |
"image_embed_dim": 1152,
|
| 7 |
"in_channels": 33,
|
| 8 |
"mlp_ratio": 4.0,
|
|
|
|
| 14 |
"out_channels": 16,
|
| 15 |
"patch_size": 2,
|
| 16 |
"patch_size_t": 1,
|
|
|
|
| 17 |
"qk_norm": "rms_norm",
|
| 18 |
"rope_axes_dim": [
|
| 19 |
16,
|
transformer/transformer_motif_video.py
DELETED
|
@@ -1,1350 +0,0 @@
|
|
| 1 |
-
# Copyright 2026 Motif Technologies. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import math
|
| 16 |
-
from functools import lru_cache
|
| 17 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
import torch
|
| 21 |
-
import torch.nn as nn
|
| 22 |
-
import torch.nn.functional as F
|
| 23 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
-
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
|
| 25 |
-
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 26 |
-
from diffusers.models.attention import FeedForward
|
| 27 |
-
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 28 |
-
from diffusers.models.cache_utils import CacheMixin
|
| 29 |
-
from diffusers.models.embeddings import (
|
| 30 |
-
PixArtAlphaTextProjection,
|
| 31 |
-
TimestepEmbedding,
|
| 32 |
-
Timesteps,
|
| 33 |
-
)
|
| 34 |
-
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 35 |
-
from diffusers.models.modeling_utils import ModelMixin
|
| 36 |
-
from diffusers.models.normalization import (
|
| 37 |
-
AdaLayerNormContinuous,
|
| 38 |
-
AdaLayerNormZero,
|
| 39 |
-
AdaLayerNormZeroSingle,
|
| 40 |
-
)
|
| 41 |
-
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 42 |
-
|
| 43 |
-
# Stub functions for TREAD (Token REduction with Approximated Distillation).
|
| 44 |
-
# These stubs ensure TREAD code paths are never activated during inference
|
| 45 |
-
# without requiring the motif_core package.
|
| 46 |
-
def is_tread_start(block_idx, start, end): return False
|
| 47 |
-
def is_tread_end(block_idx, start, end): return False
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 51 |
-
|
| 52 |
-
NUM_TRAIN_TIMESTEPS = 1000
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def apply_rotary_emb(
|
| 56 |
-
x: torch.Tensor,
|
| 57 |
-
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
| 58 |
-
use_real: bool = True,
|
| 59 |
-
use_real_unbind_dim: int = -1,
|
| 60 |
-
) -> torch.Tensor:
|
| 61 |
-
"""
|
| 62 |
-
Apply rotary positional embeddings (RoPE) to input tensors.
|
| 63 |
-
|
| 64 |
-
This implementation supports both standard 2D RoPE tensors [L, Dh] and batched 4D RoPE
|
| 65 |
-
tensors [B, 1, L, Dh] for compatibility with TREAD's token-dropping mechanism where
|
| 66 |
-
different batches may have different token subsets.
|
| 67 |
-
|
| 68 |
-
Args:
|
| 69 |
-
x: Input tensor of shape [B, H, L, Dh].
|
| 70 |
-
freqs_cis: Tuple of (cos, sin) tensors. Supports shapes [L, Dh] or [B, 1, L, Dh].
|
| 71 |
-
use_real: Whether to use real-valued RoPE implementation.
|
| 72 |
-
use_real_unbind_dim: Dimension to unbind when using real-valued RoPE (-1 or -2).
|
| 73 |
-
|
| 74 |
-
Returns:
|
| 75 |
-
Tensor with rotary embeddings applied, same shape as input x.
|
| 76 |
-
"""
|
| 77 |
-
if use_real:
|
| 78 |
-
cos, sin = freqs_cis
|
| 79 |
-
if cos.dim() == 2: # [L, Dh] → [1, 1, L, Dh]
|
| 80 |
-
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 81 |
-
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 82 |
-
if cos.dim() != 4 or sin.dim() != 4:
|
| 83 |
-
raise RuntimeError(f"RoPE must be 2D or 4D, got cos={cos.dim()}D, sin={sin.dim()}D")
|
| 84 |
-
|
| 85 |
-
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 86 |
-
|
| 87 |
-
if cos.size(-2) != x.size(-2) or cos.size(-1) != x.size(-1):
|
| 88 |
-
raise RuntimeError(
|
| 89 |
-
f"RoPE shape mismatch: rope[-2:]=({cos.size(-2)},{cos.size(-1)}) vs x[-2:]=({x.size(-2)},{x.size(-1)})"
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
if use_real_unbind_dim == -1:
|
| 93 |
-
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| 94 |
-
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 95 |
-
elif use_real_unbind_dim == -2:
|
| 96 |
-
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
|
| 97 |
-
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 98 |
-
else:
|
| 99 |
-
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 100 |
-
|
| 101 |
-
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 102 |
-
return out
|
| 103 |
-
else:
|
| 104 |
-
x_rot = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 105 |
-
freqs = freqs_cis.unsqueeze(2)
|
| 106 |
-
x_out = torch.view_as_real(x_rot * freqs).flatten(3)
|
| 107 |
-
return x_out.type_as(x)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class MotifVideoAttnProcessor2_0:
|
| 111 |
-
def __init__(self):
|
| 112 |
-
if not hasattr(F, "scaled_dot_product_attention"):
|
| 113 |
-
raise ImportError(
|
| 114 |
-
"MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
def __call__(
|
| 118 |
-
self,
|
| 119 |
-
attn: Attention,
|
| 120 |
-
hidden_states: torch.Tensor,
|
| 121 |
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 122 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 123 |
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 124 |
-
query_input: Optional[torch.Tensor] = None,
|
| 125 |
-
key_input: Optional[torch.Tensor] = None,
|
| 126 |
-
value_input: Optional[torch.Tensor] = None,
|
| 127 |
-
) -> torch.Tensor:
|
| 128 |
-
# Cross-attention mode: query already projected externally (cross_attn_query_proj + norm),
|
| 129 |
-
# skip to_q and only apply reshape + norm_q + RoPE. K/V use to_k/to_v as normal.
|
| 130 |
-
if query_input is not None:
|
| 131 |
-
query = query_input.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 132 |
-
key = attn.to_k(key_input)
|
| 133 |
-
value = attn.to_v(value_input)
|
| 134 |
-
|
| 135 |
-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 136 |
-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 137 |
-
|
| 138 |
-
if attn.norm_q is not None:
|
| 139 |
-
query = attn.norm_q(query)
|
| 140 |
-
if attn.norm_k is not None:
|
| 141 |
-
key = attn.norm_k(key)
|
| 142 |
-
|
| 143 |
-
if image_rotary_emb is not None:
|
| 144 |
-
query = apply_rotary_emb(query, image_rotary_emb)
|
| 145 |
-
|
| 146 |
-
hidden_states = F.scaled_dot_product_attention(
|
| 147 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 148 |
-
)
|
| 149 |
-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 150 |
-
hidden_states = hidden_states.to(query.dtype)
|
| 151 |
-
return hidden_states, None
|
| 152 |
-
|
| 153 |
-
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 154 |
-
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 155 |
-
|
| 156 |
-
# 1. QKV projections
|
| 157 |
-
query = attn.to_q(hidden_states)
|
| 158 |
-
key = attn.to_k(hidden_states)
|
| 159 |
-
value = attn.to_v(hidden_states)
|
| 160 |
-
|
| 161 |
-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 162 |
-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 163 |
-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 164 |
-
|
| 165 |
-
# 2. QK normalization
|
| 166 |
-
if attn.norm_q is not None:
|
| 167 |
-
query = attn.norm_q(query)
|
| 168 |
-
if attn.norm_k is not None:
|
| 169 |
-
key = attn.norm_k(key)
|
| 170 |
-
|
| 171 |
-
# 3. Rotational positional embeddings applied to latent stream
|
| 172 |
-
if image_rotary_emb is not None:
|
| 173 |
-
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 174 |
-
query = torch.cat(
|
| 175 |
-
[
|
| 176 |
-
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 177 |
-
query[:, :, -encoder_hidden_states.shape[1] :],
|
| 178 |
-
],
|
| 179 |
-
dim=2,
|
| 180 |
-
)
|
| 181 |
-
key = torch.cat(
|
| 182 |
-
[
|
| 183 |
-
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 184 |
-
key[:, :, -encoder_hidden_states.shape[1] :],
|
| 185 |
-
],
|
| 186 |
-
dim=2,
|
| 187 |
-
)
|
| 188 |
-
else:
|
| 189 |
-
query = apply_rotary_emb(query, image_rotary_emb)
|
| 190 |
-
key = apply_rotary_emb(key, image_rotary_emb)
|
| 191 |
-
|
| 192 |
-
# 4. Encoder condition QKV projection and normalization
|
| 193 |
-
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
| 194 |
-
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 195 |
-
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 196 |
-
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 197 |
-
|
| 198 |
-
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 199 |
-
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 200 |
-
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 201 |
-
|
| 202 |
-
if attn.norm_added_q is not None:
|
| 203 |
-
encoder_query = attn.norm_added_q(encoder_query)
|
| 204 |
-
if attn.norm_added_k is not None:
|
| 205 |
-
encoder_key = attn.norm_added_k(encoder_key)
|
| 206 |
-
|
| 207 |
-
query = torch.cat([query, encoder_query], dim=2)
|
| 208 |
-
key = torch.cat([key, encoder_key], dim=2)
|
| 209 |
-
value = torch.cat([value, encoder_value], dim=2)
|
| 210 |
-
|
| 211 |
-
# 5. Attention
|
| 212 |
-
hidden_states = F.scaled_dot_product_attention(
|
| 213 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 214 |
-
)
|
| 215 |
-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 216 |
-
hidden_states = hidden_states.to(query.dtype)
|
| 217 |
-
|
| 218 |
-
# 6. Output projection
|
| 219 |
-
if encoder_hidden_states is not None:
|
| 220 |
-
hidden_states, encoder_hidden_states = (
|
| 221 |
-
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
| 222 |
-
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
if getattr(attn, "to_out", None) is not None:
|
| 226 |
-
hidden_states = attn.to_out[0](hidden_states)
|
| 227 |
-
hidden_states = attn.to_out[1](hidden_states)
|
| 228 |
-
|
| 229 |
-
if getattr(attn, "to_add_out", None) is not None:
|
| 230 |
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 231 |
-
|
| 232 |
-
return hidden_states, encoder_hidden_states
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
class MotifVideoPatchEmbed(nn.Module):
|
| 236 |
-
def __init__(
|
| 237 |
-
self,
|
| 238 |
-
patch_size: Union[int, Tuple[int, int, int]] = 16,
|
| 239 |
-
in_chans: int = 3,
|
| 240 |
-
embed_dim: int = 768,
|
| 241 |
-
) -> None:
|
| 242 |
-
super().__init__()
|
| 243 |
-
|
| 244 |
-
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
| 245 |
-
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 246 |
-
|
| 247 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 248 |
-
hidden_states = self.proj(hidden_states)
|
| 249 |
-
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
| 250 |
-
return hidden_states
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
class MotifVideoAdaNorm(nn.Module):
|
| 254 |
-
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
| 255 |
-
super().__init__()
|
| 256 |
-
|
| 257 |
-
out_features = out_features or 2 * in_features
|
| 258 |
-
self.linear = nn.Linear(in_features, out_features)
|
| 259 |
-
self.nonlinearity = nn.SiLU()
|
| 260 |
-
|
| 261 |
-
def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 262 |
-
temb = self.linear(self.nonlinearity(temb))
|
| 263 |
-
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
| 264 |
-
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
| 265 |
-
return gate_msa, gate_mlp
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
class MotifVideoConditionEmbedding(nn.Module):
|
| 269 |
-
def __init__(
|
| 270 |
-
self,
|
| 271 |
-
embedding_dim: int,
|
| 272 |
-
pooled_projection_dim: int | None,
|
| 273 |
-
):
|
| 274 |
-
super().__init__()
|
| 275 |
-
|
| 276 |
-
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 277 |
-
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 278 |
-
|
| 279 |
-
if isinstance(pooled_projection_dim, int):
|
| 280 |
-
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
| 281 |
-
|
| 282 |
-
def forward(
|
| 283 |
-
self,
|
| 284 |
-
timestep: torch.Tensor,
|
| 285 |
-
pooled_projection: torch.Tensor | None = None,
|
| 286 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 287 |
-
timesteps_proj = self.time_proj(timestep)
|
| 288 |
-
timestep_embedder_dtype = next(self.timestep_embedder.parameters()).dtype
|
| 289 |
-
conditioning = self.timestep_embedder(timesteps_proj.to(timestep_embedder_dtype)) # (N, D)
|
| 290 |
-
if pooled_projection is not None:
|
| 291 |
-
conditioning = conditioning + self.text_embedder(pooled_projection)
|
| 292 |
-
|
| 293 |
-
token_replace_emb = None
|
| 294 |
-
|
| 295 |
-
return conditioning, token_replace_emb
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L485-L486
|
| 299 |
-
def find_correction_factor(num_rotations, dim, base, max_position_embeddings):
|
| 300 |
-
dtype = num_rotations.dtype if isinstance(num_rotations, torch.Tensor) else torch.float32
|
| 301 |
-
max_pos_tensor = torch.as_tensor(max_position_embeddings, dtype=dtype)
|
| 302 |
-
return (dim * torch.log(max_pos_tensor / (num_rotations * 2 * math.pi))) / (
|
| 303 |
-
2 * math.log(base)
|
| 304 |
-
) # Inverse dim formula to find number of rotations
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L489-L495
|
| 308 |
-
def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len):
|
| 309 |
-
"""
|
| 310 |
-
Find the correction range for NTK-by-parts interpolation.
|
| 311 |
-
"""
|
| 312 |
-
low = torch.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len))
|
| 313 |
-
high = torch.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len))
|
| 314 |
-
low = torch.clamp(low, min=0)
|
| 315 |
-
high = torch.clamp(high, max=dim - 1)
|
| 316 |
-
return low, high # Clamp values just in case
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L498-L504
|
| 320 |
-
def linear_ramp_mask(min_val, max_val, num_dim):
|
| 321 |
-
if isinstance(min_val, torch.Tensor):
|
| 322 |
-
if (min_val == max_val).all():
|
| 323 |
-
max_val = max_val + 0.001
|
| 324 |
-
elif min_val == max_val:
|
| 325 |
-
max_val += 0.001
|
| 326 |
-
|
| 327 |
-
linear_func = (torch.arange(num_dim, dtype=torch.float32) - min_val) / (max_val - min_val)
|
| 328 |
-
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 329 |
-
return ramp_func
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L507-L511
|
| 333 |
-
def find_newbase_ntk(dim, base, scale):
|
| 334 |
-
"""
|
| 335 |
-
Calculate the new base for NTK-aware scaling.
|
| 336 |
-
"""
|
| 337 |
-
# Avoid division by zero when dim == 2 (or invalid smaller values).
|
| 338 |
-
# In these degenerate cases, fall back to the original base (no NTK adjustment).
|
| 339 |
-
if dim <= 2:
|
| 340 |
-
return base
|
| 341 |
-
return base * (scale ** (dim / (dim - 2)))
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L514-L652
|
| 345 |
-
def get_1d_rotary_pos_embed(
|
| 346 |
-
dim: int,
|
| 347 |
-
pos: Union[np.ndarray, int],
|
| 348 |
-
theta: float = 10000.0,
|
| 349 |
-
use_real=False,
|
| 350 |
-
linear_factor=1.0,
|
| 351 |
-
ntk_factor=1.0,
|
| 352 |
-
repeat_interleave_real=True,
|
| 353 |
-
freqs_dtype=torch.float32,
|
| 354 |
-
yarn=False,
|
| 355 |
-
max_pe_len=None,
|
| 356 |
-
ori_max_pe_len=64,
|
| 357 |
-
dype=False,
|
| 358 |
-
current_timestep=1.0,
|
| 359 |
-
):
|
| 360 |
-
"""
|
| 361 |
-
Precompute the frequency tensor for complex exponentials with RoPE.
|
| 362 |
-
Supports YARN interpolation for vision transformers.
|
| 363 |
-
|
| 364 |
-
Args:
|
| 365 |
-
dim (`int`):
|
| 366 |
-
Dimension of the frequency tensor.
|
| 367 |
-
pos (`np.ndarray` or `int`):
|
| 368 |
-
Position indices for the frequency tensor. [S] or scalar.
|
| 369 |
-
theta (`float`, *optional*, defaults to 10000.0):
|
| 370 |
-
Scaling factor for frequency computation.
|
| 371 |
-
use_real (`bool`, *optional*, defaults to False):
|
| 372 |
-
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 373 |
-
linear_factor (`float`, *optional*, defaults to 1.0):
|
| 374 |
-
Scaling factor for linear interpolation.
|
| 375 |
-
ntk_factor (`float`, *optional*, defaults to 1.0):
|
| 376 |
-
Scaling factor for NTK-Aware RoPE.
|
| 377 |
-
repeat_interleave_real (`bool`, *optional*, defaults to True):
|
| 378 |
-
If True and use_real, real and imaginary parts are interleaved with themselves to reach dim.
|
| 379 |
-
Otherwise, they are concatenated.
|
| 380 |
-
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
| 381 |
-
Data type of the frequency tensor.
|
| 382 |
-
yarn (`bool`, *optional*, defaults to False):
|
| 383 |
-
If True, use YARN interpolation combining NTK, linear, and base methods.
|
| 384 |
-
max_pe_len (`int`, *optional*):
|
| 385 |
-
Maximum position encoding length (current patches for vision models).
|
| 386 |
-
ori_max_pe_len (`int`, *optional*, defaults to 64):
|
| 387 |
-
Original maximum position encoding length (base patches for vision models).
|
| 388 |
-
dype (`bool`, *optional*, defaults to False):
|
| 389 |
-
If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
|
| 390 |
-
current_timestep (`float`, *optional*, defaults to 1.0):
|
| 391 |
-
Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
|
| 392 |
-
|
| 393 |
-
Returns:
|
| 394 |
-
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| 395 |
-
If use_real=True, returns tuple of (cos, sin) tensors.
|
| 396 |
-
"""
|
| 397 |
-
assert dim % 2 == 0
|
| 398 |
-
|
| 399 |
-
if isinstance(pos, int):
|
| 400 |
-
pos = torch.arange(pos)
|
| 401 |
-
if isinstance(pos, np.ndarray):
|
| 402 |
-
pos = torch.from_numpy(pos)
|
| 403 |
-
|
| 404 |
-
device = pos.device
|
| 405 |
-
|
| 406 |
-
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
| 407 |
-
if not isinstance(max_pe_len, torch.Tensor):
|
| 408 |
-
max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device)
|
| 409 |
-
|
| 410 |
-
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
| 411 |
-
|
| 412 |
-
beta_0 = 1.25
|
| 413 |
-
beta_1 = 0.75
|
| 414 |
-
gamma_0 = 16
|
| 415 |
-
gamma_1 = 2
|
| 416 |
-
|
| 417 |
-
freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
|
| 418 |
-
|
| 419 |
-
freqs_linear = 1.0 / torch.einsum(
|
| 420 |
-
"..., f -> ... f",
|
| 421 |
-
scale,
|
| 422 |
-
(theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)),
|
| 423 |
-
)
|
| 424 |
-
|
| 425 |
-
new_base = find_newbase_ntk(dim, theta, scale)
|
| 426 |
-
if new_base.dim() > 0:
|
| 427 |
-
new_base = new_base.view(-1, 1)
|
| 428 |
-
freqs_ntk = 1.0 / torch.pow(new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
|
| 429 |
-
if freqs_ntk.dim() > 1:
|
| 430 |
-
freqs_ntk = freqs_ntk.squeeze()
|
| 431 |
-
|
| 432 |
-
if dype:
|
| 433 |
-
beta_0 = torch.pow(beta_0, 2.0 * torch.pow(current_timestep, 2.0))
|
| 434 |
-
beta_1 = torch.pow(beta_1, 2.0 * torch.pow(current_timestep, 2.0))
|
| 435 |
-
|
| 436 |
-
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
| 437 |
-
high = torch.clamp(high, max=dim // 2)
|
| 438 |
-
|
| 439 |
-
freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
|
| 440 |
-
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
| 441 |
-
|
| 442 |
-
if dype:
|
| 443 |
-
gamma_0 = torch.pow(gamma_0, 2.0 * torch.pow(current_timestep, 2.0))
|
| 444 |
-
gamma_1 = torch.pow(gamma_1, 2.0 * torch.pow(current_timestep, 2.0))
|
| 445 |
-
|
| 446 |
-
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
| 447 |
-
high = torch.clamp(high, max=dim // 2)
|
| 448 |
-
|
| 449 |
-
freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
|
| 450 |
-
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
|
| 451 |
-
|
| 452 |
-
else:
|
| 453 |
-
theta_ntk = theta * ntk_factor
|
| 454 |
-
freqs = 1.0 / (theta_ntk ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)) / linear_factor
|
| 455 |
-
|
| 456 |
-
freqs = torch.outer(pos, freqs)
|
| 457 |
-
|
| 458 |
-
is_npu = freqs.device.type == "npu"
|
| 459 |
-
if is_npu:
|
| 460 |
-
freqs = freqs.float()
|
| 461 |
-
|
| 462 |
-
if use_real and repeat_interleave_real:
|
| 463 |
-
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
|
| 464 |
-
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
|
| 465 |
-
|
| 466 |
-
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
| 467 |
-
mscale = torch.where(scale <= 1.0, 1.0, 0.1 * torch.log(scale) + 1.0).to(scale)
|
| 468 |
-
freqs_cos = freqs_cos * mscale
|
| 469 |
-
freqs_sin = freqs_sin * mscale
|
| 470 |
-
|
| 471 |
-
return freqs_cos, freqs_sin
|
| 472 |
-
elif use_real:
|
| 473 |
-
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
|
| 474 |
-
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
|
| 475 |
-
return freqs_cos, freqs_sin
|
| 476 |
-
else:
|
| 477 |
-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 478 |
-
return freqs_cis
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
class MotifVideoRotaryPosEmbed(nn.Module):
|
| 482 |
-
def __init__(
|
| 483 |
-
self,
|
| 484 |
-
patch_size: int,
|
| 485 |
-
patch_size_t: int,
|
| 486 |
-
rope_dim: List[int],
|
| 487 |
-
theta: float = 256.0,
|
| 488 |
-
base_latent_size: int | None = None,
|
| 489 |
-
):
|
| 490 |
-
"""
|
| 491 |
-
Rotary Positional Embedding (RoPE) for video latents.
|
| 492 |
-
|
| 493 |
-
Args:
|
| 494 |
-
patch_size (`int`):
|
| 495 |
-
Spatial patch size (e.g., 2).
|
| 496 |
-
patch_size_t (`int`):
|
| 497 |
-
Temporal patch size (e.g., 1).
|
| 498 |
-
rope_dim (`List[int]`):
|
| 499 |
-
Dimensions for RoPE across [Time, Height, Width] axes.
|
| 500 |
-
theta (`float`, *optional*, defaults to 256.0):
|
| 501 |
-
Base frequency for rotary embeddings.
|
| 502 |
-
base_latent_size (`int`, *optional*):
|
| 503 |
-
The maximum spatial dimension (in latent units) seen during training,
|
| 504 |
-
i.e. `training_resolution / vae_scale_factor_spatial`.
|
| 505 |
-
For example, for 1280x1280 training images and a VAE spatial downscale
|
| 506 |
-
(`vae_scale_factor_spatial`) of 8, this would be 160; for a downscale
|
| 507 |
-
of 16, it would be 80.
|
| 508 |
-
"""
|
| 509 |
-
super().__init__()
|
| 510 |
-
|
| 511 |
-
self.patch_size = patch_size
|
| 512 |
-
self.patch_size_t = patch_size_t
|
| 513 |
-
self.rope_dim = rope_dim
|
| 514 |
-
self.theta = theta
|
| 515 |
-
self.base_latent_size = base_latent_size
|
| 516 |
-
|
| 517 |
-
@lru_cache(maxsize=8)
|
| 518 |
-
def _get_base_patch_grid_size(self, base_latent_size: Optional[int], patch_size: int) -> Optional[int]:
|
| 519 |
-
return base_latent_size // patch_size if base_latent_size else None
|
| 520 |
-
|
| 521 |
-
@lru_cache(maxsize=8)
|
| 522 |
-
def _get_dynamic_interpolation_scale(self, h: int, w: int, base_grid_size: int) -> float:
|
| 523 |
-
return math.sqrt(h * w / (base_grid_size**2))
|
| 524 |
-
|
| 525 |
-
def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 526 |
-
if self.training:
|
| 527 |
-
assert self.base_latent_size is None, (
|
| 528 |
-
"RoPE interpolation/extrapolation logic should only be enabled for inference. "
|
| 529 |
-
f"During training, base_latent_size must be None, but got {self.base_latent_size!r}."
|
| 530 |
-
)
|
| 531 |
-
|
| 532 |
-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 533 |
-
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
| 534 |
-
|
| 535 |
-
axes_grids = []
|
| 536 |
-
for i in range(3):
|
| 537 |
-
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
| 538 |
-
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
| 539 |
-
# differences in layerwise debugging outputs, but visually it is the same.
|
| 540 |
-
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
| 541 |
-
axes_grids.append(grid)
|
| 542 |
-
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
|
| 543 |
-
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
| 544 |
-
|
| 545 |
-
base_patch_grid_size = self._get_base_patch_grid_size(self.base_latent_size, self.patch_size)
|
| 546 |
-
if base_patch_grid_size is not None:
|
| 547 |
-
if base_patch_grid_size <= 0:
|
| 548 |
-
raise ValueError(f"base_patch_grid_size must be a positive number, got {base_patch_grid_size}.")
|
| 549 |
-
dynamic_interpolation_scale = self._get_dynamic_interpolation_scale(
|
| 550 |
-
rope_sizes[1], rope_sizes[2], base_patch_grid_size
|
| 551 |
-
)
|
| 552 |
-
|
| 553 |
-
normalized_timestep = torch.tensor(1.0)
|
| 554 |
-
if not self.training and timestep is not None:
|
| 555 |
-
normalized_timestep = timestep[0] / NUM_TRAIN_TIMESTEPS
|
| 556 |
-
|
| 557 |
-
freqs = []
|
| 558 |
-
for i in range(3):
|
| 559 |
-
common_kwargs = {
|
| 560 |
-
"dim": self.rope_dim[i],
|
| 561 |
-
"pos": grid[i].reshape(-1),
|
| 562 |
-
"theta": self.theta,
|
| 563 |
-
"use_real": True,
|
| 564 |
-
"freqs_dtype": torch.float64,
|
| 565 |
-
}
|
| 566 |
-
|
| 567 |
-
# Apply scaling only to spatial dimensions (Height and Width, i=1 and i=2)
|
| 568 |
-
if i > 0 and base_patch_grid_size is not None and dynamic_interpolation_scale > 1.0:
|
| 569 |
-
# We project the training base to the current size using the uniform scale factor.
|
| 570 |
-
# max_pe_len tells the RoPE logic the "new" maximum length it's dealing with.
|
| 571 |
-
max_pe_len = torch.tensor(
|
| 572 |
-
base_patch_grid_size * dynamic_interpolation_scale,
|
| 573 |
-
dtype=torch.float64,
|
| 574 |
-
device=hidden_states.device,
|
| 575 |
-
)
|
| 576 |
-
|
| 577 |
-
freq = get_1d_rotary_pos_embed(
|
| 578 |
-
**common_kwargs,
|
| 579 |
-
yarn=True, # Enable Yet Another RoPE extensioN (YARN) for extrapolation
|
| 580 |
-
max_pe_len=max_pe_len,
|
| 581 |
-
ori_max_pe_len=base_patch_grid_size, # The original training scale
|
| 582 |
-
dype=True, # Enable Dynamic Position Encoding (time-aware)
|
| 583 |
-
current_timestep=normalized_timestep,
|
| 584 |
-
)
|
| 585 |
-
else:
|
| 586 |
-
# Time dimension OR within training bounds -> Standard RoPE
|
| 587 |
-
freq = get_1d_rotary_pos_embed(**common_kwargs)
|
| 588 |
-
|
| 589 |
-
freqs.append(freq)
|
| 590 |
-
|
| 591 |
-
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 592 |
-
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 593 |
-
return freqs_cos, freqs_sin
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
class MotifVideoImageProjection(nn.Module):
|
| 597 |
-
def __init__(self, in_features: int, hidden_size: int):
|
| 598 |
-
super().__init__()
|
| 599 |
-
self.norm_in = nn.LayerNorm(in_features)
|
| 600 |
-
self.linear_1 = nn.Linear(in_features, in_features)
|
| 601 |
-
self.act_fn = nn.GELU()
|
| 602 |
-
self.linear_2 = nn.Linear(in_features, hidden_size)
|
| 603 |
-
self.norm_out = nn.LayerNorm(hidden_size)
|
| 604 |
-
|
| 605 |
-
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
|
| 606 |
-
hidden_states = self.norm_in(image_embeds)
|
| 607 |
-
hidden_states = self.linear_1(hidden_states)
|
| 608 |
-
hidden_states = self.act_fn(hidden_states)
|
| 609 |
-
hidden_states = self.linear_2(hidden_states)
|
| 610 |
-
hidden_states = self.norm_out(hidden_states)
|
| 611 |
-
return hidden_states
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
class MotifVideoSingleTransformerBlock(nn.Module):
|
| 615 |
-
def __init__(
|
| 616 |
-
self,
|
| 617 |
-
num_attention_heads: int,
|
| 618 |
-
attention_head_dim: int,
|
| 619 |
-
mlp_ratio: float = 4.0,
|
| 620 |
-
qk_norm: str = "rms_norm",
|
| 621 |
-
norm_type: str = "layer_norm",
|
| 622 |
-
enable_text_cross_attention: bool = False,
|
| 623 |
-
) -> None:
|
| 624 |
-
super().__init__()
|
| 625 |
-
|
| 626 |
-
hidden_size = num_attention_heads * attention_head_dim
|
| 627 |
-
mlp_dim = int(hidden_size * mlp_ratio)
|
| 628 |
-
|
| 629 |
-
self.attn = Attention(
|
| 630 |
-
query_dim=hidden_size,
|
| 631 |
-
cross_attention_dim=None,
|
| 632 |
-
dim_head=attention_head_dim,
|
| 633 |
-
heads=num_attention_heads,
|
| 634 |
-
out_dim=hidden_size,
|
| 635 |
-
bias=True,
|
| 636 |
-
processor=MotifVideoAttnProcessor2_0(),
|
| 637 |
-
qk_norm=qk_norm,
|
| 638 |
-
eps=1e-6,
|
| 639 |
-
pre_only=True,
|
| 640 |
-
)
|
| 641 |
-
|
| 642 |
-
self.enable_text_cross_attention = enable_text_cross_attention
|
| 643 |
-
if enable_text_cross_attention:
|
| 644 |
-
self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
|
| 645 |
-
self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 646 |
-
self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
|
| 647 |
-
nn.init.zeros_(self.cross_attn_out_proj.weight)
|
| 648 |
-
nn.init.zeros_(self.cross_attn_out_proj.bias)
|
| 649 |
-
|
| 650 |
-
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type)
|
| 651 |
-
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
| 652 |
-
self.act_mlp = nn.GELU(approximate="tanh")
|
| 653 |
-
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
| 654 |
-
|
| 655 |
-
def forward(
|
| 656 |
-
self,
|
| 657 |
-
hidden_states: torch.Tensor,
|
| 658 |
-
encoder_hidden_states: torch.Tensor,
|
| 659 |
-
temb: torch.Tensor,
|
| 660 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 661 |
-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 662 |
-
token_replace_emb: torch.Tensor | None = None,
|
| 663 |
-
first_frame_num_tokens: int | None = None,
|
| 664 |
-
image_embed_seq_len: int = 0,
|
| 665 |
-
encoder_attention_mask: torch.Tensor | None = None,
|
| 666 |
-
) -> torch.Tensor:
|
| 667 |
-
text_seq_length = encoder_hidden_states.shape[1]
|
| 668 |
-
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 669 |
-
|
| 670 |
-
residual = hidden_states
|
| 671 |
-
|
| 672 |
-
# 1. Input normalization
|
| 673 |
-
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 674 |
-
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 675 |
-
|
| 676 |
-
norm_hidden_states, norm_encoder_hidden_states = (
|
| 677 |
-
norm_hidden_states[:, :-text_seq_length, :],
|
| 678 |
-
norm_hidden_states[:, -text_seq_length:, :],
|
| 679 |
-
)
|
| 680 |
-
|
| 681 |
-
# 2. Attention
|
| 682 |
-
attn_output, context_attn_output = self.attn(
|
| 683 |
-
hidden_states=norm_hidden_states,
|
| 684 |
-
encoder_hidden_states=norm_encoder_hidden_states,
|
| 685 |
-
attention_mask=attention_mask,
|
| 686 |
-
image_rotary_emb=image_rotary_emb,
|
| 687 |
-
)
|
| 688 |
-
|
| 689 |
-
# Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
|
| 690 |
-
if self.enable_text_cross_attention:
|
| 691 |
-
txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
|
| 692 |
-
text_mask = None
|
| 693 |
-
if encoder_attention_mask is not None:
|
| 694 |
-
text_mask = encoder_attention_mask[:, image_embed_seq_len:]
|
| 695 |
-
text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
|
| 696 |
-
cross_q = self.cross_attn_query_proj(attn_output)
|
| 697 |
-
cross_output, _ = self.attn(
|
| 698 |
-
hidden_states=cross_q,
|
| 699 |
-
query_input=cross_q,
|
| 700 |
-
key_input=txt_kv,
|
| 701 |
-
value_input=txt_kv,
|
| 702 |
-
attention_mask=text_mask,
|
| 703 |
-
image_rotary_emb=image_rotary_emb,
|
| 704 |
-
)
|
| 705 |
-
attn_output = attn_output + self.cross_attn_out_proj(cross_output)
|
| 706 |
-
|
| 707 |
-
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
| 708 |
-
|
| 709 |
-
# 3. Modulation and residual connection
|
| 710 |
-
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 711 |
-
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
|
| 712 |
-
hidden_states = hidden_states + residual
|
| 713 |
-
|
| 714 |
-
hidden_states, encoder_hidden_states = (
|
| 715 |
-
hidden_states[:, :-text_seq_length, :],
|
| 716 |
-
hidden_states[:, -text_seq_length:, :],
|
| 717 |
-
)
|
| 718 |
-
return hidden_states, encoder_hidden_states
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
class MotifVideoTransformerBlock(nn.Module):
|
| 722 |
-
def __init__(
|
| 723 |
-
self,
|
| 724 |
-
num_attention_heads: int,
|
| 725 |
-
attention_head_dim: int,
|
| 726 |
-
mlp_ratio: float,
|
| 727 |
-
qk_norm: str = "rms_norm",
|
| 728 |
-
norm_type: str = "layer_norm",
|
| 729 |
-
enable_text_cross_attention: bool = False,
|
| 730 |
-
) -> None:
|
| 731 |
-
super().__init__()
|
| 732 |
-
|
| 733 |
-
hidden_size = num_attention_heads * attention_head_dim
|
| 734 |
-
|
| 735 |
-
self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type)
|
| 736 |
-
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type)
|
| 737 |
-
|
| 738 |
-
self.attn = Attention(
|
| 739 |
-
query_dim=hidden_size,
|
| 740 |
-
cross_attention_dim=None,
|
| 741 |
-
added_kv_proj_dim=hidden_size,
|
| 742 |
-
dim_head=attention_head_dim,
|
| 743 |
-
heads=num_attention_heads,
|
| 744 |
-
out_dim=hidden_size,
|
| 745 |
-
context_pre_only=False,
|
| 746 |
-
bias=True,
|
| 747 |
-
processor=MotifVideoAttnProcessor2_0(),
|
| 748 |
-
qk_norm=qk_norm,
|
| 749 |
-
eps=1e-6,
|
| 750 |
-
)
|
| 751 |
-
|
| 752 |
-
self.enable_text_cross_attention = enable_text_cross_attention
|
| 753 |
-
if enable_text_cross_attention:
|
| 754 |
-
self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
|
| 755 |
-
self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 756 |
-
self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
|
| 757 |
-
nn.init.zeros_(self.cross_attn_out_proj.weight)
|
| 758 |
-
nn.init.zeros_(self.cross_attn_out_proj.bias)
|
| 759 |
-
|
| 760 |
-
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 761 |
-
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 762 |
-
|
| 763 |
-
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 764 |
-
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 765 |
-
|
| 766 |
-
def forward(
|
| 767 |
-
self,
|
| 768 |
-
hidden_states: torch.Tensor,
|
| 769 |
-
encoder_hidden_states: torch.Tensor,
|
| 770 |
-
temb: torch.Tensor,
|
| 771 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 772 |
-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 773 |
-
token_replace_emb: torch.Tensor | None = None,
|
| 774 |
-
first_frame_num_tokens: int | None = None,
|
| 775 |
-
image_embed_seq_len: int = 0,
|
| 776 |
-
encoder_attention_mask: torch.Tensor | None = None,
|
| 777 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 778 |
-
# 1. Input normalization
|
| 779 |
-
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 780 |
-
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 781 |
-
encoder_hidden_states, emb=temb
|
| 782 |
-
)
|
| 783 |
-
|
| 784 |
-
# 2. Joint attention
|
| 785 |
-
attn_output, context_attn_output = self.attn(
|
| 786 |
-
hidden_states=norm_hidden_states,
|
| 787 |
-
encoder_hidden_states=norm_encoder_hidden_states,
|
| 788 |
-
attention_mask=attention_mask,
|
| 789 |
-
image_rotary_emb=image_rotary_emb,
|
| 790 |
-
)
|
| 791 |
-
|
| 792 |
-
# 3. Modulation and residual connection
|
| 793 |
-
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
| 794 |
-
|
| 795 |
-
# Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
|
| 796 |
-
if self.enable_text_cross_attention:
|
| 797 |
-
txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
|
| 798 |
-
text_mask = None
|
| 799 |
-
if encoder_attention_mask is not None:
|
| 800 |
-
text_mask = encoder_attention_mask[:, image_embed_seq_len:]
|
| 801 |
-
text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
|
| 802 |
-
cross_q = self.cross_attn_query_proj(attn_output)
|
| 803 |
-
cross_output, _ = self.attn(
|
| 804 |
-
hidden_states=cross_q,
|
| 805 |
-
query_input=cross_q,
|
| 806 |
-
key_input=txt_kv,
|
| 807 |
-
value_input=txt_kv,
|
| 808 |
-
attention_mask=text_mask,
|
| 809 |
-
image_rotary_emb=image_rotary_emb,
|
| 810 |
-
)
|
| 811 |
-
hidden_states = hidden_states + self.cross_attn_out_proj(cross_output)
|
| 812 |
-
|
| 813 |
-
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
| 814 |
-
|
| 815 |
-
norm_hidden_states = self.norm2(hidden_states)
|
| 816 |
-
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 817 |
-
|
| 818 |
-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 819 |
-
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 820 |
-
|
| 821 |
-
# 4. Feed-forward
|
| 822 |
-
ff_output = self.ff(norm_hidden_states)
|
| 823 |
-
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 824 |
-
|
| 825 |
-
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
| 826 |
-
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 827 |
-
|
| 828 |
-
return hidden_states, encoder_hidden_states
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
TransformerBlockRegistry.register(
|
| 832 |
-
model_class=MotifVideoTransformerBlock,
|
| 833 |
-
metadata=TransformerBlockMetadata(
|
| 834 |
-
return_hidden_states_index=0,
|
| 835 |
-
return_encoder_hidden_states_index=1,
|
| 836 |
-
),
|
| 837 |
-
)
|
| 838 |
-
TransformerBlockRegistry.register(
|
| 839 |
-
model_class=MotifVideoSingleTransformerBlock,
|
| 840 |
-
metadata=TransformerBlockMetadata(
|
| 841 |
-
return_hidden_states_index=0,
|
| 842 |
-
return_encoder_hidden_states_index=1,
|
| 843 |
-
),
|
| 844 |
-
)
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
class MotifVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
| 848 |
-
r"""
|
| 849 |
-
A Transformer model for video-like data used in [MotifVideo](https://huggingface.co/motif/motifvideo).
|
| 850 |
-
|
| 851 |
-
Args:
|
| 852 |
-
in_channels (`int`, defaults to `16`):
|
| 853 |
-
The number of channels in the input.
|
| 854 |
-
out_channels (`int`, defaults to `16`):
|
| 855 |
-
The number of channels in the output.
|
| 856 |
-
num_attention_heads (`int`, defaults to `24`):
|
| 857 |
-
The number of heads to use for multi-head attention.
|
| 858 |
-
attention_head_dim (`int`, defaults to `128`):
|
| 859 |
-
The number of channels in each head.
|
| 860 |
-
num_layers (`int`, defaults to `20`):
|
| 861 |
-
The number of layers of dual-stream blocks to use.
|
| 862 |
-
num_single_layers (`int`, defaults to `40`):
|
| 863 |
-
The number of layers of single-stream blocks to use.
|
| 864 |
-
|
| 865 |
-
mlp_ratio (`float`, defaults to `4.0`):
|
| 866 |
-
The ratio of the hidden layer size to the input size in the feedforward network.
|
| 867 |
-
patch_size (`int`, defaults to `2`):
|
| 868 |
-
The size of the spatial patches to use in the patch embedding layer.
|
| 869 |
-
patch_size_t (`int`, defaults to `1`):
|
| 870 |
-
The size of the temporal patches to use in the patch embedding layer.
|
| 871 |
-
qk_norm (`str`, defaults to `rms_norm`):
|
| 872 |
-
The normalization to use for the query and key projections in the attention layers.
|
| 873 |
-
text_embed_dim (`int`, defaults to `4096`):
|
| 874 |
-
Input dimension of text embeddings from the text encoder.
|
| 875 |
-
rope_theta (`float`, defaults to `256.0`):
|
| 876 |
-
The value of theta to use in the RoPE layer.
|
| 877 |
-
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 878 |
-
The dimensions of the axes to use in the RoPE layer.
|
| 879 |
-
base_latent_size (`int`, *optional*):
|
| 880 |
-
The maximum spatial dimension (in latent units) seen during training.
|
| 881 |
-
For example, if trained on 1280x1280 with a VAE downscale of 16, this is 80.
|
| 882 |
-
"""
|
| 883 |
-
|
| 884 |
-
_supports_gradient_checkpointing = True
|
| 885 |
-
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
| 886 |
-
_no_split_modules = [
|
| 887 |
-
"MotifVideoTransformerBlock",
|
| 888 |
-
"MotifVideoSingleTransformerBlock",
|
| 889 |
-
"MotifVideoPatchEmbed",
|
| 890 |
-
]
|
| 891 |
-
|
| 892 |
-
@register_to_config
|
| 893 |
-
def __init__(
|
| 894 |
-
self,
|
| 895 |
-
in_channels: int = 33,
|
| 896 |
-
out_channels: int = 16,
|
| 897 |
-
num_attention_heads: int = 24,
|
| 898 |
-
attention_head_dim: int = 128,
|
| 899 |
-
num_layers: int = 20,
|
| 900 |
-
num_single_layers: int = 40,
|
| 901 |
-
num_decoder_layers: int = 0,
|
| 902 |
-
mlp_ratio: float = 4.0,
|
| 903 |
-
patch_size: int = 2,
|
| 904 |
-
patch_size_t: int = 1,
|
| 905 |
-
qk_norm: str = "rms_norm",
|
| 906 |
-
norm_type: str = "layer_norm",
|
| 907 |
-
text_embed_dim: int = 4096,
|
| 908 |
-
image_embed_dim: int | None = None,
|
| 909 |
-
pooled_projection_dim: int | None = None,
|
| 910 |
-
rope_theta: float = 256.0,
|
| 911 |
-
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
| 912 |
-
base_latent_size: int | None = None,
|
| 913 |
-
enable_text_cross_attention_dual: bool = False,
|
| 914 |
-
enable_text_cross_attention_single: bool = False,
|
| 915 |
-
) -> None:
|
| 916 |
-
super().__init__()
|
| 917 |
-
|
| 918 |
-
inner_dim = num_attention_heads * attention_head_dim
|
| 919 |
-
out_channels = out_channels or in_channels
|
| 920 |
-
|
| 921 |
-
# 1. Latent and condition embedders
|
| 922 |
-
self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
| 923 |
-
self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim)
|
| 924 |
-
|
| 925 |
-
# First frame conditioning: Image conditioning embedders
|
| 926 |
-
self.image_embed_dim = image_embed_dim
|
| 927 |
-
if image_embed_dim is not None:
|
| 928 |
-
# Project image embeddings from vision encoder to transformer dim
|
| 929 |
-
self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim)
|
| 930 |
-
|
| 931 |
-
self.time_text_embed = MotifVideoConditionEmbedding(inner_dim, pooled_projection_dim)
|
| 932 |
-
|
| 933 |
-
# 2. RoPE
|
| 934 |
-
self.rope = MotifVideoRotaryPosEmbed(
|
| 935 |
-
patch_size, patch_size_t, rope_axes_dim, rope_theta, base_latent_size=base_latent_size
|
| 936 |
-
)
|
| 937 |
-
|
| 938 |
-
# Cross-attention config
|
| 939 |
-
self.enable_text_cross_attention_dual = enable_text_cross_attention_dual
|
| 940 |
-
self.enable_text_cross_attention_single = enable_text_cross_attention_single
|
| 941 |
-
|
| 942 |
-
# 3. Dual stream transformer blocks
|
| 943 |
-
self.transformer_blocks = nn.ModuleList(
|
| 944 |
-
[
|
| 945 |
-
MotifVideoTransformerBlock(
|
| 946 |
-
num_attention_heads,
|
| 947 |
-
attention_head_dim,
|
| 948 |
-
mlp_ratio=mlp_ratio,
|
| 949 |
-
qk_norm=qk_norm,
|
| 950 |
-
norm_type=norm_type,
|
| 951 |
-
enable_text_cross_attention=enable_text_cross_attention_dual,
|
| 952 |
-
)
|
| 953 |
-
for _ in range(num_layers)
|
| 954 |
-
]
|
| 955 |
-
)
|
| 956 |
-
|
| 957 |
-
# 4. Single stream transformer blocks
|
| 958 |
-
# Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder)
|
| 959 |
-
num_encoder_single = num_single_layers - num_decoder_layers
|
| 960 |
-
self.single_transformer_blocks = nn.ModuleList(
|
| 961 |
-
[
|
| 962 |
-
MotifVideoSingleTransformerBlock(
|
| 963 |
-
num_attention_heads,
|
| 964 |
-
attention_head_dim,
|
| 965 |
-
mlp_ratio=mlp_ratio,
|
| 966 |
-
qk_norm=qk_norm,
|
| 967 |
-
norm_type=norm_type,
|
| 968 |
-
enable_text_cross_attention=enable_text_cross_attention_single
|
| 969 |
-
if i < num_encoder_single
|
| 970 |
-
else False,
|
| 971 |
-
)
|
| 972 |
-
for i in range(num_single_layers)
|
| 973 |
-
]
|
| 974 |
-
)
|
| 975 |
-
|
| 976 |
-
# 5. Output projection
|
| 977 |
-
self.norm_out = AdaLayerNormContinuous(
|
| 978 |
-
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type=norm_type
|
| 979 |
-
)
|
| 980 |
-
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
| 981 |
-
|
| 982 |
-
# Verify cross-attention config matches actual block state.
|
| 983 |
-
# Catches silent misconfiguration (e.g. checkpoint config with renamed keys).
|
| 984 |
-
for i, block in enumerate(self.transformer_blocks):
|
| 985 |
-
if block.enable_text_cross_attention != enable_text_cross_attention_dual:
|
| 986 |
-
raise ValueError(
|
| 987 |
-
f"transformer_blocks[{i}].enable_text_cross_attention="
|
| 988 |
-
f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. "
|
| 989 |
-
f"Check checkpoint config.json key names match __init__ parameters."
|
| 990 |
-
)
|
| 991 |
-
num_encoder_single = num_single_layers - num_decoder_layers
|
| 992 |
-
for i, block in enumerate(self.single_transformer_blocks):
|
| 993 |
-
expected = enable_text_cross_attention_single if i < num_encoder_single else False
|
| 994 |
-
if block.enable_text_cross_attention != expected:
|
| 995 |
-
raise ValueError(
|
| 996 |
-
f"single_transformer_blocks[{i}].enable_text_cross_attention="
|
| 997 |
-
f"{block.enable_text_cross_attention}, expected {expected}. "
|
| 998 |
-
f"Check checkpoint config.json key names match __init__ parameters."
|
| 999 |
-
)
|
| 1000 |
-
|
| 1001 |
-
self.gradient_checkpointing = False
|
| 1002 |
-
self.num_decoder_layers = num_decoder_layers
|
| 1003 |
-
|
| 1004 |
-
@property
|
| 1005 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 1006 |
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 1007 |
-
r"""
|
| 1008 |
-
Returns:
|
| 1009 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 1010 |
-
indexed by its weight name.
|
| 1011 |
-
"""
|
| 1012 |
-
# set recursively
|
| 1013 |
-
processors = {}
|
| 1014 |
-
|
| 1015 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 1016 |
-
if hasattr(module, "get_processor"):
|
| 1017 |
-
processors[f"{name}.processor"] = module.get_processor()
|
| 1018 |
-
|
| 1019 |
-
for sub_name, child in module.named_children():
|
| 1020 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 1021 |
-
|
| 1022 |
-
return processors
|
| 1023 |
-
|
| 1024 |
-
for name, module in self.named_children():
|
| 1025 |
-
fn_recursive_add_processors(name, module, processors)
|
| 1026 |
-
|
| 1027 |
-
return processors
|
| 1028 |
-
|
| 1029 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 1030 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 1031 |
-
r"""
|
| 1032 |
-
Sets the attention processor to use to compute attention.
|
| 1033 |
-
|
| 1034 |
-
Parameters:
|
| 1035 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 1036 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 1037 |
-
for **all** `Attention` layers.
|
| 1038 |
-
|
| 1039 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 1040 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
| 1041 |
-
|
| 1042 |
-
"""
|
| 1043 |
-
count = len(self.attn_processors.keys())
|
| 1044 |
-
|
| 1045 |
-
if isinstance(processor, dict) and len(processor) != count:
|
| 1046 |
-
raise ValueError(
|
| 1047 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 1048 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 1049 |
-
)
|
| 1050 |
-
|
| 1051 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 1052 |
-
if hasattr(module, "set_processor"):
|
| 1053 |
-
if not isinstance(processor, dict):
|
| 1054 |
-
module.set_processor(processor)
|
| 1055 |
-
else:
|
| 1056 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
| 1057 |
-
|
| 1058 |
-
for sub_name, child in module.named_children():
|
| 1059 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 1060 |
-
|
| 1061 |
-
for name, module in self.named_children():
|
| 1062 |
-
fn_recursive_attn_processor(name, module, processor)
|
| 1063 |
-
|
| 1064 |
-
def _maybe_gradient_checkpoint_block(self, block, *args):
|
| 1065 |
-
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1066 |
-
return self._gradient_checkpointing_func(block, *args)
|
| 1067 |
-
return block(*args)
|
| 1068 |
-
|
| 1069 |
-
def _get_unwrapped_blocks(self, blocks):
|
| 1070 |
-
if hasattr(blocks, "_checkpoint_wrapped_module"):
|
| 1071 |
-
return blocks._checkpoint_wrapped_module
|
| 1072 |
-
elif hasattr(blocks, "module"):
|
| 1073 |
-
return blocks.module
|
| 1074 |
-
return blocks
|
| 1075 |
-
|
| 1076 |
-
def _create_attention_mask(
|
| 1077 |
-
self,
|
| 1078 |
-
hidden_states: torch.Tensor,
|
| 1079 |
-
encoder_attention_mask: torch.Tensor,
|
| 1080 |
-
) -> torch.Tensor:
|
| 1081 |
-
"""
|
| 1082 |
-
Create attention mask of shape [B, 1, 1, N] where N = L + E,
|
| 1083 |
-
based on latent tokens (always valid) and the encoder mask.
|
| 1084 |
-
|
| 1085 |
-
Args:
|
| 1086 |
-
hidden_states: [B, L, D]
|
| 1087 |
-
encoder_attention_mask: [B, E] (required)
|
| 1088 |
-
|
| 1089 |
-
Returns:
|
| 1090 |
-
attention_mask: [B, 1, 1, N]
|
| 1091 |
-
"""
|
| 1092 |
-
attention_mask = F.pad(
|
| 1093 |
-
encoder_attention_mask.to(torch.bool),
|
| 1094 |
-
(hidden_states.shape[1], 0),
|
| 1095 |
-
value=True,
|
| 1096 |
-
)
|
| 1097 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L+E]
|
| 1098 |
-
return attention_mask
|
| 1099 |
-
|
| 1100 |
-
def forward(
|
| 1101 |
-
self,
|
| 1102 |
-
hidden_states: torch.Tensor,
|
| 1103 |
-
timestep: torch.LongTensor,
|
| 1104 |
-
encoder_hidden_states: torch.Tensor,
|
| 1105 |
-
encoder_attention_mask: torch.Tensor | None = None,
|
| 1106 |
-
pooled_projections: torch.Tensor | None = None,
|
| 1107 |
-
image_embeds: torch.Tensor | None = None,
|
| 1108 |
-
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1109 |
-
return_dict: bool = True,
|
| 1110 |
-
tread_mixin: Optional[Any] = None,
|
| 1111 |
-
tread_disabled: bool = False,
|
| 1112 |
-
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 1113 |
-
"""
|
| 1114 |
-
Forward pass of the MotifVideoTransformer3DModel.
|
| 1115 |
-
|
| 1116 |
-
Args:
|
| 1117 |
-
hidden_states: Input latent tensor [B, C, F, H, W].
|
| 1118 |
-
timestep: Diffusion timesteps [B].
|
| 1119 |
-
encoder_hidden_states: Text conditioning [B, E, D].
|
| 1120 |
-
encoder_attention_mask: Mask for text conditioning [B, E].
|
| 1121 |
-
pooled_projections: Pooled text embeddings [B, D].
|
| 1122 |
-
image_embeds: Optional image embeddings from vision encoder [B, N, D].
|
| 1123 |
-
attention_kwargs: Additional arguments for attention processors.
|
| 1124 |
-
return_dict: Whether to return a Transformer2DModelOutput.
|
| 1125 |
-
tread_mixin: Optional TreadMixin instance for token reduction.
|
| 1126 |
-
tread_disabled: When True, force tread_mixin to None (dense pass).
|
| 1127 |
-
torch.compile specializes on this bool, producing separate graphs
|
| 1128 |
-
for dense vs routed without attribute toggling.
|
| 1129 |
-
|
| 1130 |
-
Returns:
|
| 1131 |
-
Transformer2DModelOutput or tuple containing the predicted samples.
|
| 1132 |
-
"""
|
| 1133 |
-
if tread_disabled:
|
| 1134 |
-
tread_mixin = None
|
| 1135 |
-
elif tread_mixin is None:
|
| 1136 |
-
tread_mixin = getattr(self, "_inference_tread_mixin", None)
|
| 1137 |
-
|
| 1138 |
-
if attention_kwargs is not None:
|
| 1139 |
-
attention_kwargs = attention_kwargs.copy()
|
| 1140 |
-
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 1141 |
-
else:
|
| 1142 |
-
lora_scale = 1.0
|
| 1143 |
-
|
| 1144 |
-
if USE_PEFT_BACKEND:
|
| 1145 |
-
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 1146 |
-
scale_lora_layers(self, lora_scale)
|
| 1147 |
-
else:
|
| 1148 |
-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 1149 |
-
logger.warning(
|
| 1150 |
-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 1151 |
-
)
|
| 1152 |
-
|
| 1153 |
-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 1154 |
-
p, p_t = self.config.patch_size, self.config.patch_size_t
|
| 1155 |
-
post_patch_num_frames = num_frames // p_t
|
| 1156 |
-
post_patch_height = height // p
|
| 1157 |
-
post_patch_width = width // p
|
| 1158 |
-
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
| 1159 |
-
# 1. RoPE
|
| 1160 |
-
image_rotary_emb = self.rope(hidden_states, timestep=timestep)
|
| 1161 |
-
# 2. Conditional embeddings
|
| 1162 |
-
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections)
|
| 1163 |
-
hidden_states = self.x_embedder(hidden_states)
|
| 1164 |
-
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 1165 |
-
|
| 1166 |
-
# First frame conditioning: Image embeddings from vision encoder
|
| 1167 |
-
if image_embeds is not None:
|
| 1168 |
-
# image_embeds: [B, N, D_img] -> [B, N, D]
|
| 1169 |
-
image_embeds = self.image_embedder(image_embeds)
|
| 1170 |
-
encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1)
|
| 1171 |
-
# Extend attention mask for image tokens
|
| 1172 |
-
if encoder_attention_mask is not None:
|
| 1173 |
-
image_mask = torch.ones(
|
| 1174 |
-
image_embeds.shape[0],
|
| 1175 |
-
image_embeds.shape[1],
|
| 1176 |
-
device=encoder_attention_mask.device,
|
| 1177 |
-
dtype=encoder_attention_mask.dtype,
|
| 1178 |
-
)
|
| 1179 |
-
encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1)
|
| 1180 |
-
|
| 1181 |
-
# image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states
|
| 1182 |
-
image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0
|
| 1183 |
-
|
| 1184 |
-
decoder_hidden_states = hidden_states.clone()
|
| 1185 |
-
|
| 1186 |
-
if encoder_attention_mask is not None:
|
| 1187 |
-
attention_mask = self._create_attention_mask(
|
| 1188 |
-
hidden_states=hidden_states,
|
| 1189 |
-
encoder_attention_mask=encoder_attention_mask,
|
| 1190 |
-
)
|
| 1191 |
-
else:
|
| 1192 |
-
attention_mask = None
|
| 1193 |
-
|
| 1194 |
-
# TREAD state initialization: manage token reduction manually to support activation checkpointing
|
| 1195 |
-
tread_active = False
|
| 1196 |
-
current_route = None
|
| 1197 |
-
ids_keep = None
|
| 1198 |
-
x_full = None
|
| 1199 |
-
orig_mask = attention_mask
|
| 1200 |
-
orig_rope = image_rotary_emb
|
| 1201 |
-
latent_len = hidden_states.shape[1]
|
| 1202 |
-
|
| 1203 |
-
# 4. Dual stream transformer blocks (Encoder)
|
| 1204 |
-
for i, block in enumerate(self.transformer_blocks):
|
| 1205 |
-
# Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
|
| 1206 |
-
if is_tread_start(tread_mixin, tread_active, i):
|
| 1207 |
-
tread_active = True
|
| 1208 |
-
current_route = tread_mixin._tread_route
|
| 1209 |
-
# Reduce sequence length at the start of a TREAD route
|
| 1210 |
-
ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
|
| 1211 |
-
x_full = hidden_states.contiguous()
|
| 1212 |
-
hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
|
| 1213 |
-
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1214 |
-
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1215 |
-
|
| 1216 |
-
hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1217 |
-
block,
|
| 1218 |
-
hidden_states,
|
| 1219 |
-
encoder_hidden_states,
|
| 1220 |
-
temb,
|
| 1221 |
-
attention_mask,
|
| 1222 |
-
image_rotary_emb,
|
| 1223 |
-
token_replace_emb,
|
| 1224 |
-
first_frame_num_tokens,
|
| 1225 |
-
image_embed_seq_len,
|
| 1226 |
-
encoder_attention_mask,
|
| 1227 |
-
)
|
| 1228 |
-
|
| 1229 |
-
if is_tread_end(tread_mixin, tread_active, i):
|
| 1230 |
-
# Restore full sequence length at the end of a TREAD route
|
| 1231 |
-
hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
|
| 1232 |
-
tread_active = False
|
| 1233 |
-
current_route = None
|
| 1234 |
-
ids_keep = None
|
| 1235 |
-
x_full = None
|
| 1236 |
-
attention_mask = orig_mask
|
| 1237 |
-
image_rotary_emb = orig_rope
|
| 1238 |
-
|
| 1239 |
-
# We need to unwrap the blocks because CheckpointWrapper does not support len(),
|
| 1240 |
-
# which is required for slicing the blocks into encoder and decoder parts.
|
| 1241 |
-
single_transformer_blocks = self.single_transformer_blocks
|
| 1242 |
-
|
| 1243 |
-
# 5. Single stream transformer blocks (Encoder)
|
| 1244 |
-
num_dual = len(self.transformer_blocks)
|
| 1245 |
-
for i, block in enumerate(
|
| 1246 |
-
single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]
|
| 1247 |
-
):
|
| 1248 |
-
# Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
|
| 1249 |
-
abs_i = num_dual + i
|
| 1250 |
-
if is_tread_start(tread_mixin, tread_active, abs_i):
|
| 1251 |
-
tread_active = True
|
| 1252 |
-
current_route = tread_mixin._tread_route
|
| 1253 |
-
# Reduce sequence length at the start of a TREAD route
|
| 1254 |
-
ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
|
| 1255 |
-
x_full = hidden_states.contiguous()
|
| 1256 |
-
hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
|
| 1257 |
-
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1258 |
-
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1259 |
-
|
| 1260 |
-
hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1261 |
-
block,
|
| 1262 |
-
hidden_states,
|
| 1263 |
-
encoder_hidden_states,
|
| 1264 |
-
temb,
|
| 1265 |
-
attention_mask,
|
| 1266 |
-
image_rotary_emb,
|
| 1267 |
-
token_replace_emb,
|
| 1268 |
-
first_frame_num_tokens,
|
| 1269 |
-
image_embed_seq_len,
|
| 1270 |
-
encoder_attention_mask,
|
| 1271 |
-
)
|
| 1272 |
-
|
| 1273 |
-
if is_tread_end(tread_mixin, tread_active, abs_i):
|
| 1274 |
-
# Restore full sequence length at the end of a TREAD route
|
| 1275 |
-
hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
|
| 1276 |
-
tread_active = False
|
| 1277 |
-
current_route = None
|
| 1278 |
-
ids_keep = None
|
| 1279 |
-
x_full = None
|
| 1280 |
-
attention_mask = orig_mask
|
| 1281 |
-
image_rotary_emb = orig_rope
|
| 1282 |
-
|
| 1283 |
-
# 6. Single stream transformer blocks (Decoder)
|
| 1284 |
-
if self.num_decoder_layers > 0:
|
| 1285 |
-
encoder_hidden_states = hidden_states
|
| 1286 |
-
attention_mask = None
|
| 1287 |
-
|
| 1288 |
-
num_single = len(single_transformer_blocks)
|
| 1289 |
-
|
| 1290 |
-
for i, block in enumerate(single_transformer_blocks[-self.num_decoder_layers :]):
|
| 1291 |
-
abs_i = num_dual + (num_single - self.num_decoder_layers) + i
|
| 1292 |
-
if is_tread_start(tread_mixin, tread_active, abs_i):
|
| 1293 |
-
tread_active = True
|
| 1294 |
-
current_route = tread_mixin._tread_route
|
| 1295 |
-
# Reduce sequence length at the start of a TREAD route
|
| 1296 |
-
ids_keep = tread_mixin.keep_indices(decoder_hidden_states, current_route["sel"]).to(
|
| 1297 |
-
decoder_hidden_states.device
|
| 1298 |
-
)
|
| 1299 |
-
x_full = encoder_hidden_states.contiguous()
|
| 1300 |
-
x_t_full = decoder_hidden_states.contiguous()
|
| 1301 |
-
decoder_hidden_states = tread_mixin.gather_tokens(decoder_hidden_states, ids_keep)
|
| 1302 |
-
encoder_hidden_states = tread_mixin.gather_tokens(encoder_hidden_states, ids_keep)
|
| 1303 |
-
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1304 |
-
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1305 |
-
|
| 1306 |
-
decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1307 |
-
block,
|
| 1308 |
-
decoder_hidden_states,
|
| 1309 |
-
encoder_hidden_states,
|
| 1310 |
-
temb,
|
| 1311 |
-
attention_mask,
|
| 1312 |
-
image_rotary_emb,
|
| 1313 |
-
token_replace_emb,
|
| 1314 |
-
first_frame_num_tokens,
|
| 1315 |
-
)
|
| 1316 |
-
|
| 1317 |
-
if is_tread_end(tread_mixin, tread_active, abs_i):
|
| 1318 |
-
# Restore full sequence length at the end of a TREAD route
|
| 1319 |
-
decoder_hidden_states = tread_mixin.scatter_tokens(decoder_hidden_states, ids_keep, x_t_full)
|
| 1320 |
-
encoder_hidden_states = tread_mixin.scatter_tokens(encoder_hidden_states, ids_keep, x_full)
|
| 1321 |
-
tread_active = False
|
| 1322 |
-
current_route = None
|
| 1323 |
-
ids_keep = None
|
| 1324 |
-
x_full = None
|
| 1325 |
-
x_t_full = None
|
| 1326 |
-
attention_mask = orig_mask
|
| 1327 |
-
image_rotary_emb = orig_rope
|
| 1328 |
-
|
| 1329 |
-
hidden_states = decoder_hidden_states
|
| 1330 |
-
|
| 1331 |
-
# 7. Output projection
|
| 1332 |
-
hidden_states = self.norm_out(hidden_states, temb)
|
| 1333 |
-
hidden_states = self.proj_out(hidden_states)
|
| 1334 |
-
|
| 1335 |
-
hidden_states = hidden_states.reshape(
|
| 1336 |
-
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
|
| 1337 |
-
)
|
| 1338 |
-
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
| 1339 |
-
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 1340 |
-
|
| 1341 |
-
if USE_PEFT_BACKEND:
|
| 1342 |
-
# remove `lora_scale` from each PEFT layer
|
| 1343 |
-
unscale_lora_layers(self, lora_scale)
|
| 1344 |
-
|
| 1345 |
-
if not return_dict:
|
| 1346 |
-
return (hidden_states,)
|
| 1347 |
-
|
| 1348 |
-
return Transformer2DModelOutput(
|
| 1349 |
-
sample=hidden_states,
|
| 1350 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|