Text-to-Video
Diffusers
Safetensors
English
MotifVideoPipeline
image-to-video
video-generation
diffusion-transformer
Instructions to use Motif-Technologies/Motif-Video-2B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Motif-Technologies/Motif-Video-2B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Motif-Technologies/Motif-Video-2B", dtype=torch.bfloat16, device_map="cuda") prompt = "A vibrant blue jay perches gracefully on a slender branch, its feathers shimmering in the soft morning light. The bird's keen eyes scan the surroundings, capturing the essence of the tranquil forest. It flutters its wings briefly, showcasing the intricate patterns of blue, white, and black on its plumage. The background reveals a lush canopy of green leaves, with rays of sunlight filtering through, creating a dappled effect on the forest floor. The blue jay then tilts its head, emitting a melodious call that echoes through the serene woodland, adding a touch of magic to the peaceful scene." image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
[WIP] diffusers integration
#21
by kencwt - opened
- .gitattributes +10 -0
- .gitignore +0 -0
- README.md +126 -74
- _fm_solvers_unipc.py +0 -759
- assets/astronaut.mp4 +3 -0
- assets/bird.mp4 +3 -0
- assets/fisherman.mp4 +3 -0
- assets/i2v_sample.jpg +3 -0
- assets/sage_compare_BF16.webp +3 -0
- assets/sage_compare_Q4_K_M.webp +3 -0
- assets/sage_compare_Q5_K_M.webp +3 -0
- assets/sage_compare_Q8_0.webp +3 -0
- assets/underwater.mp4 +3 -0
- assets/woman.mp4 +3 -0
- docs/gguf-sageattention.md +132 -0
- docs/memory-efficient-inference.md +88 -0
- inference.py +0 -119
- model_index.json +1 -1
- pipeline_motif_video.py +0 -1388
- transformer/config.json +0 -2
- transformer/transformer_motif_video.py +0 -1350
.gitattributes
CHANGED
|
@@ -40,3 +40,13 @@ assets/showcase_t2v.png filter=lfs diff=lfs merge=lfs -text
|
|
| 40 |
tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 41 |
assets/i2v_sample.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
motif-video-technical-report.pdf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 41 |
assets/i2v_sample.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
motif-video-technical-report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/astronaut.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/bird.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/fisherman.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
assets/underwater.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
assets/vows.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
assets/woman.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
assets/sage_compare_BF16.webp filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
assets/sage_compare_Q4_K_M.webp filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
assets/sage_compare_Q5_K_M.webp filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
assets/sage_compare_Q8_0.webp filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
CHANGED
|
@@ -9,6 +9,22 @@ tags:
|
|
| 9 |
- diffusion-transformer
|
| 10 |
pipeline_tag: text-to-video
|
| 11 |
library_name: diffusers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
<p align="center">
|
|
@@ -31,8 +47,25 @@ library_name: diffusers
|
|
| 31 |
|
| 32 |
---
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
## 🔥 News
|
| 35 |
|
|
|
|
|
|
|
| 36 |
- **[2026-04-14]** We release **Motif-Video 2B**, our 2B-parameter text-to-video and image-to-video diffusion transformer, together with the full [technical report](https://arxiv.org/abs/2604.16503).
|
| 37 |
|
| 38 |
---
|
|
@@ -108,41 +141,74 @@ For the full derivation of why Shared Cross-Attention shares K/V but not Q, and
|
|
| 108 |
### Requirements
|
| 109 |
|
| 110 |
- Python 3.10+
|
| 111 |
-
- CUDA-capable GPU with **30GB+ VRAM** (e.g., A100, H100) — for 24GB GPUs see [Memory-efficient Inference](
|
| 112 |
|
| 113 |
```bash
|
| 114 |
-
pip install "
|
|
|
|
| 115 |
```
|
| 116 |
|
| 117 |
### Text-to-Video (T2V)
|
| 118 |
|
| 119 |
```python
|
| 120 |
import torch
|
| 121 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
from diffusers.utils import export_to_video
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
guider = AdaptiveProjectedGuidance(
|
| 125 |
guidance_scale=8.0,
|
| 126 |
adaptive_projected_guidance_rescale=12.0,
|
| 127 |
adaptive_projected_guidance_momentum=0.1,
|
| 128 |
use_original_formulation=True,
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
-
pipe =
|
| 132 |
"Motif-Technologies/Motif-Video-2B",
|
| 133 |
-
|
| 134 |
-
trust_remote_code=True,
|
| 135 |
torch_dtype=torch.bfloat16,
|
| 136 |
guider=guider,
|
| 137 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
pipe = pipe.to("cuda")
|
| 139 |
|
| 140 |
output = pipe(
|
| 141 |
-
prompt="A
|
|
|
|
| 142 |
height=736,
|
| 143 |
width=1280,
|
| 144 |
num_frames=121,
|
| 145 |
num_inference_steps=50,
|
|
|
|
|
|
|
| 146 |
)
|
| 147 |
|
| 148 |
export_to_video(output.frames[0], "output.mp4", fps=24)
|
|
@@ -152,7 +218,11 @@ export_to_video(output.frames[0], "output.mp4", fps=24)
|
|
| 152 |
|
| 153 |
```python
|
| 154 |
import torch
|
| 155 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
from diffusers.utils import export_to_video, load_image
|
| 157 |
|
| 158 |
guider = AdaptiveProjectedGuidance(
|
|
@@ -160,26 +230,38 @@ guider = AdaptiveProjectedGuidance(
|
|
| 160 |
adaptive_projected_guidance_rescale=12.0,
|
| 161 |
adaptive_projected_guidance_momentum=0.1,
|
| 162 |
use_original_formulation=True,
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
-
pipe =
|
| 166 |
"Motif-Technologies/Motif-Video-2B",
|
| 167 |
-
|
| 168 |
-
trust_remote_code=True,
|
| 169 |
torch_dtype=torch.bfloat16,
|
| 170 |
guider=guider,
|
| 171 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
pipe = pipe.to("cuda")
|
| 173 |
|
| 174 |
image = load_image("https://huggingface.co/Motif-Technologies/Motif-Video-2B/resolve/main/assets/i2v_sample.jpg")
|
| 175 |
|
| 176 |
output = pipe(
|
| 177 |
-
prompt="Three friends stride through a sun-bleached meadow as a warm breeze ripples the tall dry grass around their legs.
|
|
|
|
| 178 |
image=image,
|
| 179 |
height=736,
|
| 180 |
width=1280,
|
| 181 |
num_frames=121,
|
| 182 |
num_inference_steps=50,
|
|
|
|
|
|
|
| 183 |
)
|
| 184 |
|
| 185 |
export_to_video(output.frames[0], "output.mp4", fps=24)
|
|
@@ -188,96 +270,66 @@ export_to_video(output.frames[0], "output.mp4", fps=24)
|
|
| 188 |
### CLI Inference
|
| 189 |
|
| 190 |
```bash
|
| 191 |
-
# Text-to-Video
|
| 192 |
python inference.py \
|
| 193 |
-
--prompt "A
|
| 194 |
--output t2v_output.mp4
|
| 195 |
|
| 196 |
-
#
|
| 197 |
python inference.py \
|
| 198 |
-
--
|
| 199 |
-
--
|
| 200 |
-
--output
|
| 201 |
```
|
| 202 |
|
| 203 |
-
See `inference.py` for all available options
|
| 204 |
|
| 205 |
### Recommended Settings
|
| 206 |
|
| 207 |
| Parameter | Default | Notes |
|
| 208 |
|---|---|---|
|
| 209 |
-
| Resolution |
|
| 210 |
| Frames | 121 | ~5 seconds at 24fps |
|
| 211 |
-
|
|
|
|
|
| 212 |
| Inference steps | 50 | |
|
|
|
|
|
|
|
| 213 |
| dtype | bfloat16 | Recommended for H100/A100 |
|
| 214 |
|
| 215 |
### 🔋 Memory-efficient Inference
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
For GPUs with 24 GB or less (e.g. RTX 4090, RTX 3090), use `enable_model_cpu_offload()` with the `expandable_segments` allocator setting:
|
| 220 |
-
|
| 221 |
-
```bash
|
| 222 |
-
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 223 |
-
```
|
| 224 |
-
|
| 225 |
-
```python
|
| 226 |
-
pipe = DiffusionPipeline.from_pretrained(
|
| 227 |
-
"Motif-Technologies/Motif-Video-2B",
|
| 228 |
-
custom_pipeline="pipeline_motif_video",
|
| 229 |
-
trust_remote_code=True,
|
| 230 |
-
torch_dtype=torch.bfloat16,
|
| 231 |
-
guider=guider, # see T2V example above
|
| 232 |
-
)
|
| 233 |
-
pipe.enable_model_cpu_offload() # replaces pipe.to("cuda")
|
| 234 |
-
|
| 235 |
-
output = pipe(prompt="...", height=736, width=1280, num_frames=121, num_inference_steps=50)
|
| 236 |
-
export_to_video(output.frames[0], "output.mp4", fps=24)
|
| 237 |
-
```
|
| 238 |
-
|
| 239 |
-
This moves each component (text encoder → transformer → VAE) to GPU only when needed. The `expandable_segments` setting allows the CUDA memory allocator to efficiently reuse memory released by earlier components, avoiding fragmentation-related OOM errors.
|
| 240 |
-
|
| 241 |
-
| Mode | Peak VRAM | Speed | Recommended GPU |
|
| 242 |
-
|------|-----------|-------|-----------------|
|
| 243 |
-
| `pipe.to("cuda")` | ~30 GB | Fastest | A100, H100, H200 |
|
| 244 |
-
| `enable_model_cpu_offload()` | ~19 GB | Similar | RTX 4090, RTX 3090 |
|
| 245 |
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
-
|
| 249 |
|
| 250 |
-
|
| 251 |
-
pip install torchao
|
| 252 |
-
```
|
| 253 |
|
| 254 |
-
|
| 255 |
-
from torchao.quantization import quantize_, Float8WeightOnlyConfig
|
| 256 |
|
| 257 |
-
|
| 258 |
-
"Motif-Technologies/Motif-Video-2B",
|
| 259 |
-
custom_pipeline="pipeline_motif_video",
|
| 260 |
-
trust_remote_code=True,
|
| 261 |
-
torch_dtype=torch.bfloat16,
|
| 262 |
-
guider=guider, # see T2V example above
|
| 263 |
-
)
|
| 264 |
-
quantize_(pipe.transformer, Float8WeightOnlyConfig())
|
| 265 |
-
pipe.enable_model_cpu_offload()
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
| 270 |
|
| 271 |
-
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|------|-----------|-------|
|
| 275 |
-
| `enable_model_cpu_offload()` | ~19 GB | BF16 baseline |
|
| 276 |
-
| `+ Float8WeightOnlyConfig` | ~15 GB | FP8 weights, BF16 compute |
|
| 277 |
|
| 278 |
### 🖥️ ComfyUI
|
| 279 |
|
| 280 |
-
Official ComfyUI custom nodes
|
|
|
|
|
|
|
| 281 |
|
| 282 |
---
|
| 283 |
|
|
|
|
| 9 |
- diffusion-transformer
|
| 10 |
pipeline_tag: text-to-video
|
| 11 |
library_name: diffusers
|
| 12 |
+
widget:
|
| 13 |
+
- text: "A vibrant blue jay perches gracefully on a slender branch, its feathers shimmering in the soft morning light. The bird's keen eyes scan the surroundings, capturing the essence of the tranquil forest. It flutters its wings briefly, showcasing the intricate patterns of blue, white, and black on its plumage. The background reveals a lush canopy of green leaves, with rays of sunlight filtering through, creating a dappled effect on the forest floor. The blue jay then tilts its head, emitting a melodious call that echoes through the serene woodland, adding a touch of magic to the peaceful scene."
|
| 14 |
+
output:
|
| 15 |
+
url: assets/bird.mp4
|
| 16 |
+
- text: "Underwater footage of a vibrant coral reef ecosystem with tropical fish swimming through coral formations. Natural sunlight filtering down through clear water creates dancing light patterns on the reef. Smooth underwater camera movement, natural color correction preserving authentic ocean blues and coral colors, documentary marine biology style, peaceful and educational mood."
|
| 17 |
+
output:
|
| 18 |
+
url: assets/underwater.mp4
|
| 19 |
+
- text: "An old fisherman mends his nets on a stone harbor wall, weathered hands moving with practiced speed through the green mesh. Shot on a 50mm lens with a slow dolly-in from his side, the afternoon sun throws warm light across his salt-stained coat and the worn granite beneath him. Behind him, a single wooden boat bobs gently in a turquoise bay. Gulls drift through the distant sky in soft focus. The camera settles on his hands, then racks focus to his weathered, squinting eyes."
|
| 20 |
+
output:
|
| 21 |
+
url: assets/fisherman.mp4
|
| 22 |
+
- text: "A lone astronaut drifts just outside a derelict space station, tethered by a single silver line as Earth's terminator glows blue-white behind her. Shot with a slow wide-to-medium push, the camera floats alongside her in weightless silence, the curvature of the planet filling the lower third of the frame. Sunlight rakes across the hull's scarred panels, casting long hard shadows that stretch and shift as she rotates. Her visor reflects the aurora below, ribbons of green pulling across the glass. She reaches out with a gloved hand and lets her fingertips graze a dented antenna, the gesture small and reverent."
|
| 23 |
+
output:
|
| 24 |
+
url: assets/astronaut.mp4
|
| 25 |
+
- text: "A woman standing in a sunlit field as flower petals swirl around her in slow motion. Each petal floats gently through the golden light, casting tiny shadows. Her hair moves like water, and time seems to stand still."
|
| 26 |
+
output:
|
| 27 |
+
url: assets/woman.mp4
|
| 28 |
---
|
| 29 |
|
| 30 |
<p align="center">
|
|
|
|
| 47 |
|
| 48 |
---
|
| 49 |
|
| 50 |
+
<!--
|
| 51 |
+
NOTE: This README is written against the CURRENT state of diffusers PR #13551
|
| 52 |
+
(pre-merge). The PR currently has issues:
|
| 53 |
+
- negative_prompt defaults to None (should be built-in)
|
| 54 |
+
- use_linear_quadratic_schedule defaults to True (should be False)
|
| 55 |
+
- DPMSolverMultistepScheduler crashes (pipeline always passes sigmas)
|
| 56 |
+
- No built-in SageAttention support (requires manual patching)
|
| 57 |
+
|
| 58 |
+
Code examples below include workarounds (explicit negative_prompt,
|
| 59 |
+
use_linear_quadratic_schedule=False, _FlowDPMSolver subclass).
|
| 60 |
+
|
| 61 |
+
TODO: Update after PR feedback is applied, and again after merge.
|
| 62 |
+
Tracking: https://github.com/MotifTechnologies/diffusers/pull/1
|
| 63 |
+
-->
|
| 64 |
+
|
| 65 |
## 🔥 News
|
| 66 |
|
| 67 |
+
- **[2026-04-28]** **ComfyUI custom nodes** released: [ComfyUI-MotifVideo2B](https://github.com/MotifTechnologies/ComfyUI-MotifVideo2B). GGUF workflow support coming soon.
|
| 68 |
+
- **[2026-04-28]** **GGUF quantized weights** now available at [Motif-Video-2B-GGUF](https://huggingface.co/Motif-Technologies/Motif-Video-2B-GGUF) — up to 2.7 GB VRAM savings with no speed penalty. **SageAttention** support for ~2× faster inference. See [GGUF + SageAttention](#🧊-gguf--sageattention) below.
|
| 69 |
- **[2026-04-14]** We release **Motif-Video 2B**, our 2B-parameter text-to-video and image-to-video diffusion transformer, together with the full [technical report](https://arxiv.org/abs/2604.16503).
|
| 70 |
|
| 71 |
---
|
|
|
|
| 141 |
### Requirements
|
| 142 |
|
| 143 |
- Python 3.10+
|
| 144 |
+
- CUDA-capable GPU with **30GB+ VRAM** (e.g., A100, H100) — for 24GB GPUs see [Memory-efficient Inference](🔋-memory-efficient-inference)
|
| 145 |
|
| 146 |
```bash
|
| 147 |
+
pip install "transformers>=5.5.4" torch accelerate ftfy einops sentencepiece regex Pillow imageio imageio-ffmpeg
|
| 148 |
+
pip install git+https://github.com/waitingcheung/diffusers.git@feat/motif-video
|
| 149 |
```
|
| 150 |
|
| 151 |
### Text-to-Video (T2V)
|
| 152 |
|
| 153 |
```python
|
| 154 |
import torch
|
| 155 |
+
from diffusers import (
|
| 156 |
+
AdaptiveProjectedGuidance,
|
| 157 |
+
DPMSolverMultistepScheduler,
|
| 158 |
+
MotifVideoPipeline,
|
| 159 |
+
)
|
| 160 |
from diffusers.utils import export_to_video
|
| 161 |
|
| 162 |
+
|
| 163 |
+
# DPMSolver++ subclass: ignores pipeline-supplied sigmas and builds its own
|
| 164 |
+
# flow-matching schedule. Will be unnecessary once PR #13551 adds the
|
| 165 |
+
# _is_flow_multistep branch.
|
| 166 |
+
class FlowDPMSolver(DPMSolverMultistepScheduler):
|
| 167 |
+
def set_timesteps(self, num_inference_steps=None, device=None,
|
| 168 |
+
sigmas=None, mu=None, timesteps=None):
|
| 169 |
+
if sigmas is not None and num_inference_steps is None:
|
| 170 |
+
num_inference_steps = len(sigmas)
|
| 171 |
+
super().set_timesteps(
|
| 172 |
+
num_inference_steps=num_inference_steps,
|
| 173 |
+
device=device, timesteps=timesteps,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
guider = AdaptiveProjectedGuidance(
|
| 178 |
guidance_scale=8.0,
|
| 179 |
adaptive_projected_guidance_rescale=12.0,
|
| 180 |
adaptive_projected_guidance_momentum=0.1,
|
| 181 |
use_original_formulation=True,
|
| 182 |
+
normalization_dims="spatial",
|
| 183 |
)
|
| 184 |
|
| 185 |
+
pipe = MotifVideoPipeline.from_pretrained(
|
| 186 |
"Motif-Technologies/Motif-Video-2B",
|
| 187 |
+
revision="diffusers-integration",
|
|
|
|
| 188 |
torch_dtype=torch.bfloat16,
|
| 189 |
guider=guider,
|
| 190 |
)
|
| 191 |
+
|
| 192 |
+
# DPMSolver++ for faster convergence
|
| 193 |
+
pipe.scheduler = FlowDPMSolver(
|
| 194 |
+
num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
|
| 195 |
+
algorithm_type="dpmsolver++",
|
| 196 |
+
solver_order=2,
|
| 197 |
+
prediction_type="flow_prediction",
|
| 198 |
+
use_flow_sigmas=True,
|
| 199 |
+
flow_shift=15.0,
|
| 200 |
+
)
|
| 201 |
pipe = pipe.to("cuda")
|
| 202 |
|
| 203 |
output = pipe(
|
| 204 |
+
prompt="A woman standing in a sunlit field as flower petals swirl around her in slow motion. Each petal floats gently through the golden light, casting tiny shadows. Her hair moves like water, and time seems to stand still.",
|
| 205 |
+
negative_prompt="text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
|
| 206 |
height=736,
|
| 207 |
width=1280,
|
| 208 |
num_frames=121,
|
| 209 |
num_inference_steps=50,
|
| 210 |
+
frame_rate=24,
|
| 211 |
+
use_linear_quadratic_schedule=False,
|
| 212 |
)
|
| 213 |
|
| 214 |
export_to_video(output.frames[0], "output.mp4", fps=24)
|
|
|
|
| 218 |
|
| 219 |
```python
|
| 220 |
import torch
|
| 221 |
+
from diffusers import (
|
| 222 |
+
AdaptiveProjectedGuidance,
|
| 223 |
+
DPMSolverMultistepScheduler,
|
| 224 |
+
MotifVideoPipeline,
|
| 225 |
+
)
|
| 226 |
from diffusers.utils import export_to_video, load_image
|
| 227 |
|
| 228 |
guider = AdaptiveProjectedGuidance(
|
|
|
|
| 230 |
adaptive_projected_guidance_rescale=12.0,
|
| 231 |
adaptive_projected_guidance_momentum=0.1,
|
| 232 |
use_original_formulation=True,
|
| 233 |
+
normalization_dims="spatial",
|
| 234 |
)
|
| 235 |
|
| 236 |
+
pipe = MotifVideoPipeline.from_pretrained(
|
| 237 |
"Motif-Technologies/Motif-Video-2B",
|
| 238 |
+
revision="diffusers-integration",
|
|
|
|
| 239 |
torch_dtype=torch.bfloat16,
|
| 240 |
guider=guider,
|
| 241 |
)
|
| 242 |
+
|
| 243 |
+
pipe.scheduler = FlowDPMSolver(
|
| 244 |
+
num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
|
| 245 |
+
algorithm_type="dpmsolver++",
|
| 246 |
+
solver_order=2,
|
| 247 |
+
prediction_type="flow_prediction",
|
| 248 |
+
use_flow_sigmas=True,
|
| 249 |
+
flow_shift=15.0,
|
| 250 |
+
)
|
| 251 |
pipe = pipe.to("cuda")
|
| 252 |
|
| 253 |
image = load_image("https://huggingface.co/Motif-Technologies/Motif-Video-2B/resolve/main/assets/i2v_sample.jpg")
|
| 254 |
|
| 255 |
output = pipe(
|
| 256 |
+
prompt="Three friends stride through a sun-bleached meadow as a warm breeze ripples the tall dry grass around their legs.",
|
| 257 |
+
negative_prompt="text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
|
| 258 |
image=image,
|
| 259 |
height=736,
|
| 260 |
width=1280,
|
| 261 |
num_frames=121,
|
| 262 |
num_inference_steps=50,
|
| 263 |
+
frame_rate=24,
|
| 264 |
+
use_linear_quadratic_schedule=False,
|
| 265 |
)
|
| 266 |
|
| 267 |
export_to_video(output.frames[0], "output.mp4", fps=24)
|
|
|
|
| 270 |
### CLI Inference
|
| 271 |
|
| 272 |
```bash
|
| 273 |
+
# Text-to-Video (default settings)
|
| 274 |
python inference.py \
|
| 275 |
+
--prompt "A woman standing in a sunlit field as..." \
|
| 276 |
--output t2v_output.mp4
|
| 277 |
|
| 278 |
+
# With SageAttention (~2x faster, requires sageattention package)
|
| 279 |
python inference.py \
|
| 280 |
+
--prompt "Three friends stride through a sun-bleached meadow..." \
|
| 281 |
+
--use-sage-attention \
|
| 282 |
+
--output t2v_output.mp4
|
| 283 |
```
|
| 284 |
|
| 285 |
+
See `inference.py --help` for all available options.
|
| 286 |
|
| 287 |
### Recommended Settings
|
| 288 |
|
| 289 |
| Parameter | Default | Notes |
|
| 290 |
|---|---|---|
|
| 291 |
+
| Resolution | 1280×736 | 720p, best quality |
|
| 292 |
| Frames | 121 | ~5 seconds at 24fps |
|
| 293 |
+
| Scheduler | DPMSolver++ | `solver_order=2`, `flow_shift=15.0` |
|
| 294 |
+
| Guidance scale | 8.0 | With APG (`normalization_dims="spatial"`) |
|
| 295 |
| Inference steps | 50 | |
|
| 296 |
+
| Negative prompt | (built-in) | See code examples above |
|
| 297 |
+
| `use_linear_quadratic_schedule` | `False` | Must be set explicitly |
|
| 298 |
| dtype | bfloat16 | Recommended for H100/A100 |
|
| 299 |
|
| 300 |
### 🔋 Memory-efficient Inference
|
| 301 |
|
| 302 |
+
For GPUs with 24 GB or less (e.g. RTX 4090, RTX 3090), CPU offloading and FP8 quantization can reduce peak VRAM from ~30 GB to ~15 GB with minimal speed impact.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
+
| Mode | Peak VRAM | Recommended GPU |
|
| 305 |
+
|------|-----------|-----------------|
|
| 306 |
+
| `pipe.to("cuda")` | ~30 GB | A100, H100, H200 |
|
| 307 |
+
| `enable_model_cpu_offload()` | ~19 GB | RTX 4090, RTX 3090 |
|
| 308 |
+
| `+ FP8 quantization` | ~15 GB | RTX 4090, RTX 3090 |
|
| 309 |
|
| 310 |
+
> **Full guide** → [docs/memory-efficient-inference.md](docs/memory-efficient-inference.md)
|
| 311 |
|
| 312 |
+
---
|
|
|
|
|
|
|
| 313 |
|
| 314 |
+
### 🧊 GGUF + SageAttention
|
|
|
|
| 315 |
|
| 316 |
+
GGUF quantized weights at [Motif-Video-2B-GGUF](https://huggingface.co/Motif-Technologies/Motif-Video-2B-GGUF) — up to 2.7 GB VRAM savings with no speed penalty. Combined with [SageAttention](https://github.com/thu-ml/SageAttention) for ~1.6× faster inference.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
+
| Variant | Sage (s/it) | Speedup | Peak alloc (GB) |
|
| 319 |
+
|---------|------------|---------|-----------------|
|
| 320 |
+
| BF16 | 14.75 | 1.58x | 15.12 |
|
| 321 |
+
| Q8_0 | 14.49 | 1.60x | 13.44 |
|
| 322 |
+
| Q4_K_M | 14.59 | 1.60x | 12.53 |
|
| 323 |
|
| 324 |
+
> **Full guide** → [docs/gguf-sageattention.md](docs/gguf-sageattention.md)
|
| 325 |
|
| 326 |
+
---
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
### 🖥️ ComfyUI
|
| 329 |
|
| 330 |
+
Official ComfyUI custom nodes: [ComfyUI-MotifVideo2B](https://github.com/MotifTechnologies/ComfyUI-MotifVideo2B)
|
| 331 |
+
|
| 332 |
+
> **Note:** Currently requires **High VRAM** mode. GGUF quantized model loading in ComfyUI is in progress.
|
| 333 |
|
| 334 |
---
|
| 335 |
|
_fm_solvers_unipc.py
DELETED
|
@@ -1,759 +0,0 @@
|
|
| 1 |
-
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
|
| 2 |
-
# Convert unipc for flow matching
|
| 3 |
-
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 4 |
-
|
| 5 |
-
import math
|
| 6 |
-
from typing import List, Optional, Tuple, Union
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 11 |
-
from diffusers.schedulers.scheduling_utils import (
|
| 12 |
-
KarrasDiffusionSchedulers,
|
| 13 |
-
SchedulerMixin,
|
| 14 |
-
SchedulerOutput,
|
| 15 |
-
)
|
| 16 |
-
from diffusers.utils import deprecate, is_scipy_available
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
if is_scipy_available():
|
| 20 |
-
pass
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 24 |
-
"""
|
| 25 |
-
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
|
| 26 |
-
|
| 27 |
-
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 28 |
-
methods the library implements for all schedulers such as loading and saving.
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
num_train_timesteps (`int`, defaults to 1000):
|
| 32 |
-
The number of diffusion steps to train the model.
|
| 33 |
-
solver_order (`int`, default `2`):
|
| 34 |
-
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
|
| 35 |
-
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
|
| 36 |
-
unconditional sampling.
|
| 37 |
-
prediction_type (`str`, defaults to "flow_prediction"):
|
| 38 |
-
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
|
| 39 |
-
the flow of the diffusion process.
|
| 40 |
-
thresholding (`bool`, defaults to `False`):
|
| 41 |
-
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 42 |
-
as Stable Diffusion.
|
| 43 |
-
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 44 |
-
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 45 |
-
sample_max_value (`float`, defaults to 1.0):
|
| 46 |
-
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
|
| 47 |
-
predict_x0 (`bool`, defaults to `True`):
|
| 48 |
-
Whether to use the updating algorithm on the predicted x0.
|
| 49 |
-
solver_type (`str`, default `bh2`):
|
| 50 |
-
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
|
| 51 |
-
otherwise.
|
| 52 |
-
lower_order_final (`bool`, default `True`):
|
| 53 |
-
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 54 |
-
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 55 |
-
disable_corrector (`list`, default `[]`):
|
| 56 |
-
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
|
| 57 |
-
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
|
| 58 |
-
usually disabled during the first few steps.
|
| 59 |
-
solver_p (`SchedulerMixin`, default `None`):
|
| 60 |
-
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
|
| 61 |
-
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 62 |
-
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 63 |
-
the sigmas are determined according to a sequence of noise levels {σi}.
|
| 64 |
-
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
| 65 |
-
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
| 66 |
-
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 67 |
-
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 68 |
-
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 69 |
-
steps_offset (`int`, defaults to 0):
|
| 70 |
-
An offset added to the inference steps, as required by some model families.
|
| 71 |
-
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 72 |
-
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 73 |
-
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 74 |
-
"""
|
| 75 |
-
|
| 76 |
-
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 77 |
-
order = 1
|
| 78 |
-
|
| 79 |
-
@register_to_config
|
| 80 |
-
def __init__(
|
| 81 |
-
self,
|
| 82 |
-
num_train_timesteps: int = 1000,
|
| 83 |
-
solver_order: int = 2,
|
| 84 |
-
prediction_type: str = "flow_prediction",
|
| 85 |
-
shift: Optional[float] = 1.0,
|
| 86 |
-
use_dynamic_shifting=False,
|
| 87 |
-
thresholding: bool = False,
|
| 88 |
-
dynamic_thresholding_ratio: float = 0.995,
|
| 89 |
-
sample_max_value: float = 1.0,
|
| 90 |
-
predict_x0: bool = True,
|
| 91 |
-
solver_type: str = "bh2",
|
| 92 |
-
lower_order_final: bool = True,
|
| 93 |
-
disable_corrector: List[int] = [],
|
| 94 |
-
solver_p: Optional[SchedulerMixin] = None,
|
| 95 |
-
timestep_spacing: str = "linspace",
|
| 96 |
-
steps_offset: int = 0,
|
| 97 |
-
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 98 |
-
):
|
| 99 |
-
if solver_type not in ["bh1", "bh2"]:
|
| 100 |
-
if solver_type in ["midpoint", "heun", "logrho"]:
|
| 101 |
-
self.register_to_config(solver_type="bh2")
|
| 102 |
-
else:
|
| 103 |
-
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
| 104 |
-
|
| 105 |
-
self.predict_x0 = predict_x0
|
| 106 |
-
# setable values
|
| 107 |
-
self.num_inference_steps = None
|
| 108 |
-
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
| 109 |
-
sigmas = 1.0 - alphas
|
| 110 |
-
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
|
| 111 |
-
|
| 112 |
-
if not use_dynamic_shifting:
|
| 113 |
-
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
| 114 |
-
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
| 115 |
-
|
| 116 |
-
self.sigmas = sigmas
|
| 117 |
-
self.timesteps = sigmas * num_train_timesteps
|
| 118 |
-
|
| 119 |
-
self.model_outputs = [None] * solver_order
|
| 120 |
-
self.timestep_list = [None] * solver_order
|
| 121 |
-
self.lower_order_nums = 0
|
| 122 |
-
self.disable_corrector = disable_corrector
|
| 123 |
-
self.solver_p = solver_p
|
| 124 |
-
self.last_sample = None
|
| 125 |
-
self._step_index = None
|
| 126 |
-
self._begin_index = None
|
| 127 |
-
|
| 128 |
-
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 129 |
-
self.sigma_min = self.sigmas[-1].item()
|
| 130 |
-
self.sigma_max = self.sigmas[0].item()
|
| 131 |
-
|
| 132 |
-
@property
|
| 133 |
-
def step_index(self):
|
| 134 |
-
"""
|
| 135 |
-
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 136 |
-
"""
|
| 137 |
-
return self._step_index
|
| 138 |
-
|
| 139 |
-
@property
|
| 140 |
-
def begin_index(self):
|
| 141 |
-
"""
|
| 142 |
-
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 143 |
-
"""
|
| 144 |
-
return self._begin_index
|
| 145 |
-
|
| 146 |
-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 147 |
-
def set_begin_index(self, begin_index: int = 0):
|
| 148 |
-
"""
|
| 149 |
-
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 150 |
-
|
| 151 |
-
Args:
|
| 152 |
-
begin_index (`int`):
|
| 153 |
-
The begin index for the scheduler.
|
| 154 |
-
"""
|
| 155 |
-
self._begin_index = begin_index
|
| 156 |
-
|
| 157 |
-
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
|
| 158 |
-
def set_timesteps(
|
| 159 |
-
self,
|
| 160 |
-
num_inference_steps: Union[int, None] = None,
|
| 161 |
-
device: Optional[Union[str, torch.device]] = None,
|
| 162 |
-
sigmas: Optional[List[float]] = None,
|
| 163 |
-
mu: Optional[Union[float, None]] = None,
|
| 164 |
-
shift: Optional[Union[float, None]] = None,
|
| 165 |
-
):
|
| 166 |
-
"""
|
| 167 |
-
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 168 |
-
Args:
|
| 169 |
-
num_inference_steps (`int`):
|
| 170 |
-
Total number of the spacing of the time steps.
|
| 171 |
-
device (`str` or `torch.device`, *optional*):
|
| 172 |
-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 173 |
-
"""
|
| 174 |
-
|
| 175 |
-
if self.config.use_dynamic_shifting and mu is None:
|
| 176 |
-
raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
| 177 |
-
|
| 178 |
-
if sigmas is None:
|
| 179 |
-
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore
|
| 180 |
-
|
| 181 |
-
if self.config.use_dynamic_shifting:
|
| 182 |
-
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
|
| 183 |
-
else:
|
| 184 |
-
if shift is None:
|
| 185 |
-
shift = self.config.shift
|
| 186 |
-
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
| 187 |
-
|
| 188 |
-
if self.config.final_sigmas_type == "sigma_min":
|
| 189 |
-
sigma_last = self.config.sigma_min
|
| 190 |
-
elif self.config.final_sigmas_type == "zero":
|
| 191 |
-
sigma_last = 0
|
| 192 |
-
else:
|
| 193 |
-
raise ValueError(
|
| 194 |
-
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
timesteps = sigmas * self.config.num_train_timesteps
|
| 198 |
-
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
|
| 199 |
-
|
| 200 |
-
self.sigmas = torch.from_numpy(sigmas)
|
| 201 |
-
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
| 202 |
-
|
| 203 |
-
self.num_inference_steps = len(timesteps)
|
| 204 |
-
|
| 205 |
-
self.model_outputs = [
|
| 206 |
-
None,
|
| 207 |
-
] * self.config.solver_order
|
| 208 |
-
self.lower_order_nums = 0
|
| 209 |
-
self.last_sample = None
|
| 210 |
-
if self.solver_p:
|
| 211 |
-
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
| 212 |
-
|
| 213 |
-
# add an index counter for schedulers that allow duplicated timesteps
|
| 214 |
-
self._step_index = None
|
| 215 |
-
self._begin_index = None
|
| 216 |
-
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 217 |
-
|
| 218 |
-
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 219 |
-
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 220 |
-
"""
|
| 221 |
-
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 222 |
-
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 223 |
-
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 224 |
-
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 225 |
-
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 226 |
-
|
| 227 |
-
https://arxiv.org/abs/2205.11487
|
| 228 |
-
"""
|
| 229 |
-
dtype = sample.dtype
|
| 230 |
-
batch_size, channels, *remaining_dims = sample.shape
|
| 231 |
-
|
| 232 |
-
if dtype not in (torch.float32, torch.float64):
|
| 233 |
-
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 234 |
-
|
| 235 |
-
# Flatten sample for doing quantile calculation along each image
|
| 236 |
-
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 237 |
-
|
| 238 |
-
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 239 |
-
|
| 240 |
-
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 241 |
-
s = torch.clamp(
|
| 242 |
-
s, min=1, max=self.config.sample_max_value
|
| 243 |
-
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 244 |
-
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 245 |
-
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 246 |
-
|
| 247 |
-
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 248 |
-
sample = sample.to(dtype)
|
| 249 |
-
|
| 250 |
-
return sample
|
| 251 |
-
|
| 252 |
-
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
|
| 253 |
-
def _sigma_to_t(self, sigma):
|
| 254 |
-
return sigma * self.config.num_train_timesteps
|
| 255 |
-
|
| 256 |
-
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 257 |
-
return 1 - sigma, sigma
|
| 258 |
-
|
| 259 |
-
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
|
| 260 |
-
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 261 |
-
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 262 |
-
|
| 263 |
-
def convert_model_output(
|
| 264 |
-
self,
|
| 265 |
-
model_output: torch.Tensor,
|
| 266 |
-
*args,
|
| 267 |
-
sample: Optional[torch.Tensor] = None,
|
| 268 |
-
**kwargs,
|
| 269 |
-
) -> torch.Tensor:
|
| 270 |
-
r"""
|
| 271 |
-
Convert the model output to the corresponding type the UniPC algorithm needs.
|
| 272 |
-
|
| 273 |
-
Args:
|
| 274 |
-
model_output (`torch.Tensor`):
|
| 275 |
-
The direct output from the learned diffusion model.
|
| 276 |
-
timestep (`int`):
|
| 277 |
-
The current discrete timestep in the diffusion chain.
|
| 278 |
-
sample (`torch.Tensor`):
|
| 279 |
-
A current instance of a sample created by the diffusion process.
|
| 280 |
-
|
| 281 |
-
Returns:
|
| 282 |
-
`torch.Tensor`:
|
| 283 |
-
The converted model output.
|
| 284 |
-
"""
|
| 285 |
-
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 286 |
-
if sample is None:
|
| 287 |
-
if len(args) > 1:
|
| 288 |
-
sample = args[1]
|
| 289 |
-
else:
|
| 290 |
-
raise ValueError("missing `sample` as a required keyward argument")
|
| 291 |
-
if timestep is not None:
|
| 292 |
-
deprecate(
|
| 293 |
-
"timesteps",
|
| 294 |
-
"1.0.0",
|
| 295 |
-
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 296 |
-
)
|
| 297 |
-
|
| 298 |
-
sigma = self.sigmas[self.step_index]
|
| 299 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 300 |
-
|
| 301 |
-
if self.predict_x0:
|
| 302 |
-
if self.config.prediction_type == "flow_prediction":
|
| 303 |
-
sigma_t = self.sigmas[self.step_index]
|
| 304 |
-
x0_pred = sample - sigma_t * model_output
|
| 305 |
-
else:
|
| 306 |
-
raise ValueError(
|
| 307 |
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 308 |
-
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
if self.config.thresholding:
|
| 312 |
-
x0_pred = self._threshold_sample(x0_pred)
|
| 313 |
-
|
| 314 |
-
return x0_pred
|
| 315 |
-
else:
|
| 316 |
-
if self.config.prediction_type == "flow_prediction":
|
| 317 |
-
sigma_t = self.sigmas[self.step_index]
|
| 318 |
-
epsilon = sample - (1 - sigma_t) * model_output
|
| 319 |
-
else:
|
| 320 |
-
raise ValueError(
|
| 321 |
-
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 322 |
-
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
if self.config.thresholding:
|
| 326 |
-
sigma_t = self.sigmas[self.step_index]
|
| 327 |
-
x0_pred = sample - sigma_t * model_output
|
| 328 |
-
x0_pred = self._threshold_sample(x0_pred)
|
| 329 |
-
epsilon = model_output + x0_pred
|
| 330 |
-
|
| 331 |
-
return epsilon
|
| 332 |
-
|
| 333 |
-
def multistep_uni_p_bh_update(
|
| 334 |
-
self,
|
| 335 |
-
model_output: torch.Tensor,
|
| 336 |
-
*args,
|
| 337 |
-
sample: Optional[torch.Tensor] = None,
|
| 338 |
-
order: Optional[int] = None,
|
| 339 |
-
**kwargs,
|
| 340 |
-
) -> torch.Tensor:
|
| 341 |
-
"""
|
| 342 |
-
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
| 343 |
-
|
| 344 |
-
Args:
|
| 345 |
-
model_output (`torch.Tensor`):
|
| 346 |
-
The direct output from the learned diffusion model at the current timestep.
|
| 347 |
-
prev_timestep (`int`):
|
| 348 |
-
The previous discrete timestep in the diffusion chain.
|
| 349 |
-
sample (`torch.Tensor`):
|
| 350 |
-
A current instance of a sample created by the diffusion process.
|
| 351 |
-
order (`int`):
|
| 352 |
-
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
| 353 |
-
|
| 354 |
-
Returns:
|
| 355 |
-
`torch.Tensor`:
|
| 356 |
-
The sample tensor at the previous timestep.
|
| 357 |
-
"""
|
| 358 |
-
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
| 359 |
-
if sample is None:
|
| 360 |
-
if len(args) > 1:
|
| 361 |
-
sample = args[1]
|
| 362 |
-
else:
|
| 363 |
-
raise ValueError(" missing `sample` as a required keyward argument")
|
| 364 |
-
if order is None:
|
| 365 |
-
if len(args) > 2:
|
| 366 |
-
order = args[2]
|
| 367 |
-
else:
|
| 368 |
-
raise ValueError(" missing `order` as a required keyward argument")
|
| 369 |
-
if prev_timestep is not None:
|
| 370 |
-
deprecate(
|
| 371 |
-
"prev_timestep",
|
| 372 |
-
"1.0.0",
|
| 373 |
-
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 374 |
-
)
|
| 375 |
-
model_output_list = self.model_outputs
|
| 376 |
-
|
| 377 |
-
s0 = self.timestep_list[-1]
|
| 378 |
-
m0 = model_output_list[-1]
|
| 379 |
-
x = sample
|
| 380 |
-
|
| 381 |
-
if self.solver_p:
|
| 382 |
-
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
| 383 |
-
return x_t
|
| 384 |
-
|
| 385 |
-
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore
|
| 386 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 387 |
-
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 388 |
-
|
| 389 |
-
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 390 |
-
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 391 |
-
|
| 392 |
-
h = lambda_t - lambda_s0
|
| 393 |
-
device = sample.device
|
| 394 |
-
|
| 395 |
-
rks = []
|
| 396 |
-
D1s = []
|
| 397 |
-
for i in range(1, order):
|
| 398 |
-
si = self.step_index - i # pyright: ignore
|
| 399 |
-
mi = model_output_list[-(i + 1)]
|
| 400 |
-
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 401 |
-
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 402 |
-
rk = (lambda_si - lambda_s0) / h
|
| 403 |
-
rks.append(rk)
|
| 404 |
-
D1s.append((mi - m0) / rk) # pyright: ignore
|
| 405 |
-
|
| 406 |
-
rks.append(1.0)
|
| 407 |
-
rks = torch.tensor(rks, device=device)
|
| 408 |
-
|
| 409 |
-
R = []
|
| 410 |
-
b = []
|
| 411 |
-
|
| 412 |
-
hh = -h if self.predict_x0 else h
|
| 413 |
-
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 414 |
-
h_phi_k = h_phi_1 / hh - 1
|
| 415 |
-
|
| 416 |
-
factorial_i = 1
|
| 417 |
-
|
| 418 |
-
if self.config.solver_type == "bh1":
|
| 419 |
-
B_h = hh
|
| 420 |
-
elif self.config.solver_type == "bh2":
|
| 421 |
-
B_h = torch.expm1(hh)
|
| 422 |
-
else:
|
| 423 |
-
raise NotImplementedError()
|
| 424 |
-
|
| 425 |
-
for i in range(1, order + 1):
|
| 426 |
-
R.append(torch.pow(rks, i - 1))
|
| 427 |
-
b.append(h_phi_k * factorial_i / B_h)
|
| 428 |
-
factorial_i *= i + 1
|
| 429 |
-
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 430 |
-
|
| 431 |
-
R = torch.stack(R)
|
| 432 |
-
b = torch.tensor(b, device=device)
|
| 433 |
-
|
| 434 |
-
if len(D1s) > 0:
|
| 435 |
-
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 436 |
-
# for order 2, we use a simplified version
|
| 437 |
-
if order == 2:
|
| 438 |
-
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 439 |
-
else:
|
| 440 |
-
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
| 441 |
-
else:
|
| 442 |
-
D1s = None
|
| 443 |
-
|
| 444 |
-
if self.predict_x0:
|
| 445 |
-
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 446 |
-
if D1s is not None:
|
| 447 |
-
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
| 448 |
-
else:
|
| 449 |
-
pred_res = 0
|
| 450 |
-
x_t = x_t_ - alpha_t * B_h * pred_res
|
| 451 |
-
else:
|
| 452 |
-
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 453 |
-
if D1s is not None:
|
| 454 |
-
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
| 455 |
-
else:
|
| 456 |
-
pred_res = 0
|
| 457 |
-
x_t = x_t_ - sigma_t * B_h * pred_res
|
| 458 |
-
|
| 459 |
-
x_t = x_t.to(x.dtype)
|
| 460 |
-
return x_t
|
| 461 |
-
|
| 462 |
-
def multistep_uni_c_bh_update(
|
| 463 |
-
self,
|
| 464 |
-
this_model_output: torch.Tensor,
|
| 465 |
-
*args,
|
| 466 |
-
last_sample: Optional[torch.Tensor] = None,
|
| 467 |
-
this_sample: Optional[torch.Tensor] = None,
|
| 468 |
-
order: Optional[int] = None,
|
| 469 |
-
**kwargs,
|
| 470 |
-
) -> torch.Tensor:
|
| 471 |
-
"""
|
| 472 |
-
One step for the UniC (B(h) version).
|
| 473 |
-
|
| 474 |
-
Args:
|
| 475 |
-
this_model_output (`torch.Tensor`):
|
| 476 |
-
The model outputs at `x_t`.
|
| 477 |
-
this_timestep (`int`):
|
| 478 |
-
The current timestep `t`.
|
| 479 |
-
last_sample (`torch.Tensor`):
|
| 480 |
-
The generated sample before the last predictor `x_{t-1}`.
|
| 481 |
-
this_sample (`torch.Tensor`):
|
| 482 |
-
The generated sample after the last predictor `x_{t}`.
|
| 483 |
-
order (`int`):
|
| 484 |
-
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
| 485 |
-
|
| 486 |
-
Returns:
|
| 487 |
-
`torch.Tensor`:
|
| 488 |
-
The corrected sample tensor at the current timestep.
|
| 489 |
-
"""
|
| 490 |
-
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
| 491 |
-
if last_sample is None:
|
| 492 |
-
if len(args) > 1:
|
| 493 |
-
last_sample = args[1]
|
| 494 |
-
else:
|
| 495 |
-
raise ValueError(" missing`last_sample` as a required keyward argument")
|
| 496 |
-
if this_sample is None:
|
| 497 |
-
if len(args) > 2:
|
| 498 |
-
this_sample = args[2]
|
| 499 |
-
else:
|
| 500 |
-
raise ValueError(" missing`this_sample` as a required keyward argument")
|
| 501 |
-
if order is None:
|
| 502 |
-
if len(args) > 3:
|
| 503 |
-
order = args[3]
|
| 504 |
-
else:
|
| 505 |
-
raise ValueError(" missing`order` as a required keyward argument")
|
| 506 |
-
if this_timestep is not None:
|
| 507 |
-
deprecate(
|
| 508 |
-
"this_timestep",
|
| 509 |
-
"1.0.0",
|
| 510 |
-
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 511 |
-
)
|
| 512 |
-
|
| 513 |
-
model_output_list = self.model_outputs
|
| 514 |
-
|
| 515 |
-
m0 = model_output_list[-1]
|
| 516 |
-
x = last_sample
|
| 517 |
-
x_t = this_sample
|
| 518 |
-
model_t = this_model_output
|
| 519 |
-
|
| 520 |
-
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore
|
| 521 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 522 |
-
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 523 |
-
|
| 524 |
-
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 525 |
-
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 526 |
-
|
| 527 |
-
h = lambda_t - lambda_s0
|
| 528 |
-
device = this_sample.device
|
| 529 |
-
|
| 530 |
-
rks = []
|
| 531 |
-
D1s = []
|
| 532 |
-
for i in range(1, order):
|
| 533 |
-
si = self.step_index - (i + 1) # pyright: ignore
|
| 534 |
-
mi = model_output_list[-(i + 1)]
|
| 535 |
-
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 536 |
-
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 537 |
-
rk = (lambda_si - lambda_s0) / h
|
| 538 |
-
rks.append(rk)
|
| 539 |
-
D1s.append((mi - m0) / rk) # pyright: ignore
|
| 540 |
-
|
| 541 |
-
rks.append(1.0)
|
| 542 |
-
rks = torch.tensor(rks, device=device)
|
| 543 |
-
|
| 544 |
-
R = []
|
| 545 |
-
b = []
|
| 546 |
-
|
| 547 |
-
hh = -h if self.predict_x0 else h
|
| 548 |
-
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 549 |
-
h_phi_k = h_phi_1 / hh - 1
|
| 550 |
-
|
| 551 |
-
factorial_i = 1
|
| 552 |
-
|
| 553 |
-
if self.config.solver_type == "bh1":
|
| 554 |
-
B_h = hh
|
| 555 |
-
elif self.config.solver_type == "bh2":
|
| 556 |
-
B_h = torch.expm1(hh)
|
| 557 |
-
else:
|
| 558 |
-
raise NotImplementedError()
|
| 559 |
-
|
| 560 |
-
for i in range(1, order + 1):
|
| 561 |
-
R.append(torch.pow(rks, i - 1))
|
| 562 |
-
b.append(h_phi_k * factorial_i / B_h)
|
| 563 |
-
factorial_i *= i + 1
|
| 564 |
-
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 565 |
-
|
| 566 |
-
R = torch.stack(R)
|
| 567 |
-
b = torch.tensor(b, device=device)
|
| 568 |
-
|
| 569 |
-
if len(D1s) > 0:
|
| 570 |
-
D1s = torch.stack(D1s, dim=1)
|
| 571 |
-
else:
|
| 572 |
-
D1s = None
|
| 573 |
-
|
| 574 |
-
# for order 1, we use a simplified version
|
| 575 |
-
if order == 1:
|
| 576 |
-
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 577 |
-
else:
|
| 578 |
-
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
| 579 |
-
|
| 580 |
-
if self.predict_x0:
|
| 581 |
-
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 582 |
-
if D1s is not None:
|
| 583 |
-
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 584 |
-
else:
|
| 585 |
-
corr_res = 0
|
| 586 |
-
D1_t = model_t - m0
|
| 587 |
-
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 588 |
-
else:
|
| 589 |
-
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 590 |
-
if D1s is not None:
|
| 591 |
-
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 592 |
-
else:
|
| 593 |
-
corr_res = 0
|
| 594 |
-
D1_t = model_t - m0
|
| 595 |
-
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 596 |
-
x_t = x_t.to(x.dtype)
|
| 597 |
-
return x_t
|
| 598 |
-
|
| 599 |
-
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 600 |
-
if schedule_timesteps is None:
|
| 601 |
-
schedule_timesteps = self.timesteps
|
| 602 |
-
|
| 603 |
-
indices = (schedule_timesteps == timestep).nonzero()
|
| 604 |
-
|
| 605 |
-
# The sigma index that is taken for the **very** first `step`
|
| 606 |
-
# is always the second index (or the last index if there is only 1)
|
| 607 |
-
# This way we can ensure we don't accidentally skip a sigma in
|
| 608 |
-
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 609 |
-
pos = 1 if len(indices) > 1 else 0
|
| 610 |
-
|
| 611 |
-
return indices[pos].item()
|
| 612 |
-
|
| 613 |
-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
| 614 |
-
def _init_step_index(self, timestep):
|
| 615 |
-
"""
|
| 616 |
-
Initialize the step_index counter for the scheduler.
|
| 617 |
-
"""
|
| 618 |
-
|
| 619 |
-
if self.begin_index is None:
|
| 620 |
-
if isinstance(timestep, torch.Tensor):
|
| 621 |
-
timestep = timestep.to(self.timesteps.device)
|
| 622 |
-
self._step_index = self.index_for_timestep(timestep)
|
| 623 |
-
else:
|
| 624 |
-
self._step_index = self._begin_index
|
| 625 |
-
|
| 626 |
-
def step(
|
| 627 |
-
self,
|
| 628 |
-
model_output: torch.Tensor,
|
| 629 |
-
timestep: Union[int, torch.Tensor],
|
| 630 |
-
sample: torch.Tensor,
|
| 631 |
-
return_dict: bool = True,
|
| 632 |
-
generator=None,
|
| 633 |
-
) -> Union[SchedulerOutput, Tuple]:
|
| 634 |
-
"""
|
| 635 |
-
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 636 |
-
the multistep UniPC.
|
| 637 |
-
|
| 638 |
-
Args:
|
| 639 |
-
model_output (`torch.Tensor`):
|
| 640 |
-
The direct output from learned diffusion model.
|
| 641 |
-
timestep (`int`):
|
| 642 |
-
The current discrete timestep in the diffusion chain.
|
| 643 |
-
sample (`torch.Tensor`):
|
| 644 |
-
A current instance of a sample created by the diffusion process.
|
| 645 |
-
return_dict (`bool`):
|
| 646 |
-
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 647 |
-
|
| 648 |
-
Returns:
|
| 649 |
-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 650 |
-
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 651 |
-
tuple is returned where the first element is the sample tensor.
|
| 652 |
-
|
| 653 |
-
"""
|
| 654 |
-
if self.num_inference_steps is None:
|
| 655 |
-
raise ValueError(
|
| 656 |
-
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 657 |
-
)
|
| 658 |
-
|
| 659 |
-
if self.step_index is None:
|
| 660 |
-
self._init_step_index(timestep)
|
| 661 |
-
|
| 662 |
-
use_corrector = (
|
| 663 |
-
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None # pyright: ignore
|
| 664 |
-
)
|
| 665 |
-
|
| 666 |
-
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
| 667 |
-
if use_corrector:
|
| 668 |
-
sample = self.multistep_uni_c_bh_update(
|
| 669 |
-
this_model_output=model_output_convert,
|
| 670 |
-
last_sample=self.last_sample,
|
| 671 |
-
this_sample=sample,
|
| 672 |
-
order=self.this_order,
|
| 673 |
-
)
|
| 674 |
-
|
| 675 |
-
for i in range(self.config.solver_order - 1):
|
| 676 |
-
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 677 |
-
self.timestep_list[i] = self.timestep_list[i + 1]
|
| 678 |
-
|
| 679 |
-
self.model_outputs[-1] = model_output_convert
|
| 680 |
-
self.timestep_list[-1] = timestep # pyright: ignore
|
| 681 |
-
|
| 682 |
-
if self.config.lower_order_final:
|
| 683 |
-
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
|
| 684 |
-
else:
|
| 685 |
-
this_order = self.config.solver_order
|
| 686 |
-
|
| 687 |
-
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
| 688 |
-
assert self.this_order > 0
|
| 689 |
-
|
| 690 |
-
self.last_sample = sample
|
| 691 |
-
prev_sample = self.multistep_uni_p_bh_update(
|
| 692 |
-
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
| 693 |
-
sample=sample,
|
| 694 |
-
order=self.this_order,
|
| 695 |
-
)
|
| 696 |
-
|
| 697 |
-
if self.lower_order_nums < self.config.solver_order:
|
| 698 |
-
self.lower_order_nums += 1
|
| 699 |
-
|
| 700 |
-
# upon completion increase step index by one
|
| 701 |
-
self._step_index += 1 # pyright: ignore
|
| 702 |
-
|
| 703 |
-
if not return_dict:
|
| 704 |
-
return (prev_sample,)
|
| 705 |
-
|
| 706 |
-
return SchedulerOutput(prev_sample=prev_sample)
|
| 707 |
-
|
| 708 |
-
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 709 |
-
"""
|
| 710 |
-
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 711 |
-
current timestep.
|
| 712 |
-
|
| 713 |
-
Args:
|
| 714 |
-
sample (`torch.Tensor`):
|
| 715 |
-
The input sample.
|
| 716 |
-
|
| 717 |
-
Returns:
|
| 718 |
-
`torch.Tensor`:
|
| 719 |
-
A scaled input sample.
|
| 720 |
-
"""
|
| 721 |
-
return sample
|
| 722 |
-
|
| 723 |
-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
| 724 |
-
def add_noise(
|
| 725 |
-
self,
|
| 726 |
-
original_samples: torch.Tensor,
|
| 727 |
-
noise: torch.Tensor,
|
| 728 |
-
timesteps: torch.IntTensor,
|
| 729 |
-
) -> torch.Tensor:
|
| 730 |
-
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 731 |
-
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 732 |
-
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 733 |
-
# mps does not support float64
|
| 734 |
-
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 735 |
-
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 736 |
-
else:
|
| 737 |
-
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 738 |
-
timesteps = timesteps.to(original_samples.device)
|
| 739 |
-
|
| 740 |
-
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
| 741 |
-
if self.begin_index is None:
|
| 742 |
-
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 743 |
-
elif self.step_index is not None:
|
| 744 |
-
# add_noise is called after first denoising step (for inpainting)
|
| 745 |
-
step_indices = [self.step_index] * timesteps.shape[0]
|
| 746 |
-
else:
|
| 747 |
-
# add noise is called before first denoising step to create initial latent(img2img)
|
| 748 |
-
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 749 |
-
|
| 750 |
-
sigma = sigmas[step_indices].flatten()
|
| 751 |
-
while len(sigma.shape) < len(original_samples.shape):
|
| 752 |
-
sigma = sigma.unsqueeze(-1)
|
| 753 |
-
|
| 754 |
-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 755 |
-
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 756 |
-
return noisy_samples
|
| 757 |
-
|
| 758 |
-
def __len__(self):
|
| 759 |
-
return self.config.num_train_timesteps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assets/astronaut.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c13fd69a91c40ce217162e1bee917c23b86fcd878301bcab11a48fcd3bfeded
|
| 3 |
+
size 1020141
|
assets/bird.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ceddcad73d270f8a03a5bfa5ab6cc3dc74c9b1ff3db3a652151b5c26efd36e9c
|
| 3 |
+
size 757207
|
assets/fisherman.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e67764b0104e66ec56dab3cc216f3887aef84d25f3bb7758443de1b8624c745
|
| 3 |
+
size 1506256
|
assets/i2v_sample.jpg
ADDED
|
Git LFS Details
|
assets/sage_compare_BF16.webp
ADDED
|
Git LFS Details
|
assets/sage_compare_Q4_K_M.webp
ADDED
|
Git LFS Details
|
assets/sage_compare_Q5_K_M.webp
ADDED
|
Git LFS Details
|
assets/sage_compare_Q8_0.webp
ADDED
|
Git LFS Details
|
assets/underwater.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:184b72b8672bca924bf8f0b568488daac507eafd8dd8ddcb15f9928549309550
|
| 3 |
+
size 2024562
|
assets/woman.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2aa2e93ccc2d333f26323a0944e4a3e9c0ee3064824ae23b245f5ab2c548a947
|
| 3 |
+
size 2057707
|
docs/gguf-sageattention.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🧊 GGUF + SageAttention
|
| 2 |
+
|
| 3 |
+
> See the main [README](../README.md) for `FlowDPMSolver` and pipeline setup.
|
| 4 |
+
|
| 5 |
+
GGUF quantized transformer weights are available at [Motif-Video-2B-GGUF](https://huggingface.co/Motif-Technologies/Motif-Video-2B-GGUF), reducing VRAM with minimal quality loss. Combined with [SageAttention](https://github.com/thu-ml/SageAttention) for ~2× faster attention computation.
|
| 6 |
+
|
| 7 |
+
## GGUF Inference
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
pip install gguf
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
import torch
|
| 15 |
+
from diffusers import (
|
| 16 |
+
AdaptiveProjectedGuidance,
|
| 17 |
+
DPMSolverMultistepScheduler,
|
| 18 |
+
GGUFQuantizationConfig,
|
| 19 |
+
MotifVideoPipeline,
|
| 20 |
+
MotifVideoTransformer3DModel,
|
| 21 |
+
)
|
| 22 |
+
from diffusers.utils import export_to_video
|
| 23 |
+
from huggingface_hub import hf_hub_download
|
| 24 |
+
|
| 25 |
+
guider = AdaptiveProjectedGuidance(
|
| 26 |
+
guidance_scale=8.0,
|
| 27 |
+
adaptive_projected_guidance_rescale=12.0,
|
| 28 |
+
adaptive_projected_guidance_momentum=0.1,
|
| 29 |
+
use_original_formulation=True,
|
| 30 |
+
normalization_dims="spatial",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
variant = "Q4_K_M" # Options: Q4_0, Q4_1, Q4_K_M, Q5_0, Q5_1, Q5_K_M, Q6_K, Q8_0, BF16
|
| 34 |
+
ckpt_path = hf_hub_download(
|
| 35 |
+
"Motif-Technologies/Motif-Video-2B-GGUF",
|
| 36 |
+
filename=f"motifv-2b-dev-{variant}.gguf",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
transformer = MotifVideoTransformer3DModel.from_single_file(
|
| 40 |
+
ckpt_path,
|
| 41 |
+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
| 42 |
+
config="Motif-Technologies/Motif-Video-2B",
|
| 43 |
+
revision="diffusers-integration",
|
| 44 |
+
subfolder="transformer",
|
| 45 |
+
torch_dtype=torch.bfloat16,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
pipe = MotifVideoPipeline.from_pretrained(
|
| 49 |
+
"Motif-Technologies/Motif-Video-2B",
|
| 50 |
+
revision="diffusers-integration",
|
| 51 |
+
torch_dtype=torch.bfloat16,
|
| 52 |
+
guider=guider,
|
| 53 |
+
transformer=transformer,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
pipe.scheduler = FlowDPMSolver(
|
| 57 |
+
num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
|
| 58 |
+
algorithm_type="dpmsolver++",
|
| 59 |
+
solver_order=2,
|
| 60 |
+
prediction_type="flow_prediction",
|
| 61 |
+
use_flow_sigmas=True,
|
| 62 |
+
flow_shift=15.0,
|
| 63 |
+
)
|
| 64 |
+
pipe.enable_model_cpu_offload()
|
| 65 |
+
|
| 66 |
+
output = pipe(
|
| 67 |
+
prompt="A woman standing in a sunlit field as flower petals swirl around her in slow motion. Each petal floats gently through the golden light, casting tiny shadows. Her hair moves like water, and time seems to stand still.",
|
| 68 |
+
negative_prompt="text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
|
| 69 |
+
height=736,
|
| 70 |
+
width=1280,
|
| 71 |
+
num_frames=121,
|
| 72 |
+
num_inference_steps=50,
|
| 73 |
+
frame_rate=24,
|
| 74 |
+
use_linear_quadratic_schedule=False,
|
| 75 |
+
)
|
| 76 |
+
export_to_video(output.frames[0], "output.mp4", fps=24)
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## SageAttention (Optional, ~1.6× faster)
|
| 80 |
+
|
| 81 |
+
Same prompt and seed, 1280x736, 121 frames, 50 steps. Left = SDPA, Right = SageAttention.
|
| 82 |
+
|
| 83 |
+

|
| 84 |
+

|
| 85 |
+

|
| 86 |
+

|
| 87 |
+
|
| 88 |
+
[SageAttention](https://github.com/thu-ml/SageAttention) accelerates attention by quantizing Q/K to INT8 and V to FP8, reducing memory bandwidth. Works with all GGUF variants.
|
| 89 |
+
|
| 90 |
+
**Install** (build from source — PyPI only has 1.x, need 2.x):
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
# Set TORCH_CUDA_ARCH_LIST to match your GPU: "8.0" for A100, "9.0" for H100/H200
|
| 94 |
+
TORCH_CUDA_ARCH_LIST="9.0" pip install git+https://github.com/thu-ml/SageAttention.git --no-build-isolation
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
**Usage with `inference.py`:**
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
python inference.py --use-sage-attention --prompt "..."
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
**Notes:**
|
| 104 |
+
- Requires NVIDIA GPU with SM70+
|
| 105 |
+
- SM90+ (H100, H200) — FP8 kernels for maximum speedup
|
| 106 |
+
- SM80-SM89 (A100, RTX 3090, RTX 4090) — FP16 kernels (still faster than SDPA)
|
| 107 |
+
- SM70-SM75 (V100, RTX 2080 Ti) — FP16 kernels
|
| 108 |
+
- Set `TORCH_CUDA_ARCH_LIST` to match your GPU when building (e.g., `"8.6"` for RTX 3090, `"8.9"` for RTX 4090)
|
| 109 |
+
- No quality degradation observed across all GGUF variants
|
| 110 |
+
|
| 111 |
+
## Benchmark
|
| 112 |
+
|
| 113 |
+
Measured on NVIDIA H200, 1280x736, 121 frames, 50 steps, DPMSolver++ (order=2, flow_shift=15.0):
|
| 114 |
+
|
| 115 |
+
| Variant | SDPA (s/it) | Sage (s/it) | Speedup | Peak alloc (GB) | Peak rsv (GB) | Total SDPA (s) | Total Sage (s) |
|
| 116 |
+
|---------|------------|------------|---------|-----------------|----------------|----------------|----------------|
|
| 117 |
+
| BF16 | 23.36 | 14.75 | 1.58x | 14.78 / 15.12 | 24.93 / 24.90 | 1184 | 754 |
|
| 118 |
+
| Q8_0 | 23.16 | 14.49 | 1.60x | 13.10 / 13.44 | 23.14 / 23.11 | 1178 | 744 |
|
| 119 |
+
| Q6_K | 23.21 | 14.55 | 1.60x | 12.62 / 12.95 | 22.72 / 22.69 | 1178 | 747 |
|
| 120 |
+
| Q5_K_M | 23.33 | 14.69 | 1.59x | 12.39 / 12.72 | 22.45 / 22.42 | 1184 | 754 |
|
| 121 |
+
| Q5_1 | 23.54 | 14.96 | 1.57x | 12.47 / 12.81 | 22.66 / 22.62 | 1193 | 764 |
|
| 122 |
+
| Q5_0 | 23.26 | 14.67 | 1.59x | 12.37 / 12.71 | 22.55 / 22.52 | 1179 | 750 |
|
| 123 |
+
| Q4_K_M | 23.25 | 14.59 | 1.60x | 12.19 / 12.53 | 22.22 / 22.18 | 1178 | 747 |
|
| 124 |
+
| Q4_1 | 23.31 | 14.68 | 1.59x | 12.26 / 12.60 | 22.26 / 22.22 | 1181 | 750 |
|
| 125 |
+
| Q4_0 | 23.33 | 14.75 | 1.58x | 12.14 / 12.47 | 22.18 / 22.14 | 1188 | 760 |
|
| 126 |
+
|
| 127 |
+
Peak alloc/rsv columns show SDPA / Sage values. Sage adds ~0.3 GB alloc overhead (INT8/FP8 quantization buffers) with no change in reserved memory.
|
| 128 |
+
|
| 129 |
+
**Key findings:**
|
| 130 |
+
- **~1.59x faster with SageAttention** — consistent across all quantization levels
|
| 131 |
+
- **VRAM unchanged** — sage overhead is negligible (~0.3 GB alloc)
|
| 132 |
+
- **GGUF + Sage stacks** — Q4_K_M + Sage achieves 14.59 s/it at 12.53 GB alloc (vs BF16 SDPA: 23.36 s/it at 14.78 GB)
|
docs/memory-efficient-inference.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Memory-efficient Inference
|
| 2 |
+
|
| 3 |
+
> See the main [README](../README.md) for `FlowDPMSolver` and `guider` setup.
|
| 4 |
+
|
| 5 |
+
By default, `pipe.to("cuda")` loads all components onto the GPU simultaneously, requiring **~30 GB VRAM**.
|
| 6 |
+
|
| 7 |
+
For GPUs with 24 GB or less (e.g. RTX 4090, RTX 3090), use `enable_model_cpu_offload()` with the `expandable_segments` allocator setting:
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
pipe = MotifVideoPipeline.from_pretrained(
|
| 15 |
+
"Motif-Technologies/Motif-Video-2B",
|
| 16 |
+
revision="diffusers-integration",
|
| 17 |
+
torch_dtype=torch.bfloat16,
|
| 18 |
+
guider=guider, # see T2V example above
|
| 19 |
+
)
|
| 20 |
+
pipe.scheduler = FlowDPMSolver(
|
| 21 |
+
num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
|
| 22 |
+
algorithm_type="dpmsolver++",
|
| 23 |
+
solver_order=2,
|
| 24 |
+
prediction_type="flow_prediction",
|
| 25 |
+
use_flow_sigmas=True,
|
| 26 |
+
flow_shift=15.0,
|
| 27 |
+
)
|
| 28 |
+
pipe.enable_model_cpu_offload() # replaces pipe.to("cuda")
|
| 29 |
+
|
| 30 |
+
output = pipe(
|
| 31 |
+
prompt="...",
|
| 32 |
+
negative_prompt="...",
|
| 33 |
+
height=736, width=1280, num_frames=121, num_inference_steps=50,
|
| 34 |
+
frame_rate=24, use_linear_quadratic_schedule=False,
|
| 35 |
+
)
|
| 36 |
+
export_to_video(output.frames[0], "output.mp4", fps=24)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
This moves each component (text encoder → transformer → VAE) to GPU only when needed. The `expandable_segments` setting allows the CUDA memory allocator to efficiently reuse memory released by earlier components, avoiding fragmentation-related OOM errors.
|
| 40 |
+
|
| 41 |
+
| Mode | Peak VRAM | Speed | Recommended GPU |
|
| 42 |
+
|------|-----------|-------|-----------------|
|
| 43 |
+
| `pipe.to("cuda")` | ~30 GB | Fastest | A100, H100, H200 |
|
| 44 |
+
| `enable_model_cpu_offload()` | ~19 GB | Similar | RTX 4090, RTX 3090 |
|
| 45 |
+
|
| 46 |
+
## FP8 Weight Quantization (Optional)
|
| 47 |
+
|
| 48 |
+
For further VRAM reduction, you can quantize the transformer weights to FP8 using [torchao](https://github.com/pytorch/ao):
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
pip install torchao
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
from torchao.quantization import quantize_, Float8WeightOnlyConfig
|
| 56 |
+
|
| 57 |
+
pipe = MotifVideoPipeline.from_pretrained(
|
| 58 |
+
"Motif-Technologies/Motif-Video-2B",
|
| 59 |
+
revision="diffusers-integration",
|
| 60 |
+
torch_dtype=torch.bfloat16,
|
| 61 |
+
guider=guider, # see T2V example above
|
| 62 |
+
)
|
| 63 |
+
pipe.scheduler = FlowDPMSolver(
|
| 64 |
+
num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
|
| 65 |
+
algorithm_type="dpmsolver++",
|
| 66 |
+
solver_order=2,
|
| 67 |
+
prediction_type="flow_prediction",
|
| 68 |
+
use_flow_sigmas=True,
|
| 69 |
+
flow_shift=15.0,
|
| 70 |
+
)
|
| 71 |
+
quantize_(pipe.transformer, Float8WeightOnlyConfig())
|
| 72 |
+
pipe.enable_model_cpu_offload()
|
| 73 |
+
|
| 74 |
+
output = pipe(
|
| 75 |
+
prompt="...",
|
| 76 |
+
negative_prompt="...",
|
| 77 |
+
height=736, width=1280, num_frames=121, num_inference_steps=50,
|
| 78 |
+
frame_rate=24, use_linear_quadratic_schedule=False,
|
| 79 |
+
)
|
| 80 |
+
export_to_video(output.frames[0], "output.mp4", fps=24)
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
This stores the transformer weights in FP8 (8-bit) instead of BF16 (16-bit), reducing peak VRAM from ~19 GB to ~15 GB while keeping all computation in BF16 precision.
|
| 84 |
+
|
| 85 |
+
| Mode | Peak VRAM | Notes |
|
| 86 |
+
|------|-----------|-------|
|
| 87 |
+
| `enable_model_cpu_offload()` | ~19 GB | BF16 baseline |
|
| 88 |
+
| `+ Float8WeightOnlyConfig` | ~15 GB | FP8 weights, BF16 compute |
|
inference.py
DELETED
|
@@ -1,119 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Motif-Video 2B — Text-to-Video & Image-to-Video inference.
|
| 3 |
-
|
| 4 |
-
GPU requirements: ~24GB VRAM for 720p (1280x736, 121 frames).
|
| 5 |
-
Tested with: torch>=2.0, diffusers>=0.35.2, transformers>=5.0.0
|
| 6 |
-
|
| 7 |
-
Uses Adaptive Projected Guidance (APG) by default for best quality.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import argparse
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
|
| 14 |
-
from diffusers.utils import export_to_video
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def parse_args():
|
| 18 |
-
parser = argparse.ArgumentParser(description="Motif-Video 2B Inference (T2V / I2V)")
|
| 19 |
-
parser.add_argument(
|
| 20 |
-
"--model-path",
|
| 21 |
-
type=str,
|
| 22 |
-
default="Motif-Technologies/Motif-Video-2B",
|
| 23 |
-
help="HuggingFace model ID or local checkpoint path (uses trust_remote_code=True)",
|
| 24 |
-
)
|
| 25 |
-
parser.add_argument(
|
| 26 |
-
"--prompt",
|
| 27 |
-
type=str,
|
| 28 |
-
default="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
|
| 29 |
-
help="Text prompt for video generation",
|
| 30 |
-
)
|
| 31 |
-
parser.add_argument(
|
| 32 |
-
"--image",
|
| 33 |
-
type=str,
|
| 34 |
-
default=None,
|
| 35 |
-
help="Path to input image for I2V mode (omit for T2V)",
|
| 36 |
-
)
|
| 37 |
-
parser.add_argument(
|
| 38 |
-
"--negative-prompt",
|
| 39 |
-
type=str,
|
| 40 |
-
default=None,
|
| 41 |
-
help="Negative prompt (default: built-in pipeline default)",
|
| 42 |
-
)
|
| 43 |
-
parser.add_argument("--output", type=str, default="output.mp4", help="Output video file path")
|
| 44 |
-
parser.add_argument("--num-frames", type=int, default=121, help="Number of frames to generate (121 = ~5s at 24fps)")
|
| 45 |
-
parser.add_argument("--height", type=int, default=736, help="Video height in pixels")
|
| 46 |
-
parser.add_argument("--width", type=int, default=1280, help="Video width in pixels")
|
| 47 |
-
parser.add_argument("--guidance-scale", type=float, default=8.0, help="Classifier-free guidance scale")
|
| 48 |
-
parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of denoising steps")
|
| 49 |
-
parser.add_argument("--fps", type=int, default=24, help="Output video frame rate")
|
| 50 |
-
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
| 51 |
-
parser.add_argument(
|
| 52 |
-
"--dtype",
|
| 53 |
-
type=str,
|
| 54 |
-
default="bfloat16",
|
| 55 |
-
choices=["float16", "bfloat16", "float32"],
|
| 56 |
-
help="Model dtype",
|
| 57 |
-
)
|
| 58 |
-
return parser.parse_args()
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def main():
|
| 62 |
-
args = parse_args()
|
| 63 |
-
|
| 64 |
-
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
|
| 65 |
-
torch_dtype = dtype_map[args.dtype]
|
| 66 |
-
|
| 67 |
-
mode = "I2V" if args.image else "T2V"
|
| 68 |
-
print(f"[{mode}] Loading model from: {args.model_path}")
|
| 69 |
-
|
| 70 |
-
guider = AdaptiveProjectedGuidance(
|
| 71 |
-
guidance_scale=args.guidance_scale,
|
| 72 |
-
adaptive_projected_guidance_rescale=12.0,
|
| 73 |
-
adaptive_projected_guidance_momentum=0.1,
|
| 74 |
-
eta=0.0,
|
| 75 |
-
use_original_formulation=True,
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
pipe = DiffusionPipeline.from_pretrained(
|
| 79 |
-
args.model_path,
|
| 80 |
-
custom_pipeline="pipeline_motif_video",
|
| 81 |
-
trust_remote_code=True,
|
| 82 |
-
torch_dtype=torch_dtype,
|
| 83 |
-
guider=guider,
|
| 84 |
-
)
|
| 85 |
-
pipe = pipe.to("cuda")
|
| 86 |
-
|
| 87 |
-
generator = torch.Generator(device="cuda").manual_seed(args.seed)
|
| 88 |
-
|
| 89 |
-
# Load image for I2V mode
|
| 90 |
-
image = None
|
| 91 |
-
if args.image:
|
| 92 |
-
from PIL import Image
|
| 93 |
-
|
| 94 |
-
image = Image.open(args.image).convert("RGB")
|
| 95 |
-
print(f"[I2V] Input image: {args.image} ({image.size[0]}x{image.size[1]})")
|
| 96 |
-
|
| 97 |
-
print(f"Generating video: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps")
|
| 98 |
-
pipe_kwargs = dict(
|
| 99 |
-
prompt=args.prompt,
|
| 100 |
-
image=image,
|
| 101 |
-
height=args.height,
|
| 102 |
-
width=args.width,
|
| 103 |
-
num_frames=args.num_frames,
|
| 104 |
-
num_inference_steps=args.num_inference_steps,
|
| 105 |
-
generator=generator,
|
| 106 |
-
frame_rate=args.fps,
|
| 107 |
-
)
|
| 108 |
-
if args.negative_prompt is not None:
|
| 109 |
-
pipe_kwargs["negative_prompt"] = args.negative_prompt
|
| 110 |
-
|
| 111 |
-
output = pipe(**pipe_kwargs)
|
| 112 |
-
|
| 113 |
-
video_frames = output.frames[0]
|
| 114 |
-
export_to_video(video_frames, args.output, fps=args.fps)
|
| 115 |
-
print(f"Video saved to: {args.output}")
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
if __name__ == "__main__":
|
| 119 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_index.json
CHANGED
|
@@ -14,7 +14,7 @@
|
|
| 14 |
"GemmaTokenizer"
|
| 15 |
],
|
| 16 |
"transformer": [
|
| 17 |
-
"
|
| 18 |
"MotifVideoTransformer3DModel"
|
| 19 |
],
|
| 20 |
"vae": [
|
|
|
|
| 14 |
"GemmaTokenizer"
|
| 15 |
],
|
| 16 |
"transformer": [
|
| 17 |
+
"diffusers",
|
| 18 |
"MotifVideoTransformer3DModel"
|
| 19 |
],
|
| 20 |
"vae": [
|
pipeline_motif_video.py
DELETED
|
@@ -1,1388 +0,0 @@
|
|
| 1 |
-
# Copyright 2026 Motif Technologies, Inc. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import html
|
| 16 |
-
import inspect
|
| 17 |
-
from dataclasses import dataclass
|
| 18 |
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
-
|
| 20 |
-
import ftfy
|
| 21 |
-
import numpy as np
|
| 22 |
-
import regex as re
|
| 23 |
-
import torch
|
| 24 |
-
from diffusers import (
|
| 25 |
-
AdaptiveProjectedGuidance,
|
| 26 |
-
AutoencoderKLWan,
|
| 27 |
-
ClassifierFreeGuidance,
|
| 28 |
-
DiffusionPipeline,
|
| 29 |
-
DPMSolverMultistepScheduler,
|
| 30 |
-
FlowMatchEulerDiscreteScheduler,
|
| 31 |
-
SkipLayerGuidance,
|
| 32 |
-
UniPCMultistepScheduler,
|
| 33 |
-
)
|
| 34 |
-
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 35 |
-
from diffusers.guiders.adaptive_projected_guidance import MomentumBuffer
|
| 36 |
-
from diffusers.guiders.guider_utils import GuiderOutput
|
| 37 |
-
from diffusers.utils import (
|
| 38 |
-
BaseOutput,
|
| 39 |
-
is_torch_xla_available,
|
| 40 |
-
logging,
|
| 41 |
-
replace_example_docstring,
|
| 42 |
-
)
|
| 43 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 44 |
-
from diffusers.video_processor import VideoProcessor
|
| 45 |
-
from einops import rearrange
|
| 46 |
-
from PIL import Image
|
| 47 |
-
from torch import Tensor
|
| 48 |
-
|
| 49 |
-
from transformers import (
|
| 50 |
-
BatchEncoding,
|
| 51 |
-
PreTrainedTokenizerBase,
|
| 52 |
-
SiglipImageProcessor,
|
| 53 |
-
T5Gemma2Encoder,
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
from ._fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
if is_torch_xla_available():
|
| 60 |
-
import torch_xla.core.xla_model as xm
|
| 61 |
-
|
| 62 |
-
XLA_AVAILABLE = True
|
| 63 |
-
else:
|
| 64 |
-
XLA_AVAILABLE = False
|
| 65 |
-
|
| 66 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 67 |
-
|
| 68 |
-
EXAMPLE_DOC_STRING = """
|
| 69 |
-
Examples:
|
| 70 |
-
```py
|
| 71 |
-
>>> import torch
|
| 72 |
-
>>> from diffusers import MotifVideoPipeline
|
| 73 |
-
>>> from diffusers.utils import export_to_video
|
| 74 |
-
|
| 75 |
-
>>> # Load the Motif Video pipeline
|
| 76 |
-
>>> motif_video_model_id = "MotifTechnologies/Motif-Video"
|
| 77 |
-
>>> pipe = MotifVideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16)
|
| 78 |
-
>>> pipe.to("cuda")
|
| 79 |
-
|
| 80 |
-
>>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
| 81 |
-
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
| 82 |
-
|
| 83 |
-
>>> video = pipe(
|
| 84 |
-
... prompt=prompt,
|
| 85 |
-
... negative_prompt=negative_prompt,
|
| 86 |
-
... width=640,
|
| 87 |
-
... height=352,
|
| 88 |
-
... num_frames=65,
|
| 89 |
-
... num_inference_steps=50,
|
| 90 |
-
... ).frames[0]
|
| 91 |
-
>>> export_to_video(video, "output.mp4", fps=16)
|
| 92 |
-
```
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
@dataclass
|
| 97 |
-
class MotifVideoPipelineOutput(BaseOutput):
|
| 98 |
-
r"""
|
| 99 |
-
Output class for Motif Video pipelines.
|
| 100 |
-
|
| 101 |
-
Args:
|
| 102 |
-
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 103 |
-
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 104 |
-
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 105 |
-
`(batch_size, num_frames, channels, height, width)`.
|
| 106 |
-
"""
|
| 107 |
-
|
| 108 |
-
frames: torch.Tensor
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
"""Video-aware Adaptive Projected Guidance (APG).
|
| 112 |
-
|
| 113 |
-
Standard APG normalizes over all spatial dimensions [C, T, H, W], which collapses
|
| 114 |
-
temporal variation. This module normalizes over [C, H, W] only, preserving
|
| 115 |
-
per-frame independence.
|
| 116 |
-
"""
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def video_normalized_guidance(
|
| 120 |
-
pred_cond: torch.Tensor,
|
| 121 |
-
pred_uncond: torch.Tensor,
|
| 122 |
-
guidance_scale: float,
|
| 123 |
-
momentum_buffer: MomentumBuffer | None = None,
|
| 124 |
-
eta: float = 1.0,
|
| 125 |
-
norm_threshold: float = 0.0,
|
| 126 |
-
use_original_formulation: bool = False,
|
| 127 |
-
) -> torch.Tensor:
|
| 128 |
-
"""APG with video-aware normalization: normalize over [C, H, W], exclude T.
|
| 129 |
-
|
| 130 |
-
For 5D input [B, C, T, H, W], dim=[-1, -2, -4] normalizes per-frame (W, H, C),
|
| 131 |
-
keeping the T dimension independent. For 4D input [B, C, H, W], falls back to
|
| 132 |
-
standard [-1, -2, -3] behavior.
|
| 133 |
-
"""
|
| 134 |
-
diff = pred_cond - pred_uncond
|
| 135 |
-
|
| 136 |
-
if len(diff.shape) == 5:
|
| 137 |
-
# [B, C, T, H, W] → normalize over W(-1), H(-2), C(-4), skip T(-3)
|
| 138 |
-
dim = [-1, -2, -4]
|
| 139 |
-
else:
|
| 140 |
-
# [B, C, H, W] → standard behavior
|
| 141 |
-
dim = [-i for i in range(1, len(diff.shape))]
|
| 142 |
-
|
| 143 |
-
if momentum_buffer is not None:
|
| 144 |
-
momentum_buffer.update(diff)
|
| 145 |
-
diff = momentum_buffer.running_average
|
| 146 |
-
|
| 147 |
-
if norm_threshold > 0:
|
| 148 |
-
ones = torch.ones_like(diff)
|
| 149 |
-
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
| 150 |
-
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
| 151 |
-
diff = diff * scale_factor
|
| 152 |
-
|
| 153 |
-
v0, v1 = diff.double(), pred_cond.double()
|
| 154 |
-
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
| 155 |
-
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
| 156 |
-
v0_orthogonal = v0 - v0_parallel
|
| 157 |
-
diff_parallel, diff_orthogonal = (
|
| 158 |
-
v0_parallel.type_as(diff),
|
| 159 |
-
v0_orthogonal.type_as(diff),
|
| 160 |
-
)
|
| 161 |
-
normalized_update = diff_orthogonal + eta * diff_parallel
|
| 162 |
-
|
| 163 |
-
pred = pred_cond if use_original_formulation else pred_uncond
|
| 164 |
-
pred = pred + guidance_scale * normalized_update
|
| 165 |
-
|
| 166 |
-
return pred
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
class VideoAdaptiveProjectedGuidance(AdaptiveProjectedGuidance):
|
| 170 |
-
"""APG variant that normalizes over [C, H, W] per frame, excluding the T dimension."""
|
| 171 |
-
|
| 172 |
-
def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
|
| 173 |
-
pred = None
|
| 174 |
-
|
| 175 |
-
if not self._is_apg_enabled():
|
| 176 |
-
pred = pred_cond
|
| 177 |
-
else:
|
| 178 |
-
pred = video_normalized_guidance(
|
| 179 |
-
pred_cond,
|
| 180 |
-
pred_uncond,
|
| 181 |
-
self.guidance_scale,
|
| 182 |
-
self.momentum_buffer,
|
| 183 |
-
self.eta,
|
| 184 |
-
self.adaptive_projected_guidance_rescale,
|
| 185 |
-
self.use_original_formulation,
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
if self.guidance_rescale > 0.0:
|
| 189 |
-
from diffusers.guiders.classifier_free_guidance import rescale_noise_cfg
|
| 190 |
-
|
| 191 |
-
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
| 192 |
-
|
| 193 |
-
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 197 |
-
def calculate_shift(
|
| 198 |
-
image_seq_len,
|
| 199 |
-
base_seq_len: int = 256,
|
| 200 |
-
max_seq_len: int = 4096,
|
| 201 |
-
base_shift: float = 0.5,
|
| 202 |
-
max_shift: float = 1.15,
|
| 203 |
-
):
|
| 204 |
-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 205 |
-
b = base_shift - m * base_seq_len
|
| 206 |
-
mu = image_seq_len * m + b
|
| 207 |
-
return mu
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
def get_linear_quadratic_sigmas(
|
| 211 |
-
num_inference_steps: int,
|
| 212 |
-
linear_quadratic_emulating_steps: int = 250,
|
| 213 |
-
) -> np.ndarray:
|
| 214 |
-
"""
|
| 215 |
-
Compute a linear-quadratic sigma schedule for flow matching.
|
| 216 |
-
|
| 217 |
-
This schedule combines:
|
| 218 |
-
- First half: Linear interpolation from high noise to medium noise (slow denoising)
|
| 219 |
-
- Second half: Quadratic interpolation from medium noise to clean (faster denoising)
|
| 220 |
-
|
| 221 |
-
Convention:
|
| 222 |
-
- sigma=1.0 represents pure noise
|
| 223 |
-
- sigma=0.0 represents clean image
|
| 224 |
-
- Output sigmas are in descending order (1.0 → ~0)
|
| 225 |
-
|
| 226 |
-
Args:
|
| 227 |
-
num_inference_steps: Total number of denoising steps (must be even).
|
| 228 |
-
linear_quadratic_emulating_steps: Controls the slope of linear interpolation.
|
| 229 |
-
Higher values result in gentler slope in the first half.
|
| 230 |
-
|
| 231 |
-
Returns:
|
| 232 |
-
np.ndarray: Array of sigma values with shape (num_inference_steps,).
|
| 233 |
-
The scheduler will append a terminal 0.
|
| 234 |
-
|
| 235 |
-
Raises:
|
| 236 |
-
ValueError: If num_inference_steps is not even.
|
| 237 |
-
|
| 238 |
-
Reference:
|
| 239 |
-
Linear-quadratic timestep schedule for improved flow matching inference.
|
| 240 |
-
"""
|
| 241 |
-
if num_inference_steps % 2 != 0:
|
| 242 |
-
raise ValueError(
|
| 243 |
-
f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}"
|
| 244 |
-
)
|
| 245 |
-
|
| 246 |
-
steps = num_inference_steps
|
| 247 |
-
N = linear_quadratic_emulating_steps
|
| 248 |
-
half_steps = steps // 2
|
| 249 |
-
|
| 250 |
-
# First half: linear interpolation from 1 toward 0
|
| 251 |
-
# Takes first half_steps values from linspace(1, 0, N+1)
|
| 252 |
-
linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps]
|
| 253 |
-
|
| 254 |
-
# Second half: quadratic interpolation
|
| 255 |
-
# Formula: x^2 * (half_steps/N - 1) - (half_steps/N - 1)
|
| 256 |
-
# = (half_steps/N - 1) * (x^2 - 1)
|
| 257 |
-
# This maps x=0 to (half_steps/N - 1) * (-1) = 1 - half_steps/N
|
| 258 |
-
# and maps x=1 to 0
|
| 259 |
-
x = np.linspace(0.0, 1.0, half_steps + 1)
|
| 260 |
-
scale_factor = half_steps / N - 1 # negative value
|
| 261 |
-
quadratic_part = x**2 * scale_factor - scale_factor
|
| 262 |
-
|
| 263 |
-
# Concatenate and exclude the last 0 (scheduler appends terminal 0)
|
| 264 |
-
sigmas = np.concatenate([linear_part, quadratic_part])
|
| 265 |
-
sigmas = sigmas[:-1] # Remove trailing 0, scheduler will append it
|
| 266 |
-
|
| 267 |
-
return sigmas.astype(np.float32)
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 271 |
-
def retrieve_timesteps(
|
| 272 |
-
scheduler,
|
| 273 |
-
num_inference_steps: Optional[int] = None,
|
| 274 |
-
device: Optional[Union[str, torch.device]] = None,
|
| 275 |
-
timesteps: Optional[List[int]] = None,
|
| 276 |
-
sigmas: Optional[List[float]] = None,
|
| 277 |
-
use_linear_quadratic_schedule: bool = False,
|
| 278 |
-
linear_quadratic_emulating_steps: int = 250,
|
| 279 |
-
**kwargs,
|
| 280 |
-
):
|
| 281 |
-
"""
|
| 282 |
-
Retrieve timesteps from the scheduler.
|
| 283 |
-
|
| 284 |
-
Args:
|
| 285 |
-
scheduler: The noise scheduler to use.
|
| 286 |
-
num_inference_steps: Number of denoising steps.
|
| 287 |
-
device: Device to place timesteps on.
|
| 288 |
-
timesteps: Custom timestep values (mutually exclusive with sigmas).
|
| 289 |
-
sigmas: Custom sigma values (mutually exclusive with timesteps).
|
| 290 |
-
use_linear_quadratic_schedule: If True, use linear-quadratic sigma schedule.
|
| 291 |
-
This overrides the default linear schedule. Requires num_inference_steps
|
| 292 |
-
to be even.
|
| 293 |
-
linear_quadratic_emulating_steps: Controls the linear portion slope.
|
| 294 |
-
Higher values result in gentler slope in the first half. Default: 250.
|
| 295 |
-
**kwargs: Additional arguments passed to scheduler.set_timesteps().
|
| 296 |
-
|
| 297 |
-
Returns:
|
| 298 |
-
Tuple of (timesteps, num_inference_steps).
|
| 299 |
-
|
| 300 |
-
Raises:
|
| 301 |
-
ValueError: If both timesteps and sigmas are provided, or if
|
| 302 |
-
use_linear_quadratic_schedule is True but num_inference_steps is odd.
|
| 303 |
-
"""
|
| 304 |
-
if timesteps is not None and sigmas is not None:
|
| 305 |
-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 306 |
-
|
| 307 |
-
# Handle linear-quadratic schedule: compute sigmas if flag is set
|
| 308 |
-
if use_linear_quadratic_schedule:
|
| 309 |
-
if sigmas is not None:
|
| 310 |
-
raise ValueError(
|
| 311 |
-
"Cannot use both `sigmas` and `use_linear_quadratic_schedule`. "
|
| 312 |
-
"The linear-quadratic schedule computes sigmas automatically."
|
| 313 |
-
)
|
| 314 |
-
if num_inference_steps is None:
|
| 315 |
-
raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.")
|
| 316 |
-
sigmas = get_linear_quadratic_sigmas(
|
| 317 |
-
num_inference_steps=num_inference_steps,
|
| 318 |
-
linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
if timesteps is not None:
|
| 322 |
-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 323 |
-
if not accepts_timesteps:
|
| 324 |
-
raise ValueError(
|
| 325 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 326 |
-
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 327 |
-
)
|
| 328 |
-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 329 |
-
timesteps = scheduler.timesteps
|
| 330 |
-
num_inference_steps = len(timesteps)
|
| 331 |
-
elif sigmas is not None:
|
| 332 |
-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 333 |
-
if not accept_sigmas:
|
| 334 |
-
raise ValueError(
|
| 335 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 336 |
-
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 337 |
-
)
|
| 338 |
-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 339 |
-
timesteps = scheduler.timesteps
|
| 340 |
-
num_inference_steps = len(timesteps)
|
| 341 |
-
else:
|
| 342 |
-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 343 |
-
timesteps = scheduler.timesteps
|
| 344 |
-
return timesteps, num_inference_steps
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
def basic_clean(text):
|
| 348 |
-
text = ftfy.fix_text(text)
|
| 349 |
-
text = html.unescape(html.unescape(text))
|
| 350 |
-
return text.strip()
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
def whitespace_clean(text):
|
| 354 |
-
text = re.sub(r"\s+", " ", text)
|
| 355 |
-
text = text.strip()
|
| 356 |
-
return text
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
def prompt_clean(text):
|
| 360 |
-
text = whitespace_clean(basic_clean(text))
|
| 361 |
-
return text
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
class MotifVideoPipeline(DiffusionPipeline):
|
| 365 |
-
r"""
|
| 366 |
-
Pipeline for text-to-video generation using MotifVideoTransformer.
|
| 367 |
-
|
| 368 |
-
Args:
|
| 369 |
-
transformer ([`MotifVideoTransformer3DModel`]):
|
| 370 |
-
Conditional Transformer architecture to denoise the encoded video latents.
|
| 371 |
-
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 372 |
-
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 373 |
-
vae ([`AutoencoderKLWan`]):
|
| 374 |
-
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 375 |
-
text_encoder ([`T5Gemma2Encoder`]):
|
| 376 |
-
Primary text encoder for encoding text prompts into embeddings.
|
| 377 |
-
tokenizer ([`PreTrainedTokenizerBase`]):
|
| 378 |
-
Tokenizer corresponding to the primary text encoder.
|
| 379 |
-
guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`] or [`VideoAdaptiveProjectedGuidance`], *optional*):
|
| 380 |
-
The guidance method to use. If `None`, it defaults to `ClassifierFreeGuidance()`.
|
| 381 |
-
"""
|
| 382 |
-
|
| 383 |
-
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 384 |
-
_optional_components = ["feature_extractor"]
|
| 385 |
-
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 386 |
-
|
| 387 |
-
def __init__(
|
| 388 |
-
self,
|
| 389 |
-
scheduler: Union[
|
| 390 |
-
FlowMatchEulerDiscreteScheduler,
|
| 391 |
-
DPMSolverMultistepScheduler,
|
| 392 |
-
UniPCMultistepScheduler,
|
| 393 |
-
FlowUniPCMultistepScheduler,
|
| 394 |
-
],
|
| 395 |
-
vae: AutoencoderKLWan,
|
| 396 |
-
text_encoder: T5Gemma2Encoder,
|
| 397 |
-
tokenizer: PreTrainedTokenizerBase,
|
| 398 |
-
transformer,
|
| 399 |
-
guider: Optional[
|
| 400 |
-
Union[
|
| 401 |
-
ClassifierFreeGuidance,
|
| 402 |
-
SkipLayerGuidance,
|
| 403 |
-
AdaptiveProjectedGuidance,
|
| 404 |
-
VideoAdaptiveProjectedGuidance,
|
| 405 |
-
]
|
| 406 |
-
] = None,
|
| 407 |
-
feature_extractor: Optional[SiglipImageProcessor] = None,
|
| 408 |
-
):
|
| 409 |
-
super().__init__()
|
| 410 |
-
|
| 411 |
-
self.guider = ClassifierFreeGuidance() if guider is None else guider
|
| 412 |
-
|
| 413 |
-
self.register_modules(
|
| 414 |
-
vae=vae,
|
| 415 |
-
text_encoder=text_encoder,
|
| 416 |
-
tokenizer=tokenizer,
|
| 417 |
-
transformer=transformer,
|
| 418 |
-
scheduler=scheduler,
|
| 419 |
-
feature_extractor=feature_extractor,
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
| 423 |
-
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
| 424 |
-
|
| 425 |
-
self.transformer_spatial_patch_size = (
|
| 426 |
-
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2
|
| 427 |
-
)
|
| 428 |
-
self.transformer_temporal_patch_size = (
|
| 429 |
-
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 433 |
-
self.tokenizer_max_length = (
|
| 434 |
-
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512
|
| 435 |
-
)
|
| 436 |
-
|
| 437 |
-
def _get_default_embeds(
|
| 438 |
-
self,
|
| 439 |
-
text_encoder,
|
| 440 |
-
tokenizer: PreTrainedTokenizerBase,
|
| 441 |
-
prompt: Union[str, List[str]],
|
| 442 |
-
max_sequence_length: int = 512,
|
| 443 |
-
device: Optional[torch.device] = None,
|
| 444 |
-
dtype: Optional[torch.dtype] = None,
|
| 445 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 446 |
-
dtype = dtype or text_encoder.dtype
|
| 447 |
-
|
| 448 |
-
text_inputs = tokenizer(
|
| 449 |
-
prompt,
|
| 450 |
-
padding="max_length",
|
| 451 |
-
max_length=max_sequence_length,
|
| 452 |
-
truncation=True,
|
| 453 |
-
add_special_tokens=True,
|
| 454 |
-
return_attention_mask=True,
|
| 455 |
-
return_tensors="pt",
|
| 456 |
-
)
|
| 457 |
-
text_inputs = BatchEncoding(
|
| 458 |
-
{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()}
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
prompt_embeds = text_encoder(**text_inputs)[0]
|
| 462 |
-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 463 |
-
|
| 464 |
-
return prompt_embeds, text_inputs.attention_mask
|
| 465 |
-
|
| 466 |
-
def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 467 |
-
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 468 |
-
denom = attention_mask.sum(dim=1, keepdim=True).clamp(min=1) # avoid div by zero
|
| 469 |
-
return last_hidden.sum(dim=1) / denom
|
| 470 |
-
|
| 471 |
-
def _get_prompt_embeds(
|
| 472 |
-
self,
|
| 473 |
-
text_encoder: T5Gemma2Encoder,
|
| 474 |
-
tokenizer: PreTrainedTokenizerBase,
|
| 475 |
-
prompt: Union[str, List[str]] | None = None,
|
| 476 |
-
num_videos_per_prompt: int = 1,
|
| 477 |
-
max_sequence_length: int = 512,
|
| 478 |
-
device: Optional[torch.device] = None,
|
| 479 |
-
dtype: Optional[torch.dtype] = None,
|
| 480 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 481 |
-
device = device or self._execution_device
|
| 482 |
-
|
| 483 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 484 |
-
|
| 485 |
-
prompt_embeds_kwargs = {
|
| 486 |
-
"text_encoder": text_encoder,
|
| 487 |
-
"tokenizer": tokenizer,
|
| 488 |
-
"prompt": prompt,
|
| 489 |
-
"max_sequence_length": max_sequence_length,
|
| 490 |
-
"device": device,
|
| 491 |
-
"dtype": dtype,
|
| 492 |
-
}
|
| 493 |
-
# When enable_model_cpu_offload() is active, the accelerate forward hook is on text_encoder (parent). Moving the encoder to the execution device explicitly ensures inputs and
|
| 494 |
-
# weights are on the same device. The parent's offload hook will move text_encoder back to CPU after
|
| 495 |
-
# the next component claims the GPU.
|
| 496 |
-
if next(text_encoder.parameters()).device != torch.device(device):
|
| 497 |
-
text_encoder.to(device)
|
| 498 |
-
prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs)
|
| 499 |
-
|
| 500 |
-
pooled_prompt_embeds = self._average_pool(prompt_embeds, prompt_attention_mask)
|
| 501 |
-
|
| 502 |
-
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
|
| 503 |
-
|
| 504 |
-
# Keep encode_prompt structure, uses _get_prompt_embeds internally
|
| 505 |
-
def encode_prompt(
|
| 506 |
-
self,
|
| 507 |
-
prompt: Union[str, List[str]],
|
| 508 |
-
num_videos_per_prompt: int = 1,
|
| 509 |
-
prompt_embeds: Optional[torch.Tensor] = None,
|
| 510 |
-
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 511 |
-
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 512 |
-
max_sequence_length: int = 512,
|
| 513 |
-
device: Optional[torch.device] = None,
|
| 514 |
-
dtype: Optional[torch.dtype] = None,
|
| 515 |
-
) -> Tuple[
|
| 516 |
-
torch.Tensor,
|
| 517 |
-
torch.Tensor,
|
| 518 |
-
torch.Tensor,
|
| 519 |
-
]:
|
| 520 |
-
device = device or self._execution_device
|
| 521 |
-
|
| 522 |
-
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 523 |
-
if prompt is not None:
|
| 524 |
-
batch_size = len(prompt)
|
| 525 |
-
else:
|
| 526 |
-
batch_size = prompt_embeds.shape[0]
|
| 527 |
-
|
| 528 |
-
prompt_embeds_kwargs = {
|
| 529 |
-
"device": device,
|
| 530 |
-
"dtype": dtype,
|
| 531 |
-
}
|
| 532 |
-
|
| 533 |
-
if prompt_embeds is None:
|
| 534 |
-
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self._get_prompt_embeds(
|
| 535 |
-
text_encoder=self.text_encoder,
|
| 536 |
-
tokenizer=self.tokenizer,
|
| 537 |
-
prompt=prompt,
|
| 538 |
-
max_sequence_length=max_sequence_length,
|
| 539 |
-
**prompt_embeds_kwargs,
|
| 540 |
-
)
|
| 541 |
-
|
| 542 |
-
# Compute actual (non-padding) token count for batch=1 Flash Attention trimming in __call__
|
| 543 |
-
actual_seq_len = None
|
| 544 |
-
if batch_size == 1 and prompt_attention_mask is not None:
|
| 545 |
-
actual_seq_len = int(prompt_attention_mask.sum(dim=-1).max().item())
|
| 546 |
-
|
| 547 |
-
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 548 |
-
seq_len = prompt_embeds.shape[1]
|
| 549 |
-
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 550 |
-
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 551 |
-
|
| 552 |
-
if pooled_prompt_embeds is not None:
|
| 553 |
-
pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0)
|
| 554 |
-
|
| 555 |
-
if prompt_attention_mask is not None:
|
| 556 |
-
prompt_attention_mask = prompt_attention_mask.bool()
|
| 557 |
-
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
| 558 |
-
prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
|
| 559 |
-
|
| 560 |
-
return (
|
| 561 |
-
prompt_embeds,
|
| 562 |
-
pooled_prompt_embeds,
|
| 563 |
-
prompt_attention_mask,
|
| 564 |
-
actual_seq_len,
|
| 565 |
-
)
|
| 566 |
-
|
| 567 |
-
@property
|
| 568 |
-
def vision_encoder(self):
|
| 569 |
-
"""Get the vision encoder from T5Gemma2.
|
| 570 |
-
|
| 571 |
-
T5Gemma2 has vision_tower.vision_model structure.
|
| 572 |
-
Will raise AttributeError if not available.
|
| 573 |
-
"""
|
| 574 |
-
return self.text_encoder.vision_tower.vision_model
|
| 575 |
-
|
| 576 |
-
def encode_image(
|
| 577 |
-
self,
|
| 578 |
-
image: Image.Image,
|
| 579 |
-
batch_size: int = 1,
|
| 580 |
-
device: Optional[torch.device] = None,
|
| 581 |
-
dtype: Optional[torch.dtype] = None,
|
| 582 |
-
) -> torch.Tensor:
|
| 583 |
-
"""Encode image to embeddings using SigLIP vision encoder."""
|
| 584 |
-
device = device or self._execution_device
|
| 585 |
-
dtype = dtype or self.transformer.dtype
|
| 586 |
-
|
| 587 |
-
image_embeds = self._get_image_embeds(
|
| 588 |
-
image_encoder=self.vision_encoder,
|
| 589 |
-
feature_extractor=self.feature_extractor,
|
| 590 |
-
image=image,
|
| 591 |
-
device=device,
|
| 592 |
-
)
|
| 593 |
-
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
| 594 |
-
return image_embeds.to(device=device, dtype=dtype)
|
| 595 |
-
|
| 596 |
-
@staticmethod
|
| 597 |
-
def _get_image_embeds(
|
| 598 |
-
image_encoder,
|
| 599 |
-
feature_extractor: SiglipImageProcessor,
|
| 600 |
-
image,
|
| 601 |
-
device: torch.device,
|
| 602 |
-
) -> torch.Tensor:
|
| 603 |
-
"""Helper to encode single image with SigLIP.
|
| 604 |
-
|
| 605 |
-
Args:
|
| 606 |
-
image_encoder: The SigLIP vision encoder model.
|
| 607 |
-
feature_extractor: SiglipImageProcessor for preprocessing.
|
| 608 |
-
image: Can be either:
|
| 609 |
-
- PIL.Image.Image: Will be preprocessed by feature_extractor
|
| 610 |
-
- torch.Tensor: Assumed to be in [0, 1] range, will be normalized and passed to encoder
|
| 611 |
-
device: Device to place tensors on.
|
| 612 |
-
|
| 613 |
-
Returns:
|
| 614 |
-
Image embeddings from the vision encoder.
|
| 615 |
-
"""
|
| 616 |
-
image_encoder_dtype = next(image_encoder.parameters()).dtype
|
| 617 |
-
|
| 618 |
-
if isinstance(image, torch.Tensor):
|
| 619 |
-
image = feature_extractor.preprocess(
|
| 620 |
-
images=image.float(),
|
| 621 |
-
do_resize=True,
|
| 622 |
-
do_rescale=False,
|
| 623 |
-
do_normalize=True,
|
| 624 |
-
do_convert_rgb=True,
|
| 625 |
-
return_tensors="pt",
|
| 626 |
-
)
|
| 627 |
-
else:
|
| 628 |
-
image = feature_extractor.preprocess(
|
| 629 |
-
images=image,
|
| 630 |
-
do_resize=True,
|
| 631 |
-
do_rescale=False,
|
| 632 |
-
do_normalize=True,
|
| 633 |
-
do_convert_rgb=True,
|
| 634 |
-
return_tensors="pt",
|
| 635 |
-
)
|
| 636 |
-
|
| 637 |
-
image = image.to(device, dtype=image_encoder_dtype)
|
| 638 |
-
return image_encoder(**image).last_hidden_state
|
| 639 |
-
|
| 640 |
-
@torch.compiler.disable
|
| 641 |
-
def _prepare_first_frame_conditioning(
|
| 642 |
-
self,
|
| 643 |
-
video: torch.Tensor,
|
| 644 |
-
latents: torch.Tensor,
|
| 645 |
-
use_conditioning: bool,
|
| 646 |
-
generator: Optional[torch.Generator] = None,
|
| 647 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
| 648 |
-
"""Prepare first frame conditioning tensors.
|
| 649 |
-
|
| 650 |
-
This method implements batch-level conditioning where entire
|
| 651 |
-
batches are either I2V (all samples conditioned) or T2V (no conditioning). This
|
| 652 |
-
prevents mode confusion within batches.
|
| 653 |
-
|
| 654 |
-
For I2V mode:
|
| 655 |
-
1. Extract and VAE-encode first frame from video
|
| 656 |
-
2. Create latent_condition by repeating first frame across time (frame 0 only)
|
| 657 |
-
3. Create latent_mask with 1.0 at frame 0
|
| 658 |
-
4. Get image_embeds from vision encoder
|
| 659 |
-
|
| 660 |
-
For T2V mode:
|
| 661 |
-
1. Pad with zeros for latent_condition and latent_mask
|
| 662 |
-
|
| 663 |
-
Args:
|
| 664 |
-
video: Input video tensor [batch_size, frames, channels, height, width] in [-1, 1]
|
| 665 |
-
latents: Latents [batch_size, lantent_channels, latent_num_frames, latent_height, latent_width]
|
| 666 |
-
use_conditioning: Whether to use first-frame conditioning (True for I2V, False for T2V)
|
| 667 |
-
generator: Optional random number generator for reproducibility
|
| 668 |
-
|
| 669 |
-
Returns:
|
| 670 |
-
Tuple of (latent_condition, latent_mask, image_embeds).
|
| 671 |
-
- latent_condition: [B, C, F, H, W] conditioning signal (zeros for T2V)
|
| 672 |
-
- latent_mask: [B, 1, F, H, W] binary mask (zeros for T2V)
|
| 673 |
-
- image_embeds: [B, N, D] image embeddings from vision encoder or None for T2V
|
| 674 |
-
"""
|
| 675 |
-
batch_size, lantent_channels, latent_num_frames, latent_height, latent_width = latents.shape
|
| 676 |
-
device = latents.device
|
| 677 |
-
dtype = latents.dtype
|
| 678 |
-
|
| 679 |
-
# Determine if we should use conditioning
|
| 680 |
-
use_conditioning = use_conditioning and (latent_num_frames > 1)
|
| 681 |
-
|
| 682 |
-
# Initialize conditioning tensors
|
| 683 |
-
latent_condition = torch.zeros(
|
| 684 |
-
batch_size,
|
| 685 |
-
lantent_channels,
|
| 686 |
-
latent_num_frames,
|
| 687 |
-
latent_height,
|
| 688 |
-
latent_width,
|
| 689 |
-
device=device,
|
| 690 |
-
dtype=dtype,
|
| 691 |
-
)
|
| 692 |
-
latent_mask = torch.zeros(
|
| 693 |
-
batch_size,
|
| 694 |
-
1,
|
| 695 |
-
latent_num_frames,
|
| 696 |
-
latent_height,
|
| 697 |
-
latent_width,
|
| 698 |
-
device=device,
|
| 699 |
-
dtype=dtype,
|
| 700 |
-
)
|
| 701 |
-
image_embeds = None
|
| 702 |
-
|
| 703 |
-
if use_conditioning:
|
| 704 |
-
with torch.no_grad():
|
| 705 |
-
# Encode first frame for latent_condition
|
| 706 |
-
first_frame_latents = self.vae.encode(
|
| 707 |
-
rearrange(video[:, 0:1], "b f c h w -> b c f h w")
|
| 708 |
-
).latent_dist.sample(generator=generator)
|
| 709 |
-
first_frame_latents = self._normalize_latents(
|
| 710 |
-
latents=first_frame_latents,
|
| 711 |
-
latents_mean=self.vae.config.latents_mean,
|
| 712 |
-
latents_std=self.vae.config.latents_std,
|
| 713 |
-
)
|
| 714 |
-
|
| 715 |
-
# Create latent_condition by repeating first frame across time
|
| 716 |
-
latent_condition = first_frame_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
| 717 |
-
latent_condition[:, :, 1:, :, :] = 0
|
| 718 |
-
|
| 719 |
-
# latent_mask: 1.0 at frame 0, 0.0 elsewhere
|
| 720 |
-
latent_mask[:, :, 0] = 1.0
|
| 721 |
-
|
| 722 |
-
# image_embeds from vision encoder
|
| 723 |
-
first_frame_vision = video[:, 0] # [B, C, H, W]
|
| 724 |
-
first_frame_vision = ((first_frame_vision + 1) / 2).clamp(0, 1)
|
| 725 |
-
|
| 726 |
-
with torch.no_grad():
|
| 727 |
-
image_embeds = self._get_image_embeds(
|
| 728 |
-
image_encoder=self.vision_encoder,
|
| 729 |
-
feature_extractor=self.feature_extractor,
|
| 730 |
-
image=first_frame_vision,
|
| 731 |
-
device=device,
|
| 732 |
-
)
|
| 733 |
-
|
| 734 |
-
return latent_condition, latent_mask, image_embeds
|
| 735 |
-
|
| 736 |
-
def check_inputs(
|
| 737 |
-
self,
|
| 738 |
-
prompt,
|
| 739 |
-
negative_prompt,
|
| 740 |
-
height,
|
| 741 |
-
width,
|
| 742 |
-
batch_size,
|
| 743 |
-
callback_on_step_end_tensor_inputs=None,
|
| 744 |
-
prompt_embeds=None,
|
| 745 |
-
negative_prompt_embeds=None,
|
| 746 |
-
prompt_attention_mask=None,
|
| 747 |
-
negative_prompt_attention_mask=None,
|
| 748 |
-
):
|
| 749 |
-
# Resolution must be divisible by VAE scale factor * transformer patch size
|
| 750 |
-
# (e.g. 8 * 2 = 16 for default config) to avoid latent/patch dimension mismatch.
|
| 751 |
-
spatial_divisor = self.vae_scale_factor_spatial * self.transformer_spatial_patch_size
|
| 752 |
-
if height % spatial_divisor != 0 or width % spatial_divisor != 0:
|
| 753 |
-
raise ValueError(
|
| 754 |
-
f"`height` and `width` have to be divisible by {spatial_divisor} "
|
| 755 |
-
f"(vae_scale={self.vae_scale_factor_spatial} * patch_size={self.transformer_spatial_patch_size}) "
|
| 756 |
-
f"but are {height} and {width}."
|
| 757 |
-
)
|
| 758 |
-
|
| 759 |
-
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 760 |
-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 761 |
-
):
|
| 762 |
-
raise ValueError(
|
| 763 |
-
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 764 |
-
)
|
| 765 |
-
|
| 766 |
-
if prompt is not None and prompt_embeds is not None:
|
| 767 |
-
raise ValueError(
|
| 768 |
-
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 769 |
-
" only forward one of the two."
|
| 770 |
-
)
|
| 771 |
-
elif prompt is None and prompt_embeds is None:
|
| 772 |
-
raise ValueError(
|
| 773 |
-
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 774 |
-
)
|
| 775 |
-
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 776 |
-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 777 |
-
|
| 778 |
-
# Validate negative_prompt: must be None, str, or list with matching batch_size
|
| 779 |
-
if negative_prompt is not None:
|
| 780 |
-
if not isinstance(negative_prompt, (str, list)):
|
| 781 |
-
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
| 782 |
-
if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size:
|
| 783 |
-
raise ValueError(
|
| 784 |
-
f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})."
|
| 785 |
-
)
|
| 786 |
-
|
| 787 |
-
if prompt_embeds is not None and prompt_attention_mask is None:
|
| 788 |
-
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
| 789 |
-
|
| 790 |
-
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 791 |
-
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
| 792 |
-
|
| 793 |
-
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 794 |
-
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 795 |
-
raise ValueError(
|
| 796 |
-
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 797 |
-
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 798 |
-
f" {negative_prompt_embeds.shape}."
|
| 799 |
-
)
|
| 800 |
-
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
| 801 |
-
raise ValueError(
|
| 802 |
-
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
| 803 |
-
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
| 804 |
-
f" {negative_prompt_attention_mask.shape}."
|
| 805 |
-
)
|
| 806 |
-
|
| 807 |
-
def _prepare_negative_prompt(
|
| 808 |
-
self,
|
| 809 |
-
negative_prompt: Optional[Union[str, List[str]]],
|
| 810 |
-
batch_size: int,
|
| 811 |
-
) -> List[str]:
|
| 812 |
-
"""
|
| 813 |
-
Prepare negative_prompt to match batch_size.
|
| 814 |
-
|
| 815 |
-
Args:
|
| 816 |
-
negative_prompt: None, a single string, or a list of strings matching batch_size.
|
| 817 |
-
batch_size: The number of prompts in the batch.
|
| 818 |
-
|
| 819 |
-
Returns:
|
| 820 |
-
A list of negative prompts with length equal to batch_size.
|
| 821 |
-
"""
|
| 822 |
-
if negative_prompt is None:
|
| 823 |
-
return [""] * batch_size
|
| 824 |
-
if isinstance(negative_prompt, str):
|
| 825 |
-
return [negative_prompt] * batch_size
|
| 826 |
-
return negative_prompt
|
| 827 |
-
|
| 828 |
-
@staticmethod
|
| 829 |
-
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
| 830 |
-
batch_size, num_channels, num_frames, height, width = latents.shape
|
| 831 |
-
post_patch_num_frames = num_frames // patch_size_t
|
| 832 |
-
post_patch_height = height // patch_size
|
| 833 |
-
post_patch_width = width // patch_size
|
| 834 |
-
latents = latents.reshape(
|
| 835 |
-
batch_size,
|
| 836 |
-
-1,
|
| 837 |
-
post_patch_num_frames,
|
| 838 |
-
patch_size_t,
|
| 839 |
-
post_patch_height,
|
| 840 |
-
patch_size,
|
| 841 |
-
post_patch_width,
|
| 842 |
-
patch_size,
|
| 843 |
-
)
|
| 844 |
-
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
| 845 |
-
return latents
|
| 846 |
-
|
| 847 |
-
@staticmethod
|
| 848 |
-
def _unpack_latents(
|
| 849 |
-
latents: torch.Tensor,
|
| 850 |
-
num_frames: int,
|
| 851 |
-
height: int,
|
| 852 |
-
width: int,
|
| 853 |
-
patch_size: int = 1,
|
| 854 |
-
patch_size_t: int = 1,
|
| 855 |
-
) -> torch.Tensor:
|
| 856 |
-
batch_size = latents.size(0)
|
| 857 |
-
latents = latents.reshape(
|
| 858 |
-
batch_size,
|
| 859 |
-
num_frames,
|
| 860 |
-
height,
|
| 861 |
-
width,
|
| 862 |
-
-1,
|
| 863 |
-
patch_size_t,
|
| 864 |
-
patch_size,
|
| 865 |
-
patch_size,
|
| 866 |
-
)
|
| 867 |
-
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 868 |
-
return latents
|
| 869 |
-
|
| 870 |
-
@staticmethod
|
| 871 |
-
def _normalize_latents(
|
| 872 |
-
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
| 873 |
-
) -> torch.Tensor:
|
| 874 |
-
# Normalize latents across the channel dimension [B, C, F, H, W]
|
| 875 |
-
latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
| 876 |
-
latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
| 877 |
-
latents = (latents - latents_mean) / latents_std
|
| 878 |
-
return latents
|
| 879 |
-
|
| 880 |
-
@staticmethod
|
| 881 |
-
def _denormalize_latents(
|
| 882 |
-
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
| 883 |
-
) -> torch.Tensor:
|
| 884 |
-
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
| 885 |
-
latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
| 886 |
-
latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
| 887 |
-
latents = latents * latents_std + latents_mean
|
| 888 |
-
return latents
|
| 889 |
-
|
| 890 |
-
def prepare_latents(
|
| 891 |
-
self,
|
| 892 |
-
batch_size: int = 1,
|
| 893 |
-
num_channels_latents: int = 16,
|
| 894 |
-
height: int = 352,
|
| 895 |
-
width: int = 640,
|
| 896 |
-
num_frames: int = 65,
|
| 897 |
-
dtype: Optional[torch.dtype] = None,
|
| 898 |
-
device: Optional[torch.device] = None,
|
| 899 |
-
generator: Optional[torch.Generator] = None,
|
| 900 |
-
latents: Optional[torch.Tensor] = None,
|
| 901 |
-
) -> torch.Tensor:
|
| 902 |
-
if latents is not None:
|
| 903 |
-
return latents.to(device=device, dtype=dtype)
|
| 904 |
-
|
| 905 |
-
shape = (
|
| 906 |
-
batch_size,
|
| 907 |
-
num_channels_latents,
|
| 908 |
-
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 909 |
-
height // self.vae_scale_factor_spatial,
|
| 910 |
-
width // self.vae_scale_factor_spatial,
|
| 911 |
-
)
|
| 912 |
-
|
| 913 |
-
if isinstance(generator, list) and len(generator) != batch_size:
|
| 914 |
-
raise ValueError(
|
| 915 |
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 916 |
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 917 |
-
)
|
| 918 |
-
|
| 919 |
-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 920 |
-
return latents
|
| 921 |
-
|
| 922 |
-
@property
|
| 923 |
-
def num_timesteps(self):
|
| 924 |
-
return self._num_timesteps
|
| 925 |
-
|
| 926 |
-
@property
|
| 927 |
-
def current_timestep(self):
|
| 928 |
-
return self._current_timestep
|
| 929 |
-
|
| 930 |
-
@property
|
| 931 |
-
def attention_kwargs(self):
|
| 932 |
-
return self._attention_kwargs
|
| 933 |
-
|
| 934 |
-
@property
|
| 935 |
-
def interrupt(self):
|
| 936 |
-
return self._interrupt
|
| 937 |
-
|
| 938 |
-
@torch.no_grad()
|
| 939 |
-
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 940 |
-
def __call__(
|
| 941 |
-
self,
|
| 942 |
-
prompt: Union[str, List[str]] | None = None,
|
| 943 |
-
image=None,
|
| 944 |
-
negative_prompt: Optional[
|
| 945 |
-
Union[str, List[str]]
|
| 946 |
-
] = "text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
|
| 947 |
-
height: int = 736,
|
| 948 |
-
width: int = 1280,
|
| 949 |
-
num_frames: int = 121,
|
| 950 |
-
frame_rate: int = 24,
|
| 951 |
-
num_inference_steps: int = 50,
|
| 952 |
-
timesteps: List[int] | None = None,
|
| 953 |
-
use_linear_quadratic_schedule: bool = False,
|
| 954 |
-
linear_quadratic_emulating_steps: int = 250,
|
| 955 |
-
num_videos_per_prompt: Optional[int] = 1,
|
| 956 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 957 |
-
latents: Optional[torch.Tensor] = None,
|
| 958 |
-
prompt_embeds: Optional[torch.Tensor] = None,
|
| 959 |
-
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 960 |
-
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 961 |
-
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 962 |
-
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 963 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 964 |
-
output_type: Optional[str] = "pil",
|
| 965 |
-
return_dict: bool = True,
|
| 966 |
-
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 967 |
-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 968 |
-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 969 |
-
max_sequence_length: int = 512,
|
| 970 |
-
use_attention_mask: bool = True,
|
| 971 |
-
vae_batch_size: int | None = None,
|
| 972 |
-
):
|
| 973 |
-
r"""
|
| 974 |
-
Function invoked when calling the pipeline for generation.
|
| 975 |
-
|
| 976 |
-
Args:
|
| 977 |
-
prompt (`str` or `List[str]`, *optional*):
|
| 978 |
-
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 979 |
-
instead.
|
| 980 |
-
negative_prompt (`str` or `List[str]`, *optional*):
|
| 981 |
-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 982 |
-
`negative_prompt_embeds` instead. Ignored when not using guidance.
|
| 983 |
-
height (`int`, defaults to `352`):
|
| 984 |
-
The height in pixels of the generated image.
|
| 985 |
-
width (`int`, defaults to `640`):
|
| 986 |
-
The width in pixels of the generated image.
|
| 987 |
-
num_frames (`int`, defaults to `65`):
|
| 988 |
-
The number of video frames to generate
|
| 989 |
-
frame_rate (`int`, defaults to `25`):
|
| 990 |
-
Frame rate for the output video.
|
| 991 |
-
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 992 |
-
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
| 993 |
-
expense of slower inference.
|
| 994 |
-
timesteps (`List[int]`, *optional*):
|
| 995 |
-
Custom timesteps to use for the denoising process.
|
| 996 |
-
use_linear_quadratic_schedule (`bool`, defaults to `True`):
|
| 997 |
-
Whether to use a linear-quadratic sigma schedule instead of the default linear schedule.
|
| 998 |
-
This schedule combines linear interpolation in the first half (slow denoising at high noise)
|
| 999 |
-
with quadratic interpolation in the second half (faster denoising toward clean image).
|
| 1000 |
-
Requires `num_inference_steps` to be even.
|
| 1001 |
-
linear_quadratic_emulating_steps (`int`, defaults to `250`):
|
| 1002 |
-
Controls the slope of linear interpolation in the first half of the linear-quadratic schedule.
|
| 1003 |
-
Higher values result in a gentler slope. Only used when `use_linear_quadratic_schedule=True`.
|
| 1004 |
-
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 1005 |
-
The number of videos to generate per prompt.
|
| 1006 |
-
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 1007 |
-
PyTorch Generator object(s) for deterministic generation.
|
| 1008 |
-
latents (`torch.Tensor`, *optional*):
|
| 1009 |
-
Pre-generated noisy latents.
|
| 1010 |
-
prompt_embeds (`torch.Tensor`, *optional*):
|
| 1011 |
-
Pre-generated text embeddings.
|
| 1012 |
-
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 1013 |
-
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 1014 |
-
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 1015 |
-
prompt_attention_mask (`torch.Tensor`, *optional*):
|
| 1016 |
-
Pre-generated attention mask for text embeddings.
|
| 1017 |
-
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 1018 |
-
Pre-generated negative text embeddings.
|
| 1019 |
-
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 1020 |
-
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 1021 |
-
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 1022 |
-
input argument.
|
| 1023 |
-
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
| 1024 |
-
Pre-generated attention mask for negative text embeddings.
|
| 1025 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1026 |
-
The output format ("pil" or "np").
|
| 1027 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1028 |
-
Whether to return a `MotifVideoPipelineOutput`.
|
| 1029 |
-
attention_kwargs (`dict`, *optional*):
|
| 1030 |
-
Arguments passed to the attention processor.
|
| 1031 |
-
callback_on_step_end (`Callable`, *optional*):
|
| 1032 |
-
Callback function called at the end of each step.
|
| 1033 |
-
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 1034 |
-
Tensors to include in the callback.
|
| 1035 |
-
max_sequence_length (`int` defaults to `512`):
|
| 1036 |
-
Maximum sequence length for the tokenizer.
|
| 1037 |
-
|
| 1038 |
-
Examples:
|
| 1039 |
-
|
| 1040 |
-
Returns:
|
| 1041 |
-
[`~pipelines.motif_video.MotifVideoPipelineOutput`] or `tuple`:
|
| 1042 |
-
If `return_dict` is `True`, returns [`~pipelines.motif_video.MotifVideoPipelineOutput`],
|
| 1043 |
-
otherwise returns a tuple where the first element is a list of generated video frames.
|
| 1044 |
-
"""
|
| 1045 |
-
|
| 1046 |
-
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 1047 |
-
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 1048 |
-
|
| 1049 |
-
# 1. Define call parameters (batch_size needed for check_inputs)
|
| 1050 |
-
if prompt is not None and isinstance(prompt, str):
|
| 1051 |
-
batch_size = 1
|
| 1052 |
-
elif prompt is not None and isinstance(prompt, list):
|
| 1053 |
-
batch_size = len(prompt)
|
| 1054 |
-
else:
|
| 1055 |
-
batch_size = prompt_embeds.shape[0]
|
| 1056 |
-
|
| 1057 |
-
# 2. Check inputs. Raise error if not correct
|
| 1058 |
-
self.check_inputs(
|
| 1059 |
-
prompt=prompt,
|
| 1060 |
-
negative_prompt=negative_prompt,
|
| 1061 |
-
height=height,
|
| 1062 |
-
width=width,
|
| 1063 |
-
batch_size=batch_size,
|
| 1064 |
-
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 1065 |
-
prompt_embeds=prompt_embeds,
|
| 1066 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
| 1067 |
-
prompt_attention_mask=prompt_attention_mask,
|
| 1068 |
-
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 1069 |
-
)
|
| 1070 |
-
|
| 1071 |
-
self._attention_kwargs = attention_kwargs
|
| 1072 |
-
self._interrupt = False
|
| 1073 |
-
self._current_timestep = None
|
| 1074 |
-
|
| 1075 |
-
# Auto-upgrade AdaptiveProjectedGuidance to VideoAdaptiveProjectedGuidance
|
| 1076 |
-
# for video generation. Video-aware APG normalizes per-frame [C,H,W] instead
|
| 1077 |
-
# of collapsing the temporal axis, preserving motion quality.
|
| 1078 |
-
if type(self.guider) is AdaptiveProjectedGuidance:
|
| 1079 |
-
self.guider = VideoAdaptiveProjectedGuidance(
|
| 1080 |
-
guidance_scale=self.guider.guidance_scale,
|
| 1081 |
-
adaptive_projected_guidance_rescale=self.guider.adaptive_projected_guidance_rescale,
|
| 1082 |
-
adaptive_projected_guidance_momentum=self.guider.adaptive_projected_guidance_momentum,
|
| 1083 |
-
eta=self.guider.eta,
|
| 1084 |
-
use_original_formulation=self.guider.use_original_formulation,
|
| 1085 |
-
)
|
| 1086 |
-
|
| 1087 |
-
device = self._execution_device
|
| 1088 |
-
|
| 1089 |
-
# 3. Prepare text embeddings
|
| 1090 |
-
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask, pos_actual_len = self.encode_prompt(
|
| 1091 |
-
prompt=prompt,
|
| 1092 |
-
num_videos_per_prompt=num_videos_per_prompt,
|
| 1093 |
-
prompt_embeds=prompt_embeds,
|
| 1094 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 1095 |
-
prompt_attention_mask=prompt_attention_mask,
|
| 1096 |
-
max_sequence_length=max_sequence_length,
|
| 1097 |
-
device=device,
|
| 1098 |
-
)
|
| 1099 |
-
|
| 1100 |
-
if not self.guider._enabled and pos_actual_len is not None:
|
| 1101 |
-
prompt_embeds = prompt_embeds[:, :pos_actual_len, :]
|
| 1102 |
-
prompt_attention_mask = None
|
| 1103 |
-
|
| 1104 |
-
if self.guider._enabled:
|
| 1105 |
-
negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
|
| 1106 |
-
(
|
| 1107 |
-
negative_prompt_embeds,
|
| 1108 |
-
negative_pooled_prompt_embeds,
|
| 1109 |
-
negative_prompt_attention_mask,
|
| 1110 |
-
neg_actual_len,
|
| 1111 |
-
) = self.encode_prompt(
|
| 1112 |
-
prompt=negative_prompt,
|
| 1113 |
-
num_videos_per_prompt=num_videos_per_prompt,
|
| 1114 |
-
prompt_embeds=negative_prompt_embeds,
|
| 1115 |
-
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 1116 |
-
prompt_attention_mask=negative_prompt_attention_mask,
|
| 1117 |
-
max_sequence_length=max_sequence_length,
|
| 1118 |
-
device=device,
|
| 1119 |
-
)
|
| 1120 |
-
|
| 1121 |
-
# Trim each to its own actual length — guider runs pos/neg in separate loop iterations,
|
| 1122 |
-
# so different seq lengths are fine. No padding embeddings attend without mask.
|
| 1123 |
-
if pos_actual_len is not None and neg_actual_len is not None:
|
| 1124 |
-
prompt_embeds = prompt_embeds[:, :pos_actual_len, :]
|
| 1125 |
-
negative_prompt_embeds = negative_prompt_embeds[:, :neg_actual_len, :]
|
| 1126 |
-
prompt_attention_mask = None
|
| 1127 |
-
negative_prompt_attention_mask = None
|
| 1128 |
-
|
| 1129 |
-
num_channels_latents = self.vae.config.z_dim
|
| 1130 |
-
latents = self.prepare_latents(
|
| 1131 |
-
batch_size * num_videos_per_prompt,
|
| 1132 |
-
num_channels_latents,
|
| 1133 |
-
height,
|
| 1134 |
-
width,
|
| 1135 |
-
num_frames,
|
| 1136 |
-
self.transformer.dtype,
|
| 1137 |
-
device,
|
| 1138 |
-
generator,
|
| 1139 |
-
latents,
|
| 1140 |
-
)
|
| 1141 |
-
|
| 1142 |
-
# 4.5 Preprocess image for I2V conditioning
|
| 1143 |
-
if image is not None:
|
| 1144 |
-
from PIL import Image as PILImage
|
| 1145 |
-
|
| 1146 |
-
if isinstance(image, PILImage.Image):
|
| 1147 |
-
image = image.convert("RGB").resize((width, height), PILImage.LANCZOS)
|
| 1148 |
-
image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
|
| 1149 |
-
image = image * 2.0 - 1.0 # [0,1] -> [-1,1]
|
| 1150 |
-
image = image.unsqueeze(0) # [1, C, H, W]
|
| 1151 |
-
# Handle [C, H, W] -> [1, C, H, W]
|
| 1152 |
-
if image.dim() == 3:
|
| 1153 |
-
image = image.unsqueeze(0)
|
| 1154 |
-
# [B, C, H, W] -> [B, 1, C, H, W] for video format
|
| 1155 |
-
if image.dim() == 4:
|
| 1156 |
-
image = image.unsqueeze(1)
|
| 1157 |
-
image = image.to(device=device, dtype=self.vae.dtype)
|
| 1158 |
-
|
| 1159 |
-
# 5. Prepare timesteps (including mu calculation)
|
| 1160 |
-
|
| 1161 |
-
# Recalculate latent dims based on VAE for mu calculation
|
| 1162 |
-
latent_height = height // self.vae_scale_factor_spatial
|
| 1163 |
-
latent_width = width // self.vae_scale_factor_spatial
|
| 1164 |
-
latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 1165 |
-
|
| 1166 |
-
# Calculate sequence length based on *packed* dimensions if transformer uses packing
|
| 1167 |
-
# Packed dims: H/patch, W/patch, F/patch_t
|
| 1168 |
-
packed_latent_height = latent_height // self.transformer_spatial_patch_size
|
| 1169 |
-
packed_latent_width = latent_width // self.transformer_spatial_patch_size
|
| 1170 |
-
packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size
|
| 1171 |
-
video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width
|
| 1172 |
-
|
| 1173 |
-
# Compute sigmas: use linear-quadratic schedule if enabled, otherwise default linear
|
| 1174 |
-
_is_flow_multistep = isinstance(
|
| 1175 |
-
self.scheduler,
|
| 1176 |
-
(
|
| 1177 |
-
DPMSolverMultistepScheduler,
|
| 1178 |
-
UniPCMultistepScheduler,
|
| 1179 |
-
FlowUniPCMultistepScheduler,
|
| 1180 |
-
),
|
| 1181 |
-
)
|
| 1182 |
-
|
| 1183 |
-
# Compute mu once, shared by both branches (required by FlowUniPCMultistepScheduler)
|
| 1184 |
-
mu = calculate_shift(
|
| 1185 |
-
video_sequence_length,
|
| 1186 |
-
self.scheduler.config.get("base_image_seq_len", 256),
|
| 1187 |
-
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 1188 |
-
self.scheduler.config.get("base_shift", 0.5),
|
| 1189 |
-
self.scheduler.config.get("max_shift", 1.15),
|
| 1190 |
-
)
|
| 1191 |
-
|
| 1192 |
-
if _is_flow_multistep:
|
| 1193 |
-
# DPMSolver/UniPC manage their own sigma schedule via use_flow_sigmas + flow_shift.
|
| 1194 |
-
# Pass mu for dynamic shifting support (required by FlowUniPCMultistepScheduler).
|
| 1195 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1196 |
-
self.scheduler,
|
| 1197 |
-
num_inference_steps,
|
| 1198 |
-
device,
|
| 1199 |
-
timesteps,
|
| 1200 |
-
mu=mu,
|
| 1201 |
-
)
|
| 1202 |
-
else:
|
| 1203 |
-
if use_linear_quadratic_schedule:
|
| 1204 |
-
# Linear-quadratic schedule computes sigmas internally in retrieve_timesteps
|
| 1205 |
-
sigmas = None
|
| 1206 |
-
else:
|
| 1207 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 1208 |
-
|
| 1209 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 1210 |
-
self.scheduler,
|
| 1211 |
-
num_inference_steps,
|
| 1212 |
-
device,
|
| 1213 |
-
timesteps,
|
| 1214 |
-
sigmas=sigmas,
|
| 1215 |
-
use_linear_quadratic_schedule=use_linear_quadratic_schedule,
|
| 1216 |
-
linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
|
| 1217 |
-
mu=mu,
|
| 1218 |
-
)
|
| 1219 |
-
|
| 1220 |
-
# Get conditioning tensors
|
| 1221 |
-
latent_condition, latent_mask, image_embeds = self._prepare_first_frame_conditioning(
|
| 1222 |
-
image,
|
| 1223 |
-
latents,
|
| 1224 |
-
use_conditioning=image is not None,
|
| 1225 |
-
generator=generator,
|
| 1226 |
-
)
|
| 1227 |
-
|
| 1228 |
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1229 |
-
self._num_timesteps = len(timesteps)
|
| 1230 |
-
|
| 1231 |
-
# 6. Denoising loop
|
| 1232 |
-
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1233 |
-
for i, t in enumerate(timesteps):
|
| 1234 |
-
if self.interrupt:
|
| 1235 |
-
continue
|
| 1236 |
-
|
| 1237 |
-
self._current_timestep = t
|
| 1238 |
-
|
| 1239 |
-
# Concatenate current latents with conditioning for this timestep
|
| 1240 |
-
# [latents | latent_condition | latent_mask]
|
| 1241 |
-
hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1)
|
| 1242 |
-
|
| 1243 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1244 |
-
timestep = t.expand(latents.shape[0])
|
| 1245 |
-
|
| 1246 |
-
# Step 1: Collect model inputs needed for the guidance method
|
| 1247 |
-
# conditional inputs should always be first element in the tuple
|
| 1248 |
-
guider_inputs = {
|
| 1249 |
-
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
| 1250 |
-
}
|
| 1251 |
-
if use_attention_mask and prompt_attention_mask is not None:
|
| 1252 |
-
guider_inputs["encoder_attention_mask"] = (
|
| 1253 |
-
prompt_attention_mask,
|
| 1254 |
-
negative_prompt_attention_mask,
|
| 1255 |
-
)
|
| 1256 |
-
if self.transformer.config.pooled_projection_dim is not None:
|
| 1257 |
-
guider_inputs["pooled_projections"] = (
|
| 1258 |
-
pooled_prompt_embeds,
|
| 1259 |
-
negative_pooled_prompt_embeds,
|
| 1260 |
-
)
|
| 1261 |
-
if image_embeds is not None:
|
| 1262 |
-
guider_inputs["image_embeds"] = (image_embeds, image_embeds)
|
| 1263 |
-
|
| 1264 |
-
# Step 2: Update guider's internal state for this denoising step
|
| 1265 |
-
self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
| 1266 |
-
# Sigma injection for guiders that support sigma-based gating
|
| 1267 |
-
# (Kynkäänniemi 2024). Must precede `prepare_inputs` because
|
| 1268 |
-
# `num_conditions` → `_is_cfg_enabled()` reads `_current_sigma`.
|
| 1269 |
-
# Duck-typed so diffusers-native guiders are unaffected; guard
|
| 1270 |
-
# on scheduler too since some schedulers don't expose `sigmas`.
|
| 1271 |
-
if hasattr(self.guider, "_current_sigma") and hasattr(self.scheduler, "sigmas"):
|
| 1272 |
-
self.guider._current_sigma = float(self.scheduler.sigmas[i])
|
| 1273 |
-
|
| 1274 |
-
# Step 3: Prepare batched model inputs based on the guidance method
|
| 1275 |
-
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
| 1276 |
-
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
| 1277 |
-
# you will get a guider_state with two batches:
|
| 1278 |
-
# guider_state = [
|
| 1279 |
-
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
| 1280 |
-
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
| 1281 |
-
# ]
|
| 1282 |
-
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
| 1283 |
-
guider_state = self.guider.prepare_inputs(guider_inputs)
|
| 1284 |
-
|
| 1285 |
-
# Step 4: Run the denoiser for each batch
|
| 1286 |
-
# Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
|
| 1287 |
-
# We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
|
| 1288 |
-
for guider_state_batch in guider_state:
|
| 1289 |
-
self.guider.prepare_models(self.transformer)
|
| 1290 |
-
|
| 1291 |
-
# Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
|
| 1292 |
-
cond_kwargs = {
|
| 1293 |
-
input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
|
| 1294 |
-
}
|
| 1295 |
-
|
| 1296 |
-
tread_disabled = getattr(self.guider, "_current_tread_disabled", False)
|
| 1297 |
-
|
| 1298 |
-
# Override TREAD selection ratio per batch if the guider provides one
|
| 1299 |
-
selection_ratio = getattr(self.guider, "_current_selection_ratio", None)
|
| 1300 |
-
tread_mixin = getattr(self.transformer, "_inference_tread_mixin", None)
|
| 1301 |
-
if (
|
| 1302 |
-
selection_ratio is not None
|
| 1303 |
-
and tread_mixin is not None
|
| 1304 |
-
and tread_mixin._tread_route is not None
|
| 1305 |
-
):
|
| 1306 |
-
tread_mixin._tread_route["sel"] = selection_ratio
|
| 1307 |
-
|
| 1308 |
-
# e.g. "pred_cond"/"pred_uncond"
|
| 1309 |
-
context_name = getattr(guider_state_batch, self.guider._identifier_key)
|
| 1310 |
-
with self.transformer.cache_context(context_name):
|
| 1311 |
-
# Run denoiser and store noise prediction in this batch
|
| 1312 |
-
|
| 1313 |
-
noise_pred = self.transformer(
|
| 1314 |
-
hidden_states=hidden_states,
|
| 1315 |
-
timestep=timestep,
|
| 1316 |
-
attention_kwargs=self.attention_kwargs,
|
| 1317 |
-
return_dict=False,
|
| 1318 |
-
tread_disabled=tread_disabled,
|
| 1319 |
-
**cond_kwargs,
|
| 1320 |
-
)[0].clone()
|
| 1321 |
-
|
| 1322 |
-
guider_state_batch.noise_pred = noise_pred
|
| 1323 |
-
# Cleanup model (e.g., remove hooks)
|
| 1324 |
-
self.guider.cleanup_models(self.transformer)
|
| 1325 |
-
|
| 1326 |
-
# Step 5: Combine predictions using the guidance method
|
| 1327 |
-
# The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
|
| 1328 |
-
# Continuing the CFG example, the guider receives:
|
| 1329 |
-
# guider_state = [
|
| 1330 |
-
# {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
|
| 1331 |
-
# {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
|
| 1332 |
-
# ]
|
| 1333 |
-
# And extracts predictions using the __guidance_identifier__:
|
| 1334 |
-
# pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
|
| 1335 |
-
# pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
|
| 1336 |
-
# Then applies CFG formula:
|
| 1337 |
-
# noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
| 1338 |
-
# Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
| 1339 |
-
noise_pred = self.guider(guider_state)[0]
|
| 1340 |
-
|
| 1341 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 1342 |
-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1343 |
-
|
| 1344 |
-
if callback_on_step_end is not None:
|
| 1345 |
-
callback_kwargs = {}
|
| 1346 |
-
for k in callback_on_step_end_tensor_inputs:
|
| 1347 |
-
callback_kwargs[k] = locals()[k]
|
| 1348 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1349 |
-
|
| 1350 |
-
latents = callback_outputs.pop("latents", latents)
|
| 1351 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1352 |
-
# Handle negative embeds if needed by callback
|
| 1353 |
-
if "negative_prompt_embeds" in callback_outputs:
|
| 1354 |
-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds")
|
| 1355 |
-
|
| 1356 |
-
# call the callback, if provided
|
| 1357 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1358 |
-
progress_bar.update()
|
| 1359 |
-
|
| 1360 |
-
if XLA_AVAILABLE:
|
| 1361 |
-
xm.mark_step()
|
| 1362 |
-
|
| 1363 |
-
self._current_timestep = None
|
| 1364 |
-
|
| 1365 |
-
if output_type == "latent":
|
| 1366 |
-
video = latents
|
| 1367 |
-
else:
|
| 1368 |
-
latents = latents.to(self.vae.dtype)
|
| 1369 |
-
latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std)
|
| 1370 |
-
if vae_batch_size is not None and latents.shape[0] > vae_batch_size:
|
| 1371 |
-
video_chunks = []
|
| 1372 |
-
for i in range(0, latents.shape[0], vae_batch_size):
|
| 1373 |
-
chunk = latents[i : i + vae_batch_size]
|
| 1374 |
-
video_chunks.append(self.vae.decode(chunk, return_dict=False)[0])
|
| 1375 |
-
video = torch.cat(video_chunks, dim=0)
|
| 1376 |
-
del video_chunks
|
| 1377 |
-
else:
|
| 1378 |
-
video = self.vae.decode(latents, return_dict=False)[0]
|
| 1379 |
-
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 1380 |
-
|
| 1381 |
-
# Offload all models
|
| 1382 |
-
self.maybe_free_model_hooks()
|
| 1383 |
-
|
| 1384 |
-
if not return_dict:
|
| 1385 |
-
return (video,)
|
| 1386 |
-
|
| 1387 |
-
# Return updated output type
|
| 1388 |
-
return MotifVideoPipelineOutput(frames=video)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transformer/config.json
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
| 4 |
"_library": "diffusers",
|
| 5 |
"attention_head_dim": 128,
|
| 6 |
-
"base_latent_size": null,
|
| 7 |
"image_embed_dim": 1152,
|
| 8 |
"in_channels": 33,
|
| 9 |
"mlp_ratio": 4.0,
|
|
@@ -15,7 +14,6 @@
|
|
| 15 |
"out_channels": 16,
|
| 16 |
"patch_size": 2,
|
| 17 |
"patch_size_t": 1,
|
| 18 |
-
"pooled_projection_dim": null,
|
| 19 |
"qk_norm": "rms_norm",
|
| 20 |
"rope_axes_dim": [
|
| 21 |
16,
|
|
|
|
| 3 |
"_diffusers_version": "0.36.0",
|
| 4 |
"_library": "diffusers",
|
| 5 |
"attention_head_dim": 128,
|
|
|
|
| 6 |
"image_embed_dim": 1152,
|
| 7 |
"in_channels": 33,
|
| 8 |
"mlp_ratio": 4.0,
|
|
|
|
| 14 |
"out_channels": 16,
|
| 15 |
"patch_size": 2,
|
| 16 |
"patch_size_t": 1,
|
|
|
|
| 17 |
"qk_norm": "rms_norm",
|
| 18 |
"rope_axes_dim": [
|
| 19 |
16,
|
transformer/transformer_motif_video.py
DELETED
|
@@ -1,1350 +0,0 @@
|
|
| 1 |
-
# Copyright 2026 Motif Technologies. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import math
|
| 16 |
-
from functools import lru_cache
|
| 17 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
import torch
|
| 21 |
-
import torch.nn as nn
|
| 22 |
-
import torch.nn.functional as F
|
| 23 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
-
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
|
| 25 |
-
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 26 |
-
from diffusers.models.attention import FeedForward
|
| 27 |
-
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 28 |
-
from diffusers.models.cache_utils import CacheMixin
|
| 29 |
-
from diffusers.models.embeddings import (
|
| 30 |
-
PixArtAlphaTextProjection,
|
| 31 |
-
TimestepEmbedding,
|
| 32 |
-
Timesteps,
|
| 33 |
-
)
|
| 34 |
-
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 35 |
-
from diffusers.models.modeling_utils import ModelMixin
|
| 36 |
-
from diffusers.models.normalization import (
|
| 37 |
-
AdaLayerNormContinuous,
|
| 38 |
-
AdaLayerNormZero,
|
| 39 |
-
AdaLayerNormZeroSingle,
|
| 40 |
-
)
|
| 41 |
-
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 42 |
-
|
| 43 |
-
# Stub functions for TREAD (Token REduction with Approximated Distillation).
|
| 44 |
-
# These stubs ensure TREAD code paths are never activated during inference
|
| 45 |
-
# without requiring the motif_core package.
|
| 46 |
-
def is_tread_start(block_idx, start, end): return False
|
| 47 |
-
def is_tread_end(block_idx, start, end): return False
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 51 |
-
|
| 52 |
-
NUM_TRAIN_TIMESTEPS = 1000
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def apply_rotary_emb(
|
| 56 |
-
x: torch.Tensor,
|
| 57 |
-
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
| 58 |
-
use_real: bool = True,
|
| 59 |
-
use_real_unbind_dim: int = -1,
|
| 60 |
-
) -> torch.Tensor:
|
| 61 |
-
"""
|
| 62 |
-
Apply rotary positional embeddings (RoPE) to input tensors.
|
| 63 |
-
|
| 64 |
-
This implementation supports both standard 2D RoPE tensors [L, Dh] and batched 4D RoPE
|
| 65 |
-
tensors [B, 1, L, Dh] for compatibility with TREAD's token-dropping mechanism where
|
| 66 |
-
different batches may have different token subsets.
|
| 67 |
-
|
| 68 |
-
Args:
|
| 69 |
-
x: Input tensor of shape [B, H, L, Dh].
|
| 70 |
-
freqs_cis: Tuple of (cos, sin) tensors. Supports shapes [L, Dh] or [B, 1, L, Dh].
|
| 71 |
-
use_real: Whether to use real-valued RoPE implementation.
|
| 72 |
-
use_real_unbind_dim: Dimension to unbind when using real-valued RoPE (-1 or -2).
|
| 73 |
-
|
| 74 |
-
Returns:
|
| 75 |
-
Tensor with rotary embeddings applied, same shape as input x.
|
| 76 |
-
"""
|
| 77 |
-
if use_real:
|
| 78 |
-
cos, sin = freqs_cis
|
| 79 |
-
if cos.dim() == 2: # [L, Dh] → [1, 1, L, Dh]
|
| 80 |
-
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 81 |
-
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 82 |
-
if cos.dim() != 4 or sin.dim() != 4:
|
| 83 |
-
raise RuntimeError(f"RoPE must be 2D or 4D, got cos={cos.dim()}D, sin={sin.dim()}D")
|
| 84 |
-
|
| 85 |
-
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 86 |
-
|
| 87 |
-
if cos.size(-2) != x.size(-2) or cos.size(-1) != x.size(-1):
|
| 88 |
-
raise RuntimeError(
|
| 89 |
-
f"RoPE shape mismatch: rope[-2:]=({cos.size(-2)},{cos.size(-1)}) vs x[-2:]=({x.size(-2)},{x.size(-1)})"
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
if use_real_unbind_dim == -1:
|
| 93 |
-
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| 94 |
-
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 95 |
-
elif use_real_unbind_dim == -2:
|
| 96 |
-
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
|
| 97 |
-
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 98 |
-
else:
|
| 99 |
-
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 100 |
-
|
| 101 |
-
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 102 |
-
return out
|
| 103 |
-
else:
|
| 104 |
-
x_rot = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 105 |
-
freqs = freqs_cis.unsqueeze(2)
|
| 106 |
-
x_out = torch.view_as_real(x_rot * freqs).flatten(3)
|
| 107 |
-
return x_out.type_as(x)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class MotifVideoAttnProcessor2_0:
|
| 111 |
-
def __init__(self):
|
| 112 |
-
if not hasattr(F, "scaled_dot_product_attention"):
|
| 113 |
-
raise ImportError(
|
| 114 |
-
"MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
def __call__(
|
| 118 |
-
self,
|
| 119 |
-
attn: Attention,
|
| 120 |
-
hidden_states: torch.Tensor,
|
| 121 |
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 122 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 123 |
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 124 |
-
query_input: Optional[torch.Tensor] = None,
|
| 125 |
-
key_input: Optional[torch.Tensor] = None,
|
| 126 |
-
value_input: Optional[torch.Tensor] = None,
|
| 127 |
-
) -> torch.Tensor:
|
| 128 |
-
# Cross-attention mode: query already projected externally (cross_attn_query_proj + norm),
|
| 129 |
-
# skip to_q and only apply reshape + norm_q + RoPE. K/V use to_k/to_v as normal.
|
| 130 |
-
if query_input is not None:
|
| 131 |
-
query = query_input.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 132 |
-
key = attn.to_k(key_input)
|
| 133 |
-
value = attn.to_v(value_input)
|
| 134 |
-
|
| 135 |
-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 136 |
-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 137 |
-
|
| 138 |
-
if attn.norm_q is not None:
|
| 139 |
-
query = attn.norm_q(query)
|
| 140 |
-
if attn.norm_k is not None:
|
| 141 |
-
key = attn.norm_k(key)
|
| 142 |
-
|
| 143 |
-
if image_rotary_emb is not None:
|
| 144 |
-
query = apply_rotary_emb(query, image_rotary_emb)
|
| 145 |
-
|
| 146 |
-
hidden_states = F.scaled_dot_product_attention(
|
| 147 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 148 |
-
)
|
| 149 |
-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 150 |
-
hidden_states = hidden_states.to(query.dtype)
|
| 151 |
-
return hidden_states, None
|
| 152 |
-
|
| 153 |
-
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 154 |
-
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 155 |
-
|
| 156 |
-
# 1. QKV projections
|
| 157 |
-
query = attn.to_q(hidden_states)
|
| 158 |
-
key = attn.to_k(hidden_states)
|
| 159 |
-
value = attn.to_v(hidden_states)
|
| 160 |
-
|
| 161 |
-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 162 |
-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 163 |
-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 164 |
-
|
| 165 |
-
# 2. QK normalization
|
| 166 |
-
if attn.norm_q is not None:
|
| 167 |
-
query = attn.norm_q(query)
|
| 168 |
-
if attn.norm_k is not None:
|
| 169 |
-
key = attn.norm_k(key)
|
| 170 |
-
|
| 171 |
-
# 3. Rotational positional embeddings applied to latent stream
|
| 172 |
-
if image_rotary_emb is not None:
|
| 173 |
-
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 174 |
-
query = torch.cat(
|
| 175 |
-
[
|
| 176 |
-
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 177 |
-
query[:, :, -encoder_hidden_states.shape[1] :],
|
| 178 |
-
],
|
| 179 |
-
dim=2,
|
| 180 |
-
)
|
| 181 |
-
key = torch.cat(
|
| 182 |
-
[
|
| 183 |
-
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 184 |
-
key[:, :, -encoder_hidden_states.shape[1] :],
|
| 185 |
-
],
|
| 186 |
-
dim=2,
|
| 187 |
-
)
|
| 188 |
-
else:
|
| 189 |
-
query = apply_rotary_emb(query, image_rotary_emb)
|
| 190 |
-
key = apply_rotary_emb(key, image_rotary_emb)
|
| 191 |
-
|
| 192 |
-
# 4. Encoder condition QKV projection and normalization
|
| 193 |
-
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
| 194 |
-
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 195 |
-
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 196 |
-
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 197 |
-
|
| 198 |
-
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 199 |
-
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 200 |
-
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 201 |
-
|
| 202 |
-
if attn.norm_added_q is not None:
|
| 203 |
-
encoder_query = attn.norm_added_q(encoder_query)
|
| 204 |
-
if attn.norm_added_k is not None:
|
| 205 |
-
encoder_key = attn.norm_added_k(encoder_key)
|
| 206 |
-
|
| 207 |
-
query = torch.cat([query, encoder_query], dim=2)
|
| 208 |
-
key = torch.cat([key, encoder_key], dim=2)
|
| 209 |
-
value = torch.cat([value, encoder_value], dim=2)
|
| 210 |
-
|
| 211 |
-
# 5. Attention
|
| 212 |
-
hidden_states = F.scaled_dot_product_attention(
|
| 213 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 214 |
-
)
|
| 215 |
-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 216 |
-
hidden_states = hidden_states.to(query.dtype)
|
| 217 |
-
|
| 218 |
-
# 6. Output projection
|
| 219 |
-
if encoder_hidden_states is not None:
|
| 220 |
-
hidden_states, encoder_hidden_states = (
|
| 221 |
-
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
| 222 |
-
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
if getattr(attn, "to_out", None) is not None:
|
| 226 |
-
hidden_states = attn.to_out[0](hidden_states)
|
| 227 |
-
hidden_states = attn.to_out[1](hidden_states)
|
| 228 |
-
|
| 229 |
-
if getattr(attn, "to_add_out", None) is not None:
|
| 230 |
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 231 |
-
|
| 232 |
-
return hidden_states, encoder_hidden_states
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
class MotifVideoPatchEmbed(nn.Module):
|
| 236 |
-
def __init__(
|
| 237 |
-
self,
|
| 238 |
-
patch_size: Union[int, Tuple[int, int, int]] = 16,
|
| 239 |
-
in_chans: int = 3,
|
| 240 |
-
embed_dim: int = 768,
|
| 241 |
-
) -> None:
|
| 242 |
-
super().__init__()
|
| 243 |
-
|
| 244 |
-
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
| 245 |
-
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 246 |
-
|
| 247 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 248 |
-
hidden_states = self.proj(hidden_states)
|
| 249 |
-
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
| 250 |
-
return hidden_states
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
class MotifVideoAdaNorm(nn.Module):
|
| 254 |
-
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
| 255 |
-
super().__init__()
|
| 256 |
-
|
| 257 |
-
out_features = out_features or 2 * in_features
|
| 258 |
-
self.linear = nn.Linear(in_features, out_features)
|
| 259 |
-
self.nonlinearity = nn.SiLU()
|
| 260 |
-
|
| 261 |
-
def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 262 |
-
temb = self.linear(self.nonlinearity(temb))
|
| 263 |
-
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
| 264 |
-
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
| 265 |
-
return gate_msa, gate_mlp
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
class MotifVideoConditionEmbedding(nn.Module):
|
| 269 |
-
def __init__(
|
| 270 |
-
self,
|
| 271 |
-
embedding_dim: int,
|
| 272 |
-
pooled_projection_dim: int | None,
|
| 273 |
-
):
|
| 274 |
-
super().__init__()
|
| 275 |
-
|
| 276 |
-
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 277 |
-
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 278 |
-
|
| 279 |
-
if isinstance(pooled_projection_dim, int):
|
| 280 |
-
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
| 281 |
-
|
| 282 |
-
def forward(
|
| 283 |
-
self,
|
| 284 |
-
timestep: torch.Tensor,
|
| 285 |
-
pooled_projection: torch.Tensor | None = None,
|
| 286 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 287 |
-
timesteps_proj = self.time_proj(timestep)
|
| 288 |
-
timestep_embedder_dtype = next(self.timestep_embedder.parameters()).dtype
|
| 289 |
-
conditioning = self.timestep_embedder(timesteps_proj.to(timestep_embedder_dtype)) # (N, D)
|
| 290 |
-
if pooled_projection is not None:
|
| 291 |
-
conditioning = conditioning + self.text_embedder(pooled_projection)
|
| 292 |
-
|
| 293 |
-
token_replace_emb = None
|
| 294 |
-
|
| 295 |
-
return conditioning, token_replace_emb
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L485-L486
|
| 299 |
-
def find_correction_factor(num_rotations, dim, base, max_position_embeddings):
|
| 300 |
-
dtype = num_rotations.dtype if isinstance(num_rotations, torch.Tensor) else torch.float32
|
| 301 |
-
max_pos_tensor = torch.as_tensor(max_position_embeddings, dtype=dtype)
|
| 302 |
-
return (dim * torch.log(max_pos_tensor / (num_rotations * 2 * math.pi))) / (
|
| 303 |
-
2 * math.log(base)
|
| 304 |
-
) # Inverse dim formula to find number of rotations
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L489-L495
|
| 308 |
-
def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len):
|
| 309 |
-
"""
|
| 310 |
-
Find the correction range for NTK-by-parts interpolation.
|
| 311 |
-
"""
|
| 312 |
-
low = torch.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len))
|
| 313 |
-
high = torch.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len))
|
| 314 |
-
low = torch.clamp(low, min=0)
|
| 315 |
-
high = torch.clamp(high, max=dim - 1)
|
| 316 |
-
return low, high # Clamp values just in case
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L498-L504
|
| 320 |
-
def linear_ramp_mask(min_val, max_val, num_dim):
|
| 321 |
-
if isinstance(min_val, torch.Tensor):
|
| 322 |
-
if (min_val == max_val).all():
|
| 323 |
-
max_val = max_val + 0.001
|
| 324 |
-
elif min_val == max_val:
|
| 325 |
-
max_val += 0.001
|
| 326 |
-
|
| 327 |
-
linear_func = (torch.arange(num_dim, dtype=torch.float32) - min_val) / (max_val - min_val)
|
| 328 |
-
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 329 |
-
return ramp_func
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L507-L511
|
| 333 |
-
def find_newbase_ntk(dim, base, scale):
|
| 334 |
-
"""
|
| 335 |
-
Calculate the new base for NTK-aware scaling.
|
| 336 |
-
"""
|
| 337 |
-
# Avoid division by zero when dim == 2 (or invalid smaller values).
|
| 338 |
-
# In these degenerate cases, fall back to the original base (no NTK adjustment).
|
| 339 |
-
if dim <= 2:
|
| 340 |
-
return base
|
| 341 |
-
return base * (scale ** (dim / (dim - 2)))
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
# Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L514-L652
|
| 345 |
-
def get_1d_rotary_pos_embed(
|
| 346 |
-
dim: int,
|
| 347 |
-
pos: Union[np.ndarray, int],
|
| 348 |
-
theta: float = 10000.0,
|
| 349 |
-
use_real=False,
|
| 350 |
-
linear_factor=1.0,
|
| 351 |
-
ntk_factor=1.0,
|
| 352 |
-
repeat_interleave_real=True,
|
| 353 |
-
freqs_dtype=torch.float32,
|
| 354 |
-
yarn=False,
|
| 355 |
-
max_pe_len=None,
|
| 356 |
-
ori_max_pe_len=64,
|
| 357 |
-
dype=False,
|
| 358 |
-
current_timestep=1.0,
|
| 359 |
-
):
|
| 360 |
-
"""
|
| 361 |
-
Precompute the frequency tensor for complex exponentials with RoPE.
|
| 362 |
-
Supports YARN interpolation for vision transformers.
|
| 363 |
-
|
| 364 |
-
Args:
|
| 365 |
-
dim (`int`):
|
| 366 |
-
Dimension of the frequency tensor.
|
| 367 |
-
pos (`np.ndarray` or `int`):
|
| 368 |
-
Position indices for the frequency tensor. [S] or scalar.
|
| 369 |
-
theta (`float`, *optional*, defaults to 10000.0):
|
| 370 |
-
Scaling factor for frequency computation.
|
| 371 |
-
use_real (`bool`, *optional*, defaults to False):
|
| 372 |
-
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 373 |
-
linear_factor (`float`, *optional*, defaults to 1.0):
|
| 374 |
-
Scaling factor for linear interpolation.
|
| 375 |
-
ntk_factor (`float`, *optional*, defaults to 1.0):
|
| 376 |
-
Scaling factor for NTK-Aware RoPE.
|
| 377 |
-
repeat_interleave_real (`bool`, *optional*, defaults to True):
|
| 378 |
-
If True and use_real, real and imaginary parts are interleaved with themselves to reach dim.
|
| 379 |
-
Otherwise, they are concatenated.
|
| 380 |
-
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
| 381 |
-
Data type of the frequency tensor.
|
| 382 |
-
yarn (`bool`, *optional*, defaults to False):
|
| 383 |
-
If True, use YARN interpolation combining NTK, linear, and base methods.
|
| 384 |
-
max_pe_len (`int`, *optional*):
|
| 385 |
-
Maximum position encoding length (current patches for vision models).
|
| 386 |
-
ori_max_pe_len (`int`, *optional*, defaults to 64):
|
| 387 |
-
Original maximum position encoding length (base patches for vision models).
|
| 388 |
-
dype (`bool`, *optional*, defaults to False):
|
| 389 |
-
If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
|
| 390 |
-
current_timestep (`float`, *optional*, defaults to 1.0):
|
| 391 |
-
Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
|
| 392 |
-
|
| 393 |
-
Returns:
|
| 394 |
-
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| 395 |
-
If use_real=True, returns tuple of (cos, sin) tensors.
|
| 396 |
-
"""
|
| 397 |
-
assert dim % 2 == 0
|
| 398 |
-
|
| 399 |
-
if isinstance(pos, int):
|
| 400 |
-
pos = torch.arange(pos)
|
| 401 |
-
if isinstance(pos, np.ndarray):
|
| 402 |
-
pos = torch.from_numpy(pos)
|
| 403 |
-
|
| 404 |
-
device = pos.device
|
| 405 |
-
|
| 406 |
-
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
| 407 |
-
if not isinstance(max_pe_len, torch.Tensor):
|
| 408 |
-
max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device)
|
| 409 |
-
|
| 410 |
-
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
| 411 |
-
|
| 412 |
-
beta_0 = 1.25
|
| 413 |
-
beta_1 = 0.75
|
| 414 |
-
gamma_0 = 16
|
| 415 |
-
gamma_1 = 2
|
| 416 |
-
|
| 417 |
-
freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
|
| 418 |
-
|
| 419 |
-
freqs_linear = 1.0 / torch.einsum(
|
| 420 |
-
"..., f -> ... f",
|
| 421 |
-
scale,
|
| 422 |
-
(theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)),
|
| 423 |
-
)
|
| 424 |
-
|
| 425 |
-
new_base = find_newbase_ntk(dim, theta, scale)
|
| 426 |
-
if new_base.dim() > 0:
|
| 427 |
-
new_base = new_base.view(-1, 1)
|
| 428 |
-
freqs_ntk = 1.0 / torch.pow(new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
|
| 429 |
-
if freqs_ntk.dim() > 1:
|
| 430 |
-
freqs_ntk = freqs_ntk.squeeze()
|
| 431 |
-
|
| 432 |
-
if dype:
|
| 433 |
-
beta_0 = torch.pow(beta_0, 2.0 * torch.pow(current_timestep, 2.0))
|
| 434 |
-
beta_1 = torch.pow(beta_1, 2.0 * torch.pow(current_timestep, 2.0))
|
| 435 |
-
|
| 436 |
-
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
| 437 |
-
high = torch.clamp(high, max=dim // 2)
|
| 438 |
-
|
| 439 |
-
freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
|
| 440 |
-
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
| 441 |
-
|
| 442 |
-
if dype:
|
| 443 |
-
gamma_0 = torch.pow(gamma_0, 2.0 * torch.pow(current_timestep, 2.0))
|
| 444 |
-
gamma_1 = torch.pow(gamma_1, 2.0 * torch.pow(current_timestep, 2.0))
|
| 445 |
-
|
| 446 |
-
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
| 447 |
-
high = torch.clamp(high, max=dim // 2)
|
| 448 |
-
|
| 449 |
-
freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
|
| 450 |
-
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
|
| 451 |
-
|
| 452 |
-
else:
|
| 453 |
-
theta_ntk = theta * ntk_factor
|
| 454 |
-
freqs = 1.0 / (theta_ntk ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)) / linear_factor
|
| 455 |
-
|
| 456 |
-
freqs = torch.outer(pos, freqs)
|
| 457 |
-
|
| 458 |
-
is_npu = freqs.device.type == "npu"
|
| 459 |
-
if is_npu:
|
| 460 |
-
freqs = freqs.float()
|
| 461 |
-
|
| 462 |
-
if use_real and repeat_interleave_real:
|
| 463 |
-
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
|
| 464 |
-
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
|
| 465 |
-
|
| 466 |
-
if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
|
| 467 |
-
mscale = torch.where(scale <= 1.0, 1.0, 0.1 * torch.log(scale) + 1.0).to(scale)
|
| 468 |
-
freqs_cos = freqs_cos * mscale
|
| 469 |
-
freqs_sin = freqs_sin * mscale
|
| 470 |
-
|
| 471 |
-
return freqs_cos, freqs_sin
|
| 472 |
-
elif use_real:
|
| 473 |
-
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
|
| 474 |
-
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
|
| 475 |
-
return freqs_cos, freqs_sin
|
| 476 |
-
else:
|
| 477 |
-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 478 |
-
return freqs_cis
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
class MotifVideoRotaryPosEmbed(nn.Module):
|
| 482 |
-
def __init__(
|
| 483 |
-
self,
|
| 484 |
-
patch_size: int,
|
| 485 |
-
patch_size_t: int,
|
| 486 |
-
rope_dim: List[int],
|
| 487 |
-
theta: float = 256.0,
|
| 488 |
-
base_latent_size: int | None = None,
|
| 489 |
-
):
|
| 490 |
-
"""
|
| 491 |
-
Rotary Positional Embedding (RoPE) for video latents.
|
| 492 |
-
|
| 493 |
-
Args:
|
| 494 |
-
patch_size (`int`):
|
| 495 |
-
Spatial patch size (e.g., 2).
|
| 496 |
-
patch_size_t (`int`):
|
| 497 |
-
Temporal patch size (e.g., 1).
|
| 498 |
-
rope_dim (`List[int]`):
|
| 499 |
-
Dimensions for RoPE across [Time, Height, Width] axes.
|
| 500 |
-
theta (`float`, *optional*, defaults to 256.0):
|
| 501 |
-
Base frequency for rotary embeddings.
|
| 502 |
-
base_latent_size (`int`, *optional*):
|
| 503 |
-
The maximum spatial dimension (in latent units) seen during training,
|
| 504 |
-
i.e. `training_resolution / vae_scale_factor_spatial`.
|
| 505 |
-
For example, for 1280x1280 training images and a VAE spatial downscale
|
| 506 |
-
(`vae_scale_factor_spatial`) of 8, this would be 160; for a downscale
|
| 507 |
-
of 16, it would be 80.
|
| 508 |
-
"""
|
| 509 |
-
super().__init__()
|
| 510 |
-
|
| 511 |
-
self.patch_size = patch_size
|
| 512 |
-
self.patch_size_t = patch_size_t
|
| 513 |
-
self.rope_dim = rope_dim
|
| 514 |
-
self.theta = theta
|
| 515 |
-
self.base_latent_size = base_latent_size
|
| 516 |
-
|
| 517 |
-
@lru_cache(maxsize=8)
|
| 518 |
-
def _get_base_patch_grid_size(self, base_latent_size: Optional[int], patch_size: int) -> Optional[int]:
|
| 519 |
-
return base_latent_size // patch_size if base_latent_size else None
|
| 520 |
-
|
| 521 |
-
@lru_cache(maxsize=8)
|
| 522 |
-
def _get_dynamic_interpolation_scale(self, h: int, w: int, base_grid_size: int) -> float:
|
| 523 |
-
return math.sqrt(h * w / (base_grid_size**2))
|
| 524 |
-
|
| 525 |
-
def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 526 |
-
if self.training:
|
| 527 |
-
assert self.base_latent_size is None, (
|
| 528 |
-
"RoPE interpolation/extrapolation logic should only be enabled for inference. "
|
| 529 |
-
f"During training, base_latent_size must be None, but got {self.base_latent_size!r}."
|
| 530 |
-
)
|
| 531 |
-
|
| 532 |
-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 533 |
-
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
| 534 |
-
|
| 535 |
-
axes_grids = []
|
| 536 |
-
for i in range(3):
|
| 537 |
-
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
| 538 |
-
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
| 539 |
-
# differences in layerwise debugging outputs, but visually it is the same.
|
| 540 |
-
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
| 541 |
-
axes_grids.append(grid)
|
| 542 |
-
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
|
| 543 |
-
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
| 544 |
-
|
| 545 |
-
base_patch_grid_size = self._get_base_patch_grid_size(self.base_latent_size, self.patch_size)
|
| 546 |
-
if base_patch_grid_size is not None:
|
| 547 |
-
if base_patch_grid_size <= 0:
|
| 548 |
-
raise ValueError(f"base_patch_grid_size must be a positive number, got {base_patch_grid_size}.")
|
| 549 |
-
dynamic_interpolation_scale = self._get_dynamic_interpolation_scale(
|
| 550 |
-
rope_sizes[1], rope_sizes[2], base_patch_grid_size
|
| 551 |
-
)
|
| 552 |
-
|
| 553 |
-
normalized_timestep = torch.tensor(1.0)
|
| 554 |
-
if not self.training and timestep is not None:
|
| 555 |
-
normalized_timestep = timestep[0] / NUM_TRAIN_TIMESTEPS
|
| 556 |
-
|
| 557 |
-
freqs = []
|
| 558 |
-
for i in range(3):
|
| 559 |
-
common_kwargs = {
|
| 560 |
-
"dim": self.rope_dim[i],
|
| 561 |
-
"pos": grid[i].reshape(-1),
|
| 562 |
-
"theta": self.theta,
|
| 563 |
-
"use_real": True,
|
| 564 |
-
"freqs_dtype": torch.float64,
|
| 565 |
-
}
|
| 566 |
-
|
| 567 |
-
# Apply scaling only to spatial dimensions (Height and Width, i=1 and i=2)
|
| 568 |
-
if i > 0 and base_patch_grid_size is not None and dynamic_interpolation_scale > 1.0:
|
| 569 |
-
# We project the training base to the current size using the uniform scale factor.
|
| 570 |
-
# max_pe_len tells the RoPE logic the "new" maximum length it's dealing with.
|
| 571 |
-
max_pe_len = torch.tensor(
|
| 572 |
-
base_patch_grid_size * dynamic_interpolation_scale,
|
| 573 |
-
dtype=torch.float64,
|
| 574 |
-
device=hidden_states.device,
|
| 575 |
-
)
|
| 576 |
-
|
| 577 |
-
freq = get_1d_rotary_pos_embed(
|
| 578 |
-
**common_kwargs,
|
| 579 |
-
yarn=True, # Enable Yet Another RoPE extensioN (YARN) for extrapolation
|
| 580 |
-
max_pe_len=max_pe_len,
|
| 581 |
-
ori_max_pe_len=base_patch_grid_size, # The original training scale
|
| 582 |
-
dype=True, # Enable Dynamic Position Encoding (time-aware)
|
| 583 |
-
current_timestep=normalized_timestep,
|
| 584 |
-
)
|
| 585 |
-
else:
|
| 586 |
-
# Time dimension OR within training bounds -> Standard RoPE
|
| 587 |
-
freq = get_1d_rotary_pos_embed(**common_kwargs)
|
| 588 |
-
|
| 589 |
-
freqs.append(freq)
|
| 590 |
-
|
| 591 |
-
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 592 |
-
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
| 593 |
-
return freqs_cos, freqs_sin
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
class MotifVideoImageProjection(nn.Module):
|
| 597 |
-
def __init__(self, in_features: int, hidden_size: int):
|
| 598 |
-
super().__init__()
|
| 599 |
-
self.norm_in = nn.LayerNorm(in_features)
|
| 600 |
-
self.linear_1 = nn.Linear(in_features, in_features)
|
| 601 |
-
self.act_fn = nn.GELU()
|
| 602 |
-
self.linear_2 = nn.Linear(in_features, hidden_size)
|
| 603 |
-
self.norm_out = nn.LayerNorm(hidden_size)
|
| 604 |
-
|
| 605 |
-
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
|
| 606 |
-
hidden_states = self.norm_in(image_embeds)
|
| 607 |
-
hidden_states = self.linear_1(hidden_states)
|
| 608 |
-
hidden_states = self.act_fn(hidden_states)
|
| 609 |
-
hidden_states = self.linear_2(hidden_states)
|
| 610 |
-
hidden_states = self.norm_out(hidden_states)
|
| 611 |
-
return hidden_states
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
class MotifVideoSingleTransformerBlock(nn.Module):
|
| 615 |
-
def __init__(
|
| 616 |
-
self,
|
| 617 |
-
num_attention_heads: int,
|
| 618 |
-
attention_head_dim: int,
|
| 619 |
-
mlp_ratio: float = 4.0,
|
| 620 |
-
qk_norm: str = "rms_norm",
|
| 621 |
-
norm_type: str = "layer_norm",
|
| 622 |
-
enable_text_cross_attention: bool = False,
|
| 623 |
-
) -> None:
|
| 624 |
-
super().__init__()
|
| 625 |
-
|
| 626 |
-
hidden_size = num_attention_heads * attention_head_dim
|
| 627 |
-
mlp_dim = int(hidden_size * mlp_ratio)
|
| 628 |
-
|
| 629 |
-
self.attn = Attention(
|
| 630 |
-
query_dim=hidden_size,
|
| 631 |
-
cross_attention_dim=None,
|
| 632 |
-
dim_head=attention_head_dim,
|
| 633 |
-
heads=num_attention_heads,
|
| 634 |
-
out_dim=hidden_size,
|
| 635 |
-
bias=True,
|
| 636 |
-
processor=MotifVideoAttnProcessor2_0(),
|
| 637 |
-
qk_norm=qk_norm,
|
| 638 |
-
eps=1e-6,
|
| 639 |
-
pre_only=True,
|
| 640 |
-
)
|
| 641 |
-
|
| 642 |
-
self.enable_text_cross_attention = enable_text_cross_attention
|
| 643 |
-
if enable_text_cross_attention:
|
| 644 |
-
self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
|
| 645 |
-
self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 646 |
-
self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
|
| 647 |
-
nn.init.zeros_(self.cross_attn_out_proj.weight)
|
| 648 |
-
nn.init.zeros_(self.cross_attn_out_proj.bias)
|
| 649 |
-
|
| 650 |
-
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type)
|
| 651 |
-
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
| 652 |
-
self.act_mlp = nn.GELU(approximate="tanh")
|
| 653 |
-
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
| 654 |
-
|
| 655 |
-
def forward(
|
| 656 |
-
self,
|
| 657 |
-
hidden_states: torch.Tensor,
|
| 658 |
-
encoder_hidden_states: torch.Tensor,
|
| 659 |
-
temb: torch.Tensor,
|
| 660 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 661 |
-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 662 |
-
token_replace_emb: torch.Tensor | None = None,
|
| 663 |
-
first_frame_num_tokens: int | None = None,
|
| 664 |
-
image_embed_seq_len: int = 0,
|
| 665 |
-
encoder_attention_mask: torch.Tensor | None = None,
|
| 666 |
-
) -> torch.Tensor:
|
| 667 |
-
text_seq_length = encoder_hidden_states.shape[1]
|
| 668 |
-
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 669 |
-
|
| 670 |
-
residual = hidden_states
|
| 671 |
-
|
| 672 |
-
# 1. Input normalization
|
| 673 |
-
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 674 |
-
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 675 |
-
|
| 676 |
-
norm_hidden_states, norm_encoder_hidden_states = (
|
| 677 |
-
norm_hidden_states[:, :-text_seq_length, :],
|
| 678 |
-
norm_hidden_states[:, -text_seq_length:, :],
|
| 679 |
-
)
|
| 680 |
-
|
| 681 |
-
# 2. Attention
|
| 682 |
-
attn_output, context_attn_output = self.attn(
|
| 683 |
-
hidden_states=norm_hidden_states,
|
| 684 |
-
encoder_hidden_states=norm_encoder_hidden_states,
|
| 685 |
-
attention_mask=attention_mask,
|
| 686 |
-
image_rotary_emb=image_rotary_emb,
|
| 687 |
-
)
|
| 688 |
-
|
| 689 |
-
# Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
|
| 690 |
-
if self.enable_text_cross_attention:
|
| 691 |
-
txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
|
| 692 |
-
text_mask = None
|
| 693 |
-
if encoder_attention_mask is not None:
|
| 694 |
-
text_mask = encoder_attention_mask[:, image_embed_seq_len:]
|
| 695 |
-
text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
|
| 696 |
-
cross_q = self.cross_attn_query_proj(attn_output)
|
| 697 |
-
cross_output, _ = self.attn(
|
| 698 |
-
hidden_states=cross_q,
|
| 699 |
-
query_input=cross_q,
|
| 700 |
-
key_input=txt_kv,
|
| 701 |
-
value_input=txt_kv,
|
| 702 |
-
attention_mask=text_mask,
|
| 703 |
-
image_rotary_emb=image_rotary_emb,
|
| 704 |
-
)
|
| 705 |
-
attn_output = attn_output + self.cross_attn_out_proj(cross_output)
|
| 706 |
-
|
| 707 |
-
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
| 708 |
-
|
| 709 |
-
# 3. Modulation and residual connection
|
| 710 |
-
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 711 |
-
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
|
| 712 |
-
hidden_states = hidden_states + residual
|
| 713 |
-
|
| 714 |
-
hidden_states, encoder_hidden_states = (
|
| 715 |
-
hidden_states[:, :-text_seq_length, :],
|
| 716 |
-
hidden_states[:, -text_seq_length:, :],
|
| 717 |
-
)
|
| 718 |
-
return hidden_states, encoder_hidden_states
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
class MotifVideoTransformerBlock(nn.Module):
|
| 722 |
-
def __init__(
|
| 723 |
-
self,
|
| 724 |
-
num_attention_heads: int,
|
| 725 |
-
attention_head_dim: int,
|
| 726 |
-
mlp_ratio: float,
|
| 727 |
-
qk_norm: str = "rms_norm",
|
| 728 |
-
norm_type: str = "layer_norm",
|
| 729 |
-
enable_text_cross_attention: bool = False,
|
| 730 |
-
) -> None:
|
| 731 |
-
super().__init__()
|
| 732 |
-
|
| 733 |
-
hidden_size = num_attention_heads * attention_head_dim
|
| 734 |
-
|
| 735 |
-
self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type)
|
| 736 |
-
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type)
|
| 737 |
-
|
| 738 |
-
self.attn = Attention(
|
| 739 |
-
query_dim=hidden_size,
|
| 740 |
-
cross_attention_dim=None,
|
| 741 |
-
added_kv_proj_dim=hidden_size,
|
| 742 |
-
dim_head=attention_head_dim,
|
| 743 |
-
heads=num_attention_heads,
|
| 744 |
-
out_dim=hidden_size,
|
| 745 |
-
context_pre_only=False,
|
| 746 |
-
bias=True,
|
| 747 |
-
processor=MotifVideoAttnProcessor2_0(),
|
| 748 |
-
qk_norm=qk_norm,
|
| 749 |
-
eps=1e-6,
|
| 750 |
-
)
|
| 751 |
-
|
| 752 |
-
self.enable_text_cross_attention = enable_text_cross_attention
|
| 753 |
-
if enable_text_cross_attention:
|
| 754 |
-
self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
|
| 755 |
-
self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 756 |
-
self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
|
| 757 |
-
nn.init.zeros_(self.cross_attn_out_proj.weight)
|
| 758 |
-
nn.init.zeros_(self.cross_attn_out_proj.bias)
|
| 759 |
-
|
| 760 |
-
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 761 |
-
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 762 |
-
|
| 763 |
-
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 764 |
-
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 765 |
-
|
| 766 |
-
def forward(
|
| 767 |
-
self,
|
| 768 |
-
hidden_states: torch.Tensor,
|
| 769 |
-
encoder_hidden_states: torch.Tensor,
|
| 770 |
-
temb: torch.Tensor,
|
| 771 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 772 |
-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 773 |
-
token_replace_emb: torch.Tensor | None = None,
|
| 774 |
-
first_frame_num_tokens: int | None = None,
|
| 775 |
-
image_embed_seq_len: int = 0,
|
| 776 |
-
encoder_attention_mask: torch.Tensor | None = None,
|
| 777 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 778 |
-
# 1. Input normalization
|
| 779 |
-
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 780 |
-
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 781 |
-
encoder_hidden_states, emb=temb
|
| 782 |
-
)
|
| 783 |
-
|
| 784 |
-
# 2. Joint attention
|
| 785 |
-
attn_output, context_attn_output = self.attn(
|
| 786 |
-
hidden_states=norm_hidden_states,
|
| 787 |
-
encoder_hidden_states=norm_encoder_hidden_states,
|
| 788 |
-
attention_mask=attention_mask,
|
| 789 |
-
image_rotary_emb=image_rotary_emb,
|
| 790 |
-
)
|
| 791 |
-
|
| 792 |
-
# 3. Modulation and residual connection
|
| 793 |
-
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
| 794 |
-
|
| 795 |
-
# Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
|
| 796 |
-
if self.enable_text_cross_attention:
|
| 797 |
-
txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
|
| 798 |
-
text_mask = None
|
| 799 |
-
if encoder_attention_mask is not None:
|
| 800 |
-
text_mask = encoder_attention_mask[:, image_embed_seq_len:]
|
| 801 |
-
text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
|
| 802 |
-
cross_q = self.cross_attn_query_proj(attn_output)
|
| 803 |
-
cross_output, _ = self.attn(
|
| 804 |
-
hidden_states=cross_q,
|
| 805 |
-
query_input=cross_q,
|
| 806 |
-
key_input=txt_kv,
|
| 807 |
-
value_input=txt_kv,
|
| 808 |
-
attention_mask=text_mask,
|
| 809 |
-
image_rotary_emb=image_rotary_emb,
|
| 810 |
-
)
|
| 811 |
-
hidden_states = hidden_states + self.cross_attn_out_proj(cross_output)
|
| 812 |
-
|
| 813 |
-
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
| 814 |
-
|
| 815 |
-
norm_hidden_states = self.norm2(hidden_states)
|
| 816 |
-
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 817 |
-
|
| 818 |
-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 819 |
-
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 820 |
-
|
| 821 |
-
# 4. Feed-forward
|
| 822 |
-
ff_output = self.ff(norm_hidden_states)
|
| 823 |
-
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 824 |
-
|
| 825 |
-
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
| 826 |
-
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 827 |
-
|
| 828 |
-
return hidden_states, encoder_hidden_states
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
TransformerBlockRegistry.register(
|
| 832 |
-
model_class=MotifVideoTransformerBlock,
|
| 833 |
-
metadata=TransformerBlockMetadata(
|
| 834 |
-
return_hidden_states_index=0,
|
| 835 |
-
return_encoder_hidden_states_index=1,
|
| 836 |
-
),
|
| 837 |
-
)
|
| 838 |
-
TransformerBlockRegistry.register(
|
| 839 |
-
model_class=MotifVideoSingleTransformerBlock,
|
| 840 |
-
metadata=TransformerBlockMetadata(
|
| 841 |
-
return_hidden_states_index=0,
|
| 842 |
-
return_encoder_hidden_states_index=1,
|
| 843 |
-
),
|
| 844 |
-
)
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
class MotifVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
| 848 |
-
r"""
|
| 849 |
-
A Transformer model for video-like data used in [MotifVideo](https://huggingface.co/motif/motifvideo).
|
| 850 |
-
|
| 851 |
-
Args:
|
| 852 |
-
in_channels (`int`, defaults to `16`):
|
| 853 |
-
The number of channels in the input.
|
| 854 |
-
out_channels (`int`, defaults to `16`):
|
| 855 |
-
The number of channels in the output.
|
| 856 |
-
num_attention_heads (`int`, defaults to `24`):
|
| 857 |
-
The number of heads to use for multi-head attention.
|
| 858 |
-
attention_head_dim (`int`, defaults to `128`):
|
| 859 |
-
The number of channels in each head.
|
| 860 |
-
num_layers (`int`, defaults to `20`):
|
| 861 |
-
The number of layers of dual-stream blocks to use.
|
| 862 |
-
num_single_layers (`int`, defaults to `40`):
|
| 863 |
-
The number of layers of single-stream blocks to use.
|
| 864 |
-
|
| 865 |
-
mlp_ratio (`float`, defaults to `4.0`):
|
| 866 |
-
The ratio of the hidden layer size to the input size in the feedforward network.
|
| 867 |
-
patch_size (`int`, defaults to `2`):
|
| 868 |
-
The size of the spatial patches to use in the patch embedding layer.
|
| 869 |
-
patch_size_t (`int`, defaults to `1`):
|
| 870 |
-
The size of the temporal patches to use in the patch embedding layer.
|
| 871 |
-
qk_norm (`str`, defaults to `rms_norm`):
|
| 872 |
-
The normalization to use for the query and key projections in the attention layers.
|
| 873 |
-
text_embed_dim (`int`, defaults to `4096`):
|
| 874 |
-
Input dimension of text embeddings from the text encoder.
|
| 875 |
-
rope_theta (`float`, defaults to `256.0`):
|
| 876 |
-
The value of theta to use in the RoPE layer.
|
| 877 |
-
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 878 |
-
The dimensions of the axes to use in the RoPE layer.
|
| 879 |
-
base_latent_size (`int`, *optional*):
|
| 880 |
-
The maximum spatial dimension (in latent units) seen during training.
|
| 881 |
-
For example, if trained on 1280x1280 with a VAE downscale of 16, this is 80.
|
| 882 |
-
"""
|
| 883 |
-
|
| 884 |
-
_supports_gradient_checkpointing = True
|
| 885 |
-
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
| 886 |
-
_no_split_modules = [
|
| 887 |
-
"MotifVideoTransformerBlock",
|
| 888 |
-
"MotifVideoSingleTransformerBlock",
|
| 889 |
-
"MotifVideoPatchEmbed",
|
| 890 |
-
]
|
| 891 |
-
|
| 892 |
-
@register_to_config
|
| 893 |
-
def __init__(
|
| 894 |
-
self,
|
| 895 |
-
in_channels: int = 33,
|
| 896 |
-
out_channels: int = 16,
|
| 897 |
-
num_attention_heads: int = 24,
|
| 898 |
-
attention_head_dim: int = 128,
|
| 899 |
-
num_layers: int = 20,
|
| 900 |
-
num_single_layers: int = 40,
|
| 901 |
-
num_decoder_layers: int = 0,
|
| 902 |
-
mlp_ratio: float = 4.0,
|
| 903 |
-
patch_size: int = 2,
|
| 904 |
-
patch_size_t: int = 1,
|
| 905 |
-
qk_norm: str = "rms_norm",
|
| 906 |
-
norm_type: str = "layer_norm",
|
| 907 |
-
text_embed_dim: int = 4096,
|
| 908 |
-
image_embed_dim: int | None = None,
|
| 909 |
-
pooled_projection_dim: int | None = None,
|
| 910 |
-
rope_theta: float = 256.0,
|
| 911 |
-
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
| 912 |
-
base_latent_size: int | None = None,
|
| 913 |
-
enable_text_cross_attention_dual: bool = False,
|
| 914 |
-
enable_text_cross_attention_single: bool = False,
|
| 915 |
-
) -> None:
|
| 916 |
-
super().__init__()
|
| 917 |
-
|
| 918 |
-
inner_dim = num_attention_heads * attention_head_dim
|
| 919 |
-
out_channels = out_channels or in_channels
|
| 920 |
-
|
| 921 |
-
# 1. Latent and condition embedders
|
| 922 |
-
self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
| 923 |
-
self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim)
|
| 924 |
-
|
| 925 |
-
# First frame conditioning: Image conditioning embedders
|
| 926 |
-
self.image_embed_dim = image_embed_dim
|
| 927 |
-
if image_embed_dim is not None:
|
| 928 |
-
# Project image embeddings from vision encoder to transformer dim
|
| 929 |
-
self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim)
|
| 930 |
-
|
| 931 |
-
self.time_text_embed = MotifVideoConditionEmbedding(inner_dim, pooled_projection_dim)
|
| 932 |
-
|
| 933 |
-
# 2. RoPE
|
| 934 |
-
self.rope = MotifVideoRotaryPosEmbed(
|
| 935 |
-
patch_size, patch_size_t, rope_axes_dim, rope_theta, base_latent_size=base_latent_size
|
| 936 |
-
)
|
| 937 |
-
|
| 938 |
-
# Cross-attention config
|
| 939 |
-
self.enable_text_cross_attention_dual = enable_text_cross_attention_dual
|
| 940 |
-
self.enable_text_cross_attention_single = enable_text_cross_attention_single
|
| 941 |
-
|
| 942 |
-
# 3. Dual stream transformer blocks
|
| 943 |
-
self.transformer_blocks = nn.ModuleList(
|
| 944 |
-
[
|
| 945 |
-
MotifVideoTransformerBlock(
|
| 946 |
-
num_attention_heads,
|
| 947 |
-
attention_head_dim,
|
| 948 |
-
mlp_ratio=mlp_ratio,
|
| 949 |
-
qk_norm=qk_norm,
|
| 950 |
-
norm_type=norm_type,
|
| 951 |
-
enable_text_cross_attention=enable_text_cross_attention_dual,
|
| 952 |
-
)
|
| 953 |
-
for _ in range(num_layers)
|
| 954 |
-
]
|
| 955 |
-
)
|
| 956 |
-
|
| 957 |
-
# 4. Single stream transformer blocks
|
| 958 |
-
# Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder)
|
| 959 |
-
num_encoder_single = num_single_layers - num_decoder_layers
|
| 960 |
-
self.single_transformer_blocks = nn.ModuleList(
|
| 961 |
-
[
|
| 962 |
-
MotifVideoSingleTransformerBlock(
|
| 963 |
-
num_attention_heads,
|
| 964 |
-
attention_head_dim,
|
| 965 |
-
mlp_ratio=mlp_ratio,
|
| 966 |
-
qk_norm=qk_norm,
|
| 967 |
-
norm_type=norm_type,
|
| 968 |
-
enable_text_cross_attention=enable_text_cross_attention_single
|
| 969 |
-
if i < num_encoder_single
|
| 970 |
-
else False,
|
| 971 |
-
)
|
| 972 |
-
for i in range(num_single_layers)
|
| 973 |
-
]
|
| 974 |
-
)
|
| 975 |
-
|
| 976 |
-
# 5. Output projection
|
| 977 |
-
self.norm_out = AdaLayerNormContinuous(
|
| 978 |
-
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type=norm_type
|
| 979 |
-
)
|
| 980 |
-
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
| 981 |
-
|
| 982 |
-
# Verify cross-attention config matches actual block state.
|
| 983 |
-
# Catches silent misconfiguration (e.g. checkpoint config with renamed keys).
|
| 984 |
-
for i, block in enumerate(self.transformer_blocks):
|
| 985 |
-
if block.enable_text_cross_attention != enable_text_cross_attention_dual:
|
| 986 |
-
raise ValueError(
|
| 987 |
-
f"transformer_blocks[{i}].enable_text_cross_attention="
|
| 988 |
-
f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. "
|
| 989 |
-
f"Check checkpoint config.json key names match __init__ parameters."
|
| 990 |
-
)
|
| 991 |
-
num_encoder_single = num_single_layers - num_decoder_layers
|
| 992 |
-
for i, block in enumerate(self.single_transformer_blocks):
|
| 993 |
-
expected = enable_text_cross_attention_single if i < num_encoder_single else False
|
| 994 |
-
if block.enable_text_cross_attention != expected:
|
| 995 |
-
raise ValueError(
|
| 996 |
-
f"single_transformer_blocks[{i}].enable_text_cross_attention="
|
| 997 |
-
f"{block.enable_text_cross_attention}, expected {expected}. "
|
| 998 |
-
f"Check checkpoint config.json key names match __init__ parameters."
|
| 999 |
-
)
|
| 1000 |
-
|
| 1001 |
-
self.gradient_checkpointing = False
|
| 1002 |
-
self.num_decoder_layers = num_decoder_layers
|
| 1003 |
-
|
| 1004 |
-
@property
|
| 1005 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 1006 |
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 1007 |
-
r"""
|
| 1008 |
-
Returns:
|
| 1009 |
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 1010 |
-
indexed by its weight name.
|
| 1011 |
-
"""
|
| 1012 |
-
# set recursively
|
| 1013 |
-
processors = {}
|
| 1014 |
-
|
| 1015 |
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 1016 |
-
if hasattr(module, "get_processor"):
|
| 1017 |
-
processors[f"{name}.processor"] = module.get_processor()
|
| 1018 |
-
|
| 1019 |
-
for sub_name, child in module.named_children():
|
| 1020 |
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 1021 |
-
|
| 1022 |
-
return processors
|
| 1023 |
-
|
| 1024 |
-
for name, module in self.named_children():
|
| 1025 |
-
fn_recursive_add_processors(name, module, processors)
|
| 1026 |
-
|
| 1027 |
-
return processors
|
| 1028 |
-
|
| 1029 |
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 1030 |
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 1031 |
-
r"""
|
| 1032 |
-
Sets the attention processor to use to compute attention.
|
| 1033 |
-
|
| 1034 |
-
Parameters:
|
| 1035 |
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 1036 |
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 1037 |
-
for **all** `Attention` layers.
|
| 1038 |
-
|
| 1039 |
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 1040 |
-
processor. This is strongly recommended when setting trainable attention processors.
|
| 1041 |
-
|
| 1042 |
-
"""
|
| 1043 |
-
count = len(self.attn_processors.keys())
|
| 1044 |
-
|
| 1045 |
-
if isinstance(processor, dict) and len(processor) != count:
|
| 1046 |
-
raise ValueError(
|
| 1047 |
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 1048 |
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 1049 |
-
)
|
| 1050 |
-
|
| 1051 |
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 1052 |
-
if hasattr(module, "set_processor"):
|
| 1053 |
-
if not isinstance(processor, dict):
|
| 1054 |
-
module.set_processor(processor)
|
| 1055 |
-
else:
|
| 1056 |
-
module.set_processor(processor.pop(f"{name}.processor"))
|
| 1057 |
-
|
| 1058 |
-
for sub_name, child in module.named_children():
|
| 1059 |
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 1060 |
-
|
| 1061 |
-
for name, module in self.named_children():
|
| 1062 |
-
fn_recursive_attn_processor(name, module, processor)
|
| 1063 |
-
|
| 1064 |
-
def _maybe_gradient_checkpoint_block(self, block, *args):
|
| 1065 |
-
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1066 |
-
return self._gradient_checkpointing_func(block, *args)
|
| 1067 |
-
return block(*args)
|
| 1068 |
-
|
| 1069 |
-
def _get_unwrapped_blocks(self, blocks):
|
| 1070 |
-
if hasattr(blocks, "_checkpoint_wrapped_module"):
|
| 1071 |
-
return blocks._checkpoint_wrapped_module
|
| 1072 |
-
elif hasattr(blocks, "module"):
|
| 1073 |
-
return blocks.module
|
| 1074 |
-
return blocks
|
| 1075 |
-
|
| 1076 |
-
def _create_attention_mask(
|
| 1077 |
-
self,
|
| 1078 |
-
hidden_states: torch.Tensor,
|
| 1079 |
-
encoder_attention_mask: torch.Tensor,
|
| 1080 |
-
) -> torch.Tensor:
|
| 1081 |
-
"""
|
| 1082 |
-
Create attention mask of shape [B, 1, 1, N] where N = L + E,
|
| 1083 |
-
based on latent tokens (always valid) and the encoder mask.
|
| 1084 |
-
|
| 1085 |
-
Args:
|
| 1086 |
-
hidden_states: [B, L, D]
|
| 1087 |
-
encoder_attention_mask: [B, E] (required)
|
| 1088 |
-
|
| 1089 |
-
Returns:
|
| 1090 |
-
attention_mask: [B, 1, 1, N]
|
| 1091 |
-
"""
|
| 1092 |
-
attention_mask = F.pad(
|
| 1093 |
-
encoder_attention_mask.to(torch.bool),
|
| 1094 |
-
(hidden_states.shape[1], 0),
|
| 1095 |
-
value=True,
|
| 1096 |
-
)
|
| 1097 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L+E]
|
| 1098 |
-
return attention_mask
|
| 1099 |
-
|
| 1100 |
-
def forward(
|
| 1101 |
-
self,
|
| 1102 |
-
hidden_states: torch.Tensor,
|
| 1103 |
-
timestep: torch.LongTensor,
|
| 1104 |
-
encoder_hidden_states: torch.Tensor,
|
| 1105 |
-
encoder_attention_mask: torch.Tensor | None = None,
|
| 1106 |
-
pooled_projections: torch.Tensor | None = None,
|
| 1107 |
-
image_embeds: torch.Tensor | None = None,
|
| 1108 |
-
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1109 |
-
return_dict: bool = True,
|
| 1110 |
-
tread_mixin: Optional[Any] = None,
|
| 1111 |
-
tread_disabled: bool = False,
|
| 1112 |
-
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 1113 |
-
"""
|
| 1114 |
-
Forward pass of the MotifVideoTransformer3DModel.
|
| 1115 |
-
|
| 1116 |
-
Args:
|
| 1117 |
-
hidden_states: Input latent tensor [B, C, F, H, W].
|
| 1118 |
-
timestep: Diffusion timesteps [B].
|
| 1119 |
-
encoder_hidden_states: Text conditioning [B, E, D].
|
| 1120 |
-
encoder_attention_mask: Mask for text conditioning [B, E].
|
| 1121 |
-
pooled_projections: Pooled text embeddings [B, D].
|
| 1122 |
-
image_embeds: Optional image embeddings from vision encoder [B, N, D].
|
| 1123 |
-
attention_kwargs: Additional arguments for attention processors.
|
| 1124 |
-
return_dict: Whether to return a Transformer2DModelOutput.
|
| 1125 |
-
tread_mixin: Optional TreadMixin instance for token reduction.
|
| 1126 |
-
tread_disabled: When True, force tread_mixin to None (dense pass).
|
| 1127 |
-
torch.compile specializes on this bool, producing separate graphs
|
| 1128 |
-
for dense vs routed without attribute toggling.
|
| 1129 |
-
|
| 1130 |
-
Returns:
|
| 1131 |
-
Transformer2DModelOutput or tuple containing the predicted samples.
|
| 1132 |
-
"""
|
| 1133 |
-
if tread_disabled:
|
| 1134 |
-
tread_mixin = None
|
| 1135 |
-
elif tread_mixin is None:
|
| 1136 |
-
tread_mixin = getattr(self, "_inference_tread_mixin", None)
|
| 1137 |
-
|
| 1138 |
-
if attention_kwargs is not None:
|
| 1139 |
-
attention_kwargs = attention_kwargs.copy()
|
| 1140 |
-
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 1141 |
-
else:
|
| 1142 |
-
lora_scale = 1.0
|
| 1143 |
-
|
| 1144 |
-
if USE_PEFT_BACKEND:
|
| 1145 |
-
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 1146 |
-
scale_lora_layers(self, lora_scale)
|
| 1147 |
-
else:
|
| 1148 |
-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 1149 |
-
logger.warning(
|
| 1150 |
-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 1151 |
-
)
|
| 1152 |
-
|
| 1153 |
-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 1154 |
-
p, p_t = self.config.patch_size, self.config.patch_size_t
|
| 1155 |
-
post_patch_num_frames = num_frames // p_t
|
| 1156 |
-
post_patch_height = height // p
|
| 1157 |
-
post_patch_width = width // p
|
| 1158 |
-
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
| 1159 |
-
# 1. RoPE
|
| 1160 |
-
image_rotary_emb = self.rope(hidden_states, timestep=timestep)
|
| 1161 |
-
# 2. Conditional embeddings
|
| 1162 |
-
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections)
|
| 1163 |
-
hidden_states = self.x_embedder(hidden_states)
|
| 1164 |
-
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 1165 |
-
|
| 1166 |
-
# First frame conditioning: Image embeddings from vision encoder
|
| 1167 |
-
if image_embeds is not None:
|
| 1168 |
-
# image_embeds: [B, N, D_img] -> [B, N, D]
|
| 1169 |
-
image_embeds = self.image_embedder(image_embeds)
|
| 1170 |
-
encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1)
|
| 1171 |
-
# Extend attention mask for image tokens
|
| 1172 |
-
if encoder_attention_mask is not None:
|
| 1173 |
-
image_mask = torch.ones(
|
| 1174 |
-
image_embeds.shape[0],
|
| 1175 |
-
image_embeds.shape[1],
|
| 1176 |
-
device=encoder_attention_mask.device,
|
| 1177 |
-
dtype=encoder_attention_mask.dtype,
|
| 1178 |
-
)
|
| 1179 |
-
encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1)
|
| 1180 |
-
|
| 1181 |
-
# image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states
|
| 1182 |
-
image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0
|
| 1183 |
-
|
| 1184 |
-
decoder_hidden_states = hidden_states.clone()
|
| 1185 |
-
|
| 1186 |
-
if encoder_attention_mask is not None:
|
| 1187 |
-
attention_mask = self._create_attention_mask(
|
| 1188 |
-
hidden_states=hidden_states,
|
| 1189 |
-
encoder_attention_mask=encoder_attention_mask,
|
| 1190 |
-
)
|
| 1191 |
-
else:
|
| 1192 |
-
attention_mask = None
|
| 1193 |
-
|
| 1194 |
-
# TREAD state initialization: manage token reduction manually to support activation checkpointing
|
| 1195 |
-
tread_active = False
|
| 1196 |
-
current_route = None
|
| 1197 |
-
ids_keep = None
|
| 1198 |
-
x_full = None
|
| 1199 |
-
orig_mask = attention_mask
|
| 1200 |
-
orig_rope = image_rotary_emb
|
| 1201 |
-
latent_len = hidden_states.shape[1]
|
| 1202 |
-
|
| 1203 |
-
# 4. Dual stream transformer blocks (Encoder)
|
| 1204 |
-
for i, block in enumerate(self.transformer_blocks):
|
| 1205 |
-
# Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
|
| 1206 |
-
if is_tread_start(tread_mixin, tread_active, i):
|
| 1207 |
-
tread_active = True
|
| 1208 |
-
current_route = tread_mixin._tread_route
|
| 1209 |
-
# Reduce sequence length at the start of a TREAD route
|
| 1210 |
-
ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
|
| 1211 |
-
x_full = hidden_states.contiguous()
|
| 1212 |
-
hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
|
| 1213 |
-
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1214 |
-
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1215 |
-
|
| 1216 |
-
hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1217 |
-
block,
|
| 1218 |
-
hidden_states,
|
| 1219 |
-
encoder_hidden_states,
|
| 1220 |
-
temb,
|
| 1221 |
-
attention_mask,
|
| 1222 |
-
image_rotary_emb,
|
| 1223 |
-
token_replace_emb,
|
| 1224 |
-
first_frame_num_tokens,
|
| 1225 |
-
image_embed_seq_len,
|
| 1226 |
-
encoder_attention_mask,
|
| 1227 |
-
)
|
| 1228 |
-
|
| 1229 |
-
if is_tread_end(tread_mixin, tread_active, i):
|
| 1230 |
-
# Restore full sequence length at the end of a TREAD route
|
| 1231 |
-
hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
|
| 1232 |
-
tread_active = False
|
| 1233 |
-
current_route = None
|
| 1234 |
-
ids_keep = None
|
| 1235 |
-
x_full = None
|
| 1236 |
-
attention_mask = orig_mask
|
| 1237 |
-
image_rotary_emb = orig_rope
|
| 1238 |
-
|
| 1239 |
-
# We need to unwrap the blocks because CheckpointWrapper does not support len(),
|
| 1240 |
-
# which is required for slicing the blocks into encoder and decoder parts.
|
| 1241 |
-
single_transformer_blocks = self.single_transformer_blocks
|
| 1242 |
-
|
| 1243 |
-
# 5. Single stream transformer blocks (Encoder)
|
| 1244 |
-
num_dual = len(self.transformer_blocks)
|
| 1245 |
-
for i, block in enumerate(
|
| 1246 |
-
single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]
|
| 1247 |
-
):
|
| 1248 |
-
# Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
|
| 1249 |
-
abs_i = num_dual + i
|
| 1250 |
-
if is_tread_start(tread_mixin, tread_active, abs_i):
|
| 1251 |
-
tread_active = True
|
| 1252 |
-
current_route = tread_mixin._tread_route
|
| 1253 |
-
# Reduce sequence length at the start of a TREAD route
|
| 1254 |
-
ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
|
| 1255 |
-
x_full = hidden_states.contiguous()
|
| 1256 |
-
hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
|
| 1257 |
-
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1258 |
-
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1259 |
-
|
| 1260 |
-
hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1261 |
-
block,
|
| 1262 |
-
hidden_states,
|
| 1263 |
-
encoder_hidden_states,
|
| 1264 |
-
temb,
|
| 1265 |
-
attention_mask,
|
| 1266 |
-
image_rotary_emb,
|
| 1267 |
-
token_replace_emb,
|
| 1268 |
-
first_frame_num_tokens,
|
| 1269 |
-
image_embed_seq_len,
|
| 1270 |
-
encoder_attention_mask,
|
| 1271 |
-
)
|
| 1272 |
-
|
| 1273 |
-
if is_tread_end(tread_mixin, tread_active, abs_i):
|
| 1274 |
-
# Restore full sequence length at the end of a TREAD route
|
| 1275 |
-
hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
|
| 1276 |
-
tread_active = False
|
| 1277 |
-
current_route = None
|
| 1278 |
-
ids_keep = None
|
| 1279 |
-
x_full = None
|
| 1280 |
-
attention_mask = orig_mask
|
| 1281 |
-
image_rotary_emb = orig_rope
|
| 1282 |
-
|
| 1283 |
-
# 6. Single stream transformer blocks (Decoder)
|
| 1284 |
-
if self.num_decoder_layers > 0:
|
| 1285 |
-
encoder_hidden_states = hidden_states
|
| 1286 |
-
attention_mask = None
|
| 1287 |
-
|
| 1288 |
-
num_single = len(single_transformer_blocks)
|
| 1289 |
-
|
| 1290 |
-
for i, block in enumerate(single_transformer_blocks[-self.num_decoder_layers :]):
|
| 1291 |
-
abs_i = num_dual + (num_single - self.num_decoder_layers) + i
|
| 1292 |
-
if is_tread_start(tread_mixin, tread_active, abs_i):
|
| 1293 |
-
tread_active = True
|
| 1294 |
-
current_route = tread_mixin._tread_route
|
| 1295 |
-
# Reduce sequence length at the start of a TREAD route
|
| 1296 |
-
ids_keep = tread_mixin.keep_indices(decoder_hidden_states, current_route["sel"]).to(
|
| 1297 |
-
decoder_hidden_states.device
|
| 1298 |
-
)
|
| 1299 |
-
x_full = encoder_hidden_states.contiguous()
|
| 1300 |
-
x_t_full = decoder_hidden_states.contiguous()
|
| 1301 |
-
decoder_hidden_states = tread_mixin.gather_tokens(decoder_hidden_states, ids_keep)
|
| 1302 |
-
encoder_hidden_states = tread_mixin.gather_tokens(encoder_hidden_states, ids_keep)
|
| 1303 |
-
attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
|
| 1304 |
-
image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
|
| 1305 |
-
|
| 1306 |
-
decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
|
| 1307 |
-
block,
|
| 1308 |
-
decoder_hidden_states,
|
| 1309 |
-
encoder_hidden_states,
|
| 1310 |
-
temb,
|
| 1311 |
-
attention_mask,
|
| 1312 |
-
image_rotary_emb,
|
| 1313 |
-
token_replace_emb,
|
| 1314 |
-
first_frame_num_tokens,
|
| 1315 |
-
)
|
| 1316 |
-
|
| 1317 |
-
if is_tread_end(tread_mixin, tread_active, abs_i):
|
| 1318 |
-
# Restore full sequence length at the end of a TREAD route
|
| 1319 |
-
decoder_hidden_states = tread_mixin.scatter_tokens(decoder_hidden_states, ids_keep, x_t_full)
|
| 1320 |
-
encoder_hidden_states = tread_mixin.scatter_tokens(encoder_hidden_states, ids_keep, x_full)
|
| 1321 |
-
tread_active = False
|
| 1322 |
-
current_route = None
|
| 1323 |
-
ids_keep = None
|
| 1324 |
-
x_full = None
|
| 1325 |
-
x_t_full = None
|
| 1326 |
-
attention_mask = orig_mask
|
| 1327 |
-
image_rotary_emb = orig_rope
|
| 1328 |
-
|
| 1329 |
-
hidden_states = decoder_hidden_states
|
| 1330 |
-
|
| 1331 |
-
# 7. Output projection
|
| 1332 |
-
hidden_states = self.norm_out(hidden_states, temb)
|
| 1333 |
-
hidden_states = self.proj_out(hidden_states)
|
| 1334 |
-
|
| 1335 |
-
hidden_states = hidden_states.reshape(
|
| 1336 |
-
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
|
| 1337 |
-
)
|
| 1338 |
-
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
| 1339 |
-
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 1340 |
-
|
| 1341 |
-
if USE_PEFT_BACKEND:
|
| 1342 |
-
# remove `lora_scale` from each PEFT layer
|
| 1343 |
-
unscale_lora_layers(self, lora_scale)
|
| 1344 |
-
|
| 1345 |
-
if not return_dict:
|
| 1346 |
-
return (hidden_states,)
|
| 1347 |
-
|
| 1348 |
-
return Transformer2DModelOutput(
|
| 1349 |
-
sample=hidden_states,
|
| 1350 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|