[WIP] diffusers integration

#21
by kencwt - opened
.gitattributes CHANGED
@@ -40,3 +40,13 @@ assets/showcase_t2v.png filter=lfs diff=lfs merge=lfs -text
40
  tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
41
  assets/i2v_sample.jpg filter=lfs diff=lfs merge=lfs -text
42
  motif-video-technical-report.pdf filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
40
  tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
41
  assets/i2v_sample.jpg filter=lfs diff=lfs merge=lfs -text
42
  motif-video-technical-report.pdf filter=lfs diff=lfs merge=lfs -text
43
+ assets/astronaut.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ assets/bird.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ assets/fisherman.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ assets/underwater.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ assets/vows.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ assets/woman.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ assets/sage_compare_BF16.webp filter=lfs diff=lfs merge=lfs -text
50
+ assets/sage_compare_Q4_K_M.webp filter=lfs diff=lfs merge=lfs -text
51
+ assets/sage_compare_Q5_K_M.webp filter=lfs diff=lfs merge=lfs -text
52
+ assets/sage_compare_Q8_0.webp filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -9,6 +9,22 @@ tags:
9
  - diffusion-transformer
10
  pipeline_tag: text-to-video
11
  library_name: diffusers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  <p align="center">
@@ -31,8 +47,25 @@ library_name: diffusers
31
 
32
  ---
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ## 🔥 News
35
 
 
 
36
  - **[2026-04-14]** We release **Motif-Video 2B**, our 2B-parameter text-to-video and image-to-video diffusion transformer, together with the full [technical report](https://arxiv.org/abs/2604.16503).
37
 
38
  ---
@@ -108,41 +141,74 @@ For the full derivation of why Shared Cross-Attention shares K/V but not Q, and
108
  ### Requirements
109
 
110
  - Python 3.10+
111
- - CUDA-capable GPU with **30GB+ VRAM** (e.g., A100, H100) — for 24GB GPUs see [Memory-efficient Inference](#-memory-efficient-inference)
112
 
113
  ```bash
114
- pip install "diffusers>=0.35.2" "transformers>=5.5.4" torch accelerate ftfy einops sentencepiece regex Pillow imageio imageio-ffmpeg
 
115
  ```
116
 
117
  ### Text-to-Video (T2V)
118
 
119
  ```python
120
  import torch
121
- from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
 
 
 
 
122
  from diffusers.utils import export_to_video
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  guider = AdaptiveProjectedGuidance(
125
  guidance_scale=8.0,
126
  adaptive_projected_guidance_rescale=12.0,
127
  adaptive_projected_guidance_momentum=0.1,
128
  use_original_formulation=True,
 
129
  )
130
 
131
- pipe = DiffusionPipeline.from_pretrained(
132
  "Motif-Technologies/Motif-Video-2B",
133
- custom_pipeline="pipeline_motif_video",
134
- trust_remote_code=True,
135
  torch_dtype=torch.bfloat16,
136
  guider=guider,
137
  )
 
 
 
 
 
 
 
 
 
 
138
  pipe = pipe.to("cuda")
139
 
140
  output = pipe(
141
- prompt="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
 
142
  height=736,
143
  width=1280,
144
  num_frames=121,
145
  num_inference_steps=50,
 
 
146
  )
147
 
148
  export_to_video(output.frames[0], "output.mp4", fps=24)
@@ -152,7 +218,11 @@ export_to_video(output.frames[0], "output.mp4", fps=24)
152
 
153
  ```python
154
  import torch
155
- from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
 
 
 
 
156
  from diffusers.utils import export_to_video, load_image
157
 
158
  guider = AdaptiveProjectedGuidance(
@@ -160,26 +230,38 @@ guider = AdaptiveProjectedGuidance(
160
  adaptive_projected_guidance_rescale=12.0,
161
  adaptive_projected_guidance_momentum=0.1,
162
  use_original_formulation=True,
 
163
  )
164
 
165
- pipe = DiffusionPipeline.from_pretrained(
166
  "Motif-Technologies/Motif-Video-2B",
167
- custom_pipeline="pipeline_motif_video",
168
- trust_remote_code=True,
169
  torch_dtype=torch.bfloat16,
170
  guider=guider,
171
  )
 
 
 
 
 
 
 
 
 
172
  pipe = pipe.to("cuda")
173
 
174
  image = load_image("https://huggingface.co/Motif-Technologies/Motif-Video-2B/resolve/main/assets/i2v_sample.jpg")
175
 
176
  output = pipe(
177
- prompt="Three friends stride through a sun-bleached meadow as a warm breeze ripples the tall dry grass around their legs. The woman on the left turns her head to share a quiet laugh, the woman in the center pushes a loose curl behind her ear, and the man on the right tilts his face toward the sky. The camera drifts gently alongside them at walking pace, handheld, with soft overcast light.",
 
178
  image=image,
179
  height=736,
180
  width=1280,
181
  num_frames=121,
182
  num_inference_steps=50,
 
 
183
  )
184
 
185
  export_to_video(output.frames[0], "output.mp4", fps=24)
@@ -188,96 +270,66 @@ export_to_video(output.frames[0], "output.mp4", fps=24)
188
  ### CLI Inference
189
 
190
  ```bash
191
- # Text-to-Video
192
  python inference.py \
193
- --prompt "A time-lapse of a flower blooming in a dark room, dramatic lighting" \
194
  --output t2v_output.mp4
195
 
196
- # Image-to-Video
197
  python inference.py \
198
- --image assets/i2v_sample.jpg \
199
- --prompt "Three friends stride through a meadow as a warm breeze ripples the tall grass" \
200
- --output i2v_output.mp4
201
  ```
202
 
203
- See `inference.py` for all available options (`--help`).
204
 
205
  ### Recommended Settings
206
 
207
  | Parameter | Default | Notes |
208
  |---|---|---|
209
- | Resolution | 1280x736 | 720p, best quality |
210
  | Frames | 121 | ~5 seconds at 24fps |
211
- | Guidance scale | 8.0 | |
 
212
  | Inference steps | 50 | |
 
 
213
  | dtype | bfloat16 | Recommended for H100/A100 |
214
 
215
  ### 🔋 Memory-efficient Inference
216
 
217
- By default, `pipe.to("cuda")` loads all components onto the GPU simultaneously, requiring **~30 GB VRAM**.
218
-
219
- For GPUs with 24 GB or less (e.g. RTX 4090, RTX 3090), use `enable_model_cpu_offload()` with the `expandable_segments` allocator setting:
220
-
221
- ```bash
222
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
223
- ```
224
-
225
- ```python
226
- pipe = DiffusionPipeline.from_pretrained(
227
- "Motif-Technologies/Motif-Video-2B",
228
- custom_pipeline="pipeline_motif_video",
229
- trust_remote_code=True,
230
- torch_dtype=torch.bfloat16,
231
- guider=guider, # see T2V example above
232
- )
233
- pipe.enable_model_cpu_offload() # replaces pipe.to("cuda")
234
-
235
- output = pipe(prompt="...", height=736, width=1280, num_frames=121, num_inference_steps=50)
236
- export_to_video(output.frames[0], "output.mp4", fps=24)
237
- ```
238
-
239
- This moves each component (text encoder → transformer → VAE) to GPU only when needed. The `expandable_segments` setting allows the CUDA memory allocator to efficiently reuse memory released by earlier components, avoiding fragmentation-related OOM errors.
240
-
241
- | Mode | Peak VRAM | Speed | Recommended GPU |
242
- |------|-----------|-------|-----------------|
243
- | `pipe.to("cuda")` | ~30 GB | Fastest | A100, H100, H200 |
244
- | `enable_model_cpu_offload()` | ~19 GB | Similar | RTX 4090, RTX 3090 |
245
 
246
- #### FP8 Weight Quantization (Optional)
 
 
 
 
247
 
248
- For further VRAM reduction, you can quantize the transformer weights to FP8 using [torchao](https://github.com/pytorch/ao):
249
 
250
- ```bash
251
- pip install torchao
252
- ```
253
 
254
- ```python
255
- from torchao.quantization import quantize_, Float8WeightOnlyConfig
256
 
257
- pipe = DiffusionPipeline.from_pretrained(
258
- "Motif-Technologies/Motif-Video-2B",
259
- custom_pipeline="pipeline_motif_video",
260
- trust_remote_code=True,
261
- torch_dtype=torch.bfloat16,
262
- guider=guider, # see T2V example above
263
- )
264
- quantize_(pipe.transformer, Float8WeightOnlyConfig())
265
- pipe.enable_model_cpu_offload()
266
 
267
- output = pipe(prompt="...", height=736, width=1280, num_frames=121, num_inference_steps=50)
268
- export_to_video(output.frames[0], "output.mp4", fps=24)
269
- ```
 
 
270
 
271
- This stores the transformer weights in FP8 (8-bit) instead of BF16 (16-bit), reducing peak VRAM from ~19 GB to ~15 GB while keeping all computation in BF16 precision.
272
 
273
- | Mode | Peak VRAM | Notes |
274
- |------|-----------|-------|
275
- | `enable_model_cpu_offload()` | ~19 GB | BF16 baseline |
276
- | `+ Float8WeightOnlyConfig` | ~15 GB | FP8 weights, BF16 compute |
277
 
278
  ### 🖥️ ComfyUI
279
 
280
- Official ComfyUI custom nodes for Motif-Video 2B are currently in development. Stay tuned for updates.
 
 
281
 
282
  ---
283
 
 
9
  - diffusion-transformer
10
  pipeline_tag: text-to-video
11
  library_name: diffusers
12
+ widget:
13
+ - text: "A vibrant blue jay perches gracefully on a slender branch, its feathers shimmering in the soft morning light. The bird's keen eyes scan the surroundings, capturing the essence of the tranquil forest. It flutters its wings briefly, showcasing the intricate patterns of blue, white, and black on its plumage. The background reveals a lush canopy of green leaves, with rays of sunlight filtering through, creating a dappled effect on the forest floor. The blue jay then tilts its head, emitting a melodious call that echoes through the serene woodland, adding a touch of magic to the peaceful scene."
14
+ output:
15
+ url: assets/bird.mp4
16
+ - text: "Underwater footage of a vibrant coral reef ecosystem with tropical fish swimming through coral formations. Natural sunlight filtering down through clear water creates dancing light patterns on the reef. Smooth underwater camera movement, natural color correction preserving authentic ocean blues and coral colors, documentary marine biology style, peaceful and educational mood."
17
+ output:
18
+ url: assets/underwater.mp4
19
+ - text: "An old fisherman mends his nets on a stone harbor wall, weathered hands moving with practiced speed through the green mesh. Shot on a 50mm lens with a slow dolly-in from his side, the afternoon sun throws warm light across his salt-stained coat and the worn granite beneath him. Behind him, a single wooden boat bobs gently in a turquoise bay. Gulls drift through the distant sky in soft focus. The camera settles on his hands, then racks focus to his weathered, squinting eyes."
20
+ output:
21
+ url: assets/fisherman.mp4
22
+ - text: "A lone astronaut drifts just outside a derelict space station, tethered by a single silver line as Earth's terminator glows blue-white behind her. Shot with a slow wide-to-medium push, the camera floats alongside her in weightless silence, the curvature of the planet filling the lower third of the frame. Sunlight rakes across the hull's scarred panels, casting long hard shadows that stretch and shift as she rotates. Her visor reflects the aurora below, ribbons of green pulling across the glass. She reaches out with a gloved hand and lets her fingertips graze a dented antenna, the gesture small and reverent."
23
+ output:
24
+ url: assets/astronaut.mp4
25
+ - text: "A woman standing in a sunlit field as flower petals swirl around her in slow motion. Each petal floats gently through the golden light, casting tiny shadows. Her hair moves like water, and time seems to stand still."
26
+ output:
27
+ url: assets/woman.mp4
28
  ---
29
 
30
  <p align="center">
 
47
 
48
  ---
49
 
50
+ <!--
51
+ NOTE: This README is written against the CURRENT state of diffusers PR #13551
52
+ (pre-merge). The PR currently has issues:
53
+ - negative_prompt defaults to None (should be built-in)
54
+ - use_linear_quadratic_schedule defaults to True (should be False)
55
+ - DPMSolverMultistepScheduler crashes (pipeline always passes sigmas)
56
+ - No built-in SageAttention support (requires manual patching)
57
+
58
+ Code examples below include workarounds (explicit negative_prompt,
59
+ use_linear_quadratic_schedule=False, _FlowDPMSolver subclass).
60
+
61
+ TODO: Update after PR feedback is applied, and again after merge.
62
+ Tracking: https://github.com/MotifTechnologies/diffusers/pull/1
63
+ -->
64
+
65
  ## 🔥 News
66
 
67
+ - **[2026-04-28]** **ComfyUI custom nodes** released: [ComfyUI-MotifVideo2B](https://github.com/MotifTechnologies/ComfyUI-MotifVideo2B). GGUF workflow support coming soon.
68
+ - **[2026-04-28]** **GGUF quantized weights** now available at [Motif-Video-2B-GGUF](https://huggingface.co/Motif-Technologies/Motif-Video-2B-GGUF) — up to 2.7 GB VRAM savings with no speed penalty. **SageAttention** support for ~2× faster inference. See [GGUF + SageAttention](#🧊-gguf--sageattention) below.
69
  - **[2026-04-14]** We release **Motif-Video 2B**, our 2B-parameter text-to-video and image-to-video diffusion transformer, together with the full [technical report](https://arxiv.org/abs/2604.16503).
70
 
71
  ---
 
141
  ### Requirements
142
 
143
  - Python 3.10+
144
+ - CUDA-capable GPU with **30GB+ VRAM** (e.g., A100, H100) — for 24GB GPUs see [Memory-efficient Inference](🔋-memory-efficient-inference)
145
 
146
  ```bash
147
+ pip install "transformers>=5.5.4" torch accelerate ftfy einops sentencepiece regex Pillow imageio imageio-ffmpeg
148
+ pip install git+https://github.com/waitingcheung/diffusers.git@feat/motif-video
149
  ```
150
 
151
  ### Text-to-Video (T2V)
152
 
153
  ```python
154
  import torch
155
+ from diffusers import (
156
+ AdaptiveProjectedGuidance,
157
+ DPMSolverMultistepScheduler,
158
+ MotifVideoPipeline,
159
+ )
160
  from diffusers.utils import export_to_video
161
 
162
+
163
+ # DPMSolver++ subclass: ignores pipeline-supplied sigmas and builds its own
164
+ # flow-matching schedule. Will be unnecessary once PR #13551 adds the
165
+ # _is_flow_multistep branch.
166
+ class FlowDPMSolver(DPMSolverMultistepScheduler):
167
+ def set_timesteps(self, num_inference_steps=None, device=None,
168
+ sigmas=None, mu=None, timesteps=None):
169
+ if sigmas is not None and num_inference_steps is None:
170
+ num_inference_steps = len(sigmas)
171
+ super().set_timesteps(
172
+ num_inference_steps=num_inference_steps,
173
+ device=device, timesteps=timesteps,
174
+ )
175
+
176
+
177
  guider = AdaptiveProjectedGuidance(
178
  guidance_scale=8.0,
179
  adaptive_projected_guidance_rescale=12.0,
180
  adaptive_projected_guidance_momentum=0.1,
181
  use_original_formulation=True,
182
+ normalization_dims="spatial",
183
  )
184
 
185
+ pipe = MotifVideoPipeline.from_pretrained(
186
  "Motif-Technologies/Motif-Video-2B",
187
+ revision="diffusers-integration",
 
188
  torch_dtype=torch.bfloat16,
189
  guider=guider,
190
  )
191
+
192
+ # DPMSolver++ for faster convergence
193
+ pipe.scheduler = FlowDPMSolver(
194
+ num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
195
+ algorithm_type="dpmsolver++",
196
+ solver_order=2,
197
+ prediction_type="flow_prediction",
198
+ use_flow_sigmas=True,
199
+ flow_shift=15.0,
200
+ )
201
  pipe = pipe.to("cuda")
202
 
203
  output = pipe(
204
+ prompt="A woman standing in a sunlit field as flower petals swirl around her in slow motion. Each petal floats gently through the golden light, casting tiny shadows. Her hair moves like water, and time seems to stand still.",
205
+ negative_prompt="text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
206
  height=736,
207
  width=1280,
208
  num_frames=121,
209
  num_inference_steps=50,
210
+ frame_rate=24,
211
+ use_linear_quadratic_schedule=False,
212
  )
213
 
214
  export_to_video(output.frames[0], "output.mp4", fps=24)
 
218
 
219
  ```python
220
  import torch
221
+ from diffusers import (
222
+ AdaptiveProjectedGuidance,
223
+ DPMSolverMultistepScheduler,
224
+ MotifVideoPipeline,
225
+ )
226
  from diffusers.utils import export_to_video, load_image
227
 
228
  guider = AdaptiveProjectedGuidance(
 
230
  adaptive_projected_guidance_rescale=12.0,
231
  adaptive_projected_guidance_momentum=0.1,
232
  use_original_formulation=True,
233
+ normalization_dims="spatial",
234
  )
235
 
236
+ pipe = MotifVideoPipeline.from_pretrained(
237
  "Motif-Technologies/Motif-Video-2B",
238
+ revision="diffusers-integration",
 
239
  torch_dtype=torch.bfloat16,
240
  guider=guider,
241
  )
242
+
243
+ pipe.scheduler = FlowDPMSolver(
244
+ num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
245
+ algorithm_type="dpmsolver++",
246
+ solver_order=2,
247
+ prediction_type="flow_prediction",
248
+ use_flow_sigmas=True,
249
+ flow_shift=15.0,
250
+ )
251
  pipe = pipe.to("cuda")
252
 
253
  image = load_image("https://huggingface.co/Motif-Technologies/Motif-Video-2B/resolve/main/assets/i2v_sample.jpg")
254
 
255
  output = pipe(
256
+ prompt="Three friends stride through a sun-bleached meadow as a warm breeze ripples the tall dry grass around their legs.",
257
+ negative_prompt="text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
258
  image=image,
259
  height=736,
260
  width=1280,
261
  num_frames=121,
262
  num_inference_steps=50,
263
+ frame_rate=24,
264
+ use_linear_quadratic_schedule=False,
265
  )
266
 
267
  export_to_video(output.frames[0], "output.mp4", fps=24)
 
270
  ### CLI Inference
271
 
272
  ```bash
273
+ # Text-to-Video (default settings)
274
  python inference.py \
275
+ --prompt "A woman standing in a sunlit field as..." \
276
  --output t2v_output.mp4
277
 
278
+ # With SageAttention (~2x faster, requires sageattention package)
279
  python inference.py \
280
+ --prompt "Three friends stride through a sun-bleached meadow..." \
281
+ --use-sage-attention \
282
+ --output t2v_output.mp4
283
  ```
284
 
285
+ See `inference.py --help` for all available options.
286
 
287
  ### Recommended Settings
288
 
289
  | Parameter | Default | Notes |
290
  |---|---|---|
291
+ | Resolution | 1280×736 | 720p, best quality |
292
  | Frames | 121 | ~5 seconds at 24fps |
293
+ | Scheduler | DPMSolver++ | `solver_order=2`, `flow_shift=15.0` |
294
+ | Guidance scale | 8.0 | With APG (`normalization_dims="spatial"`) |
295
  | Inference steps | 50 | |
296
+ | Negative prompt | (built-in) | See code examples above |
297
+ | `use_linear_quadratic_schedule` | `False` | Must be set explicitly |
298
  | dtype | bfloat16 | Recommended for H100/A100 |
299
 
300
  ### 🔋 Memory-efficient Inference
301
 
302
+ For GPUs with 24 GB or less (e.g. RTX 4090, RTX 3090), CPU offloading and FP8 quantization can reduce peak VRAM from ~30 GB to ~15 GB with minimal speed impact.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ | Mode | Peak VRAM | Recommended GPU |
305
+ |------|-----------|-----------------|
306
+ | `pipe.to("cuda")` | ~30 GB | A100, H100, H200 |
307
+ | `enable_model_cpu_offload()` | ~19 GB | RTX 4090, RTX 3090 |
308
+ | `+ FP8 quantization` | ~15 GB | RTX 4090, RTX 3090 |
309
 
310
+ > **Full guide** [docs/memory-efficient-inference.md](docs/memory-efficient-inference.md)
311
 
312
+ ---
 
 
313
 
314
+ ### 🧊 GGUF + SageAttention
 
315
 
316
+ GGUF quantized weights at [Motif-Video-2B-GGUF](https://huggingface.co/Motif-Technologies/Motif-Video-2B-GGUF) — up to 2.7 GB VRAM savings with no speed penalty. Combined with [SageAttention](https://github.com/thu-ml/SageAttention) for ~1.6× faster inference.
 
 
 
 
 
 
 
 
317
 
318
+ | Variant | Sage (s/it) | Speedup | Peak alloc (GB) |
319
+ |---------|------------|---------|-----------------|
320
+ | BF16 | 14.75 | 1.58x | 15.12 |
321
+ | Q8_0 | 14.49 | 1.60x | 13.44 |
322
+ | Q4_K_M | 14.59 | 1.60x | 12.53 |
323
 
324
+ > **Full guide** [docs/gguf-sageattention.md](docs/gguf-sageattention.md)
325
 
326
+ ---
 
 
 
327
 
328
  ### 🖥️ ComfyUI
329
 
330
+ Official ComfyUI custom nodes: [ComfyUI-MotifVideo2B](https://github.com/MotifTechnologies/ComfyUI-MotifVideo2B)
331
+
332
+ > **Note:** Currently requires **High VRAM** mode. GGUF quantized model loading in ComfyUI is in progress.
333
 
334
  ---
335
 
_fm_solvers_unipc.py DELETED
@@ -1,759 +0,0 @@
1
- # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2
- # Convert unipc for flow matching
3
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
-
5
- import math
6
- from typing import List, Optional, Tuple, Union
7
-
8
- import numpy as np
9
- import torch
10
- from diffusers.configuration_utils import ConfigMixin, register_to_config
11
- from diffusers.schedulers.scheduling_utils import (
12
- KarrasDiffusionSchedulers,
13
- SchedulerMixin,
14
- SchedulerOutput,
15
- )
16
- from diffusers.utils import deprecate, is_scipy_available
17
-
18
-
19
- if is_scipy_available():
20
- pass
21
-
22
-
23
- class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
24
- """
25
- `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
26
-
27
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
28
- methods the library implements for all schedulers such as loading and saving.
29
-
30
- Args:
31
- num_train_timesteps (`int`, defaults to 1000):
32
- The number of diffusion steps to train the model.
33
- solver_order (`int`, default `2`):
34
- The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
35
- due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
36
- unconditional sampling.
37
- prediction_type (`str`, defaults to "flow_prediction"):
38
- Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
39
- the flow of the diffusion process.
40
- thresholding (`bool`, defaults to `False`):
41
- Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
42
- as Stable Diffusion.
43
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
44
- The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
45
- sample_max_value (`float`, defaults to 1.0):
46
- The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
47
- predict_x0 (`bool`, defaults to `True`):
48
- Whether to use the updating algorithm on the predicted x0.
49
- solver_type (`str`, default `bh2`):
50
- Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
51
- otherwise.
52
- lower_order_final (`bool`, default `True`):
53
- Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
54
- stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
55
- disable_corrector (`list`, default `[]`):
56
- Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
57
- and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
58
- usually disabled during the first few steps.
59
- solver_p (`SchedulerMixin`, default `None`):
60
- Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
61
- use_karras_sigmas (`bool`, *optional*, defaults to `False`):
62
- Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
63
- the sigmas are determined according to a sequence of noise levels {σi}.
64
- use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
65
- Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
66
- timestep_spacing (`str`, defaults to `"linspace"`):
67
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
68
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
69
- steps_offset (`int`, defaults to 0):
70
- An offset added to the inference steps, as required by some model families.
71
- final_sigmas_type (`str`, defaults to `"zero"`):
72
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
73
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
74
- """
75
-
76
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
77
- order = 1
78
-
79
- @register_to_config
80
- def __init__(
81
- self,
82
- num_train_timesteps: int = 1000,
83
- solver_order: int = 2,
84
- prediction_type: str = "flow_prediction",
85
- shift: Optional[float] = 1.0,
86
- use_dynamic_shifting=False,
87
- thresholding: bool = False,
88
- dynamic_thresholding_ratio: float = 0.995,
89
- sample_max_value: float = 1.0,
90
- predict_x0: bool = True,
91
- solver_type: str = "bh2",
92
- lower_order_final: bool = True,
93
- disable_corrector: List[int] = [],
94
- solver_p: Optional[SchedulerMixin] = None,
95
- timestep_spacing: str = "linspace",
96
- steps_offset: int = 0,
97
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
98
- ):
99
- if solver_type not in ["bh1", "bh2"]:
100
- if solver_type in ["midpoint", "heun", "logrho"]:
101
- self.register_to_config(solver_type="bh2")
102
- else:
103
- raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
104
-
105
- self.predict_x0 = predict_x0
106
- # setable values
107
- self.num_inference_steps = None
108
- alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
109
- sigmas = 1.0 - alphas
110
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
111
-
112
- if not use_dynamic_shifting:
113
- # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
114
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
115
-
116
- self.sigmas = sigmas
117
- self.timesteps = sigmas * num_train_timesteps
118
-
119
- self.model_outputs = [None] * solver_order
120
- self.timestep_list = [None] * solver_order
121
- self.lower_order_nums = 0
122
- self.disable_corrector = disable_corrector
123
- self.solver_p = solver_p
124
- self.last_sample = None
125
- self._step_index = None
126
- self._begin_index = None
127
-
128
- self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
129
- self.sigma_min = self.sigmas[-1].item()
130
- self.sigma_max = self.sigmas[0].item()
131
-
132
- @property
133
- def step_index(self):
134
- """
135
- The index counter for current timestep. It will increase 1 after each scheduler step.
136
- """
137
- return self._step_index
138
-
139
- @property
140
- def begin_index(self):
141
- """
142
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
143
- """
144
- return self._begin_index
145
-
146
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
147
- def set_begin_index(self, begin_index: int = 0):
148
- """
149
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
150
-
151
- Args:
152
- begin_index (`int`):
153
- The begin index for the scheduler.
154
- """
155
- self._begin_index = begin_index
156
-
157
- # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
158
- def set_timesteps(
159
- self,
160
- num_inference_steps: Union[int, None] = None,
161
- device: Optional[Union[str, torch.device]] = None,
162
- sigmas: Optional[List[float]] = None,
163
- mu: Optional[Union[float, None]] = None,
164
- shift: Optional[Union[float, None]] = None,
165
- ):
166
- """
167
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
168
- Args:
169
- num_inference_steps (`int`):
170
- Total number of the spacing of the time steps.
171
- device (`str` or `torch.device`, *optional*):
172
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
173
- """
174
-
175
- if self.config.use_dynamic_shifting and mu is None:
176
- raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
177
-
178
- if sigmas is None:
179
- sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore
180
-
181
- if self.config.use_dynamic_shifting:
182
- sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
183
- else:
184
- if shift is None:
185
- shift = self.config.shift
186
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
187
-
188
- if self.config.final_sigmas_type == "sigma_min":
189
- sigma_last = self.config.sigma_min
190
- elif self.config.final_sigmas_type == "zero":
191
- sigma_last = 0
192
- else:
193
- raise ValueError(
194
- f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
195
- )
196
-
197
- timesteps = sigmas * self.config.num_train_timesteps
198
- sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
199
-
200
- self.sigmas = torch.from_numpy(sigmas)
201
- self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
202
-
203
- self.num_inference_steps = len(timesteps)
204
-
205
- self.model_outputs = [
206
- None,
207
- ] * self.config.solver_order
208
- self.lower_order_nums = 0
209
- self.last_sample = None
210
- if self.solver_p:
211
- self.solver_p.set_timesteps(self.num_inference_steps, device=device)
212
-
213
- # add an index counter for schedulers that allow duplicated timesteps
214
- self._step_index = None
215
- self._begin_index = None
216
- self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
217
-
218
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
219
- def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
220
- """
221
- "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
222
- prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
223
- s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
224
- pixels from saturation at each step. We find that dynamic thresholding results in significantly better
225
- photorealism as well as better image-text alignment, especially when using very large guidance weights."
226
-
227
- https://arxiv.org/abs/2205.11487
228
- """
229
- dtype = sample.dtype
230
- batch_size, channels, *remaining_dims = sample.shape
231
-
232
- if dtype not in (torch.float32, torch.float64):
233
- sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
234
-
235
- # Flatten sample for doing quantile calculation along each image
236
- sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
237
-
238
- abs_sample = sample.abs() # "a certain percentile absolute pixel value"
239
-
240
- s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
241
- s = torch.clamp(
242
- s, min=1, max=self.config.sample_max_value
243
- ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
244
- s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
245
- sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
246
-
247
- sample = sample.reshape(batch_size, channels, *remaining_dims)
248
- sample = sample.to(dtype)
249
-
250
- return sample
251
-
252
- # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
253
- def _sigma_to_t(self, sigma):
254
- return sigma * self.config.num_train_timesteps
255
-
256
- def _sigma_to_alpha_sigma_t(self, sigma):
257
- return 1 - sigma, sigma
258
-
259
- # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
260
- def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
261
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
262
-
263
- def convert_model_output(
264
- self,
265
- model_output: torch.Tensor,
266
- *args,
267
- sample: Optional[torch.Tensor] = None,
268
- **kwargs,
269
- ) -> torch.Tensor:
270
- r"""
271
- Convert the model output to the corresponding type the UniPC algorithm needs.
272
-
273
- Args:
274
- model_output (`torch.Tensor`):
275
- The direct output from the learned diffusion model.
276
- timestep (`int`):
277
- The current discrete timestep in the diffusion chain.
278
- sample (`torch.Tensor`):
279
- A current instance of a sample created by the diffusion process.
280
-
281
- Returns:
282
- `torch.Tensor`:
283
- The converted model output.
284
- """
285
- timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
286
- if sample is None:
287
- if len(args) > 1:
288
- sample = args[1]
289
- else:
290
- raise ValueError("missing `sample` as a required keyward argument")
291
- if timestep is not None:
292
- deprecate(
293
- "timesteps",
294
- "1.0.0",
295
- "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
296
- )
297
-
298
- sigma = self.sigmas[self.step_index]
299
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
300
-
301
- if self.predict_x0:
302
- if self.config.prediction_type == "flow_prediction":
303
- sigma_t = self.sigmas[self.step_index]
304
- x0_pred = sample - sigma_t * model_output
305
- else:
306
- raise ValueError(
307
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
308
- " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
309
- )
310
-
311
- if self.config.thresholding:
312
- x0_pred = self._threshold_sample(x0_pred)
313
-
314
- return x0_pred
315
- else:
316
- if self.config.prediction_type == "flow_prediction":
317
- sigma_t = self.sigmas[self.step_index]
318
- epsilon = sample - (1 - sigma_t) * model_output
319
- else:
320
- raise ValueError(
321
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
322
- " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
323
- )
324
-
325
- if self.config.thresholding:
326
- sigma_t = self.sigmas[self.step_index]
327
- x0_pred = sample - sigma_t * model_output
328
- x0_pred = self._threshold_sample(x0_pred)
329
- epsilon = model_output + x0_pred
330
-
331
- return epsilon
332
-
333
- def multistep_uni_p_bh_update(
334
- self,
335
- model_output: torch.Tensor,
336
- *args,
337
- sample: Optional[torch.Tensor] = None,
338
- order: Optional[int] = None,
339
- **kwargs,
340
- ) -> torch.Tensor:
341
- """
342
- One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
343
-
344
- Args:
345
- model_output (`torch.Tensor`):
346
- The direct output from the learned diffusion model at the current timestep.
347
- prev_timestep (`int`):
348
- The previous discrete timestep in the diffusion chain.
349
- sample (`torch.Tensor`):
350
- A current instance of a sample created by the diffusion process.
351
- order (`int`):
352
- The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
353
-
354
- Returns:
355
- `torch.Tensor`:
356
- The sample tensor at the previous timestep.
357
- """
358
- prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
359
- if sample is None:
360
- if len(args) > 1:
361
- sample = args[1]
362
- else:
363
- raise ValueError(" missing `sample` as a required keyward argument")
364
- if order is None:
365
- if len(args) > 2:
366
- order = args[2]
367
- else:
368
- raise ValueError(" missing `order` as a required keyward argument")
369
- if prev_timestep is not None:
370
- deprecate(
371
- "prev_timestep",
372
- "1.0.0",
373
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
374
- )
375
- model_output_list = self.model_outputs
376
-
377
- s0 = self.timestep_list[-1]
378
- m0 = model_output_list[-1]
379
- x = sample
380
-
381
- if self.solver_p:
382
- x_t = self.solver_p.step(model_output, s0, x).prev_sample
383
- return x_t
384
-
385
- sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore
386
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
387
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
388
-
389
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
390
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
391
-
392
- h = lambda_t - lambda_s0
393
- device = sample.device
394
-
395
- rks = []
396
- D1s = []
397
- for i in range(1, order):
398
- si = self.step_index - i # pyright: ignore
399
- mi = model_output_list[-(i + 1)]
400
- alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
401
- lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
402
- rk = (lambda_si - lambda_s0) / h
403
- rks.append(rk)
404
- D1s.append((mi - m0) / rk) # pyright: ignore
405
-
406
- rks.append(1.0)
407
- rks = torch.tensor(rks, device=device)
408
-
409
- R = []
410
- b = []
411
-
412
- hh = -h if self.predict_x0 else h
413
- h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
414
- h_phi_k = h_phi_1 / hh - 1
415
-
416
- factorial_i = 1
417
-
418
- if self.config.solver_type == "bh1":
419
- B_h = hh
420
- elif self.config.solver_type == "bh2":
421
- B_h = torch.expm1(hh)
422
- else:
423
- raise NotImplementedError()
424
-
425
- for i in range(1, order + 1):
426
- R.append(torch.pow(rks, i - 1))
427
- b.append(h_phi_k * factorial_i / B_h)
428
- factorial_i *= i + 1
429
- h_phi_k = h_phi_k / hh - 1 / factorial_i
430
-
431
- R = torch.stack(R)
432
- b = torch.tensor(b, device=device)
433
-
434
- if len(D1s) > 0:
435
- D1s = torch.stack(D1s, dim=1) # (B, K)
436
- # for order 2, we use a simplified version
437
- if order == 2:
438
- rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
439
- else:
440
- rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
441
- else:
442
- D1s = None
443
-
444
- if self.predict_x0:
445
- x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
446
- if D1s is not None:
447
- pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
448
- else:
449
- pred_res = 0
450
- x_t = x_t_ - alpha_t * B_h * pred_res
451
- else:
452
- x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
453
- if D1s is not None:
454
- pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
455
- else:
456
- pred_res = 0
457
- x_t = x_t_ - sigma_t * B_h * pred_res
458
-
459
- x_t = x_t.to(x.dtype)
460
- return x_t
461
-
462
- def multistep_uni_c_bh_update(
463
- self,
464
- this_model_output: torch.Tensor,
465
- *args,
466
- last_sample: Optional[torch.Tensor] = None,
467
- this_sample: Optional[torch.Tensor] = None,
468
- order: Optional[int] = None,
469
- **kwargs,
470
- ) -> torch.Tensor:
471
- """
472
- One step for the UniC (B(h) version).
473
-
474
- Args:
475
- this_model_output (`torch.Tensor`):
476
- The model outputs at `x_t`.
477
- this_timestep (`int`):
478
- The current timestep `t`.
479
- last_sample (`torch.Tensor`):
480
- The generated sample before the last predictor `x_{t-1}`.
481
- this_sample (`torch.Tensor`):
482
- The generated sample after the last predictor `x_{t}`.
483
- order (`int`):
484
- The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
485
-
486
- Returns:
487
- `torch.Tensor`:
488
- The corrected sample tensor at the current timestep.
489
- """
490
- this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
491
- if last_sample is None:
492
- if len(args) > 1:
493
- last_sample = args[1]
494
- else:
495
- raise ValueError(" missing`last_sample` as a required keyward argument")
496
- if this_sample is None:
497
- if len(args) > 2:
498
- this_sample = args[2]
499
- else:
500
- raise ValueError(" missing`this_sample` as a required keyward argument")
501
- if order is None:
502
- if len(args) > 3:
503
- order = args[3]
504
- else:
505
- raise ValueError(" missing`order` as a required keyward argument")
506
- if this_timestep is not None:
507
- deprecate(
508
- "this_timestep",
509
- "1.0.0",
510
- "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
511
- )
512
-
513
- model_output_list = self.model_outputs
514
-
515
- m0 = model_output_list[-1]
516
- x = last_sample
517
- x_t = this_sample
518
- model_t = this_model_output
519
-
520
- sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore
521
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
522
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
523
-
524
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
525
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
526
-
527
- h = lambda_t - lambda_s0
528
- device = this_sample.device
529
-
530
- rks = []
531
- D1s = []
532
- for i in range(1, order):
533
- si = self.step_index - (i + 1) # pyright: ignore
534
- mi = model_output_list[-(i + 1)]
535
- alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
536
- lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
537
- rk = (lambda_si - lambda_s0) / h
538
- rks.append(rk)
539
- D1s.append((mi - m0) / rk) # pyright: ignore
540
-
541
- rks.append(1.0)
542
- rks = torch.tensor(rks, device=device)
543
-
544
- R = []
545
- b = []
546
-
547
- hh = -h if self.predict_x0 else h
548
- h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
549
- h_phi_k = h_phi_1 / hh - 1
550
-
551
- factorial_i = 1
552
-
553
- if self.config.solver_type == "bh1":
554
- B_h = hh
555
- elif self.config.solver_type == "bh2":
556
- B_h = torch.expm1(hh)
557
- else:
558
- raise NotImplementedError()
559
-
560
- for i in range(1, order + 1):
561
- R.append(torch.pow(rks, i - 1))
562
- b.append(h_phi_k * factorial_i / B_h)
563
- factorial_i *= i + 1
564
- h_phi_k = h_phi_k / hh - 1 / factorial_i
565
-
566
- R = torch.stack(R)
567
- b = torch.tensor(b, device=device)
568
-
569
- if len(D1s) > 0:
570
- D1s = torch.stack(D1s, dim=1)
571
- else:
572
- D1s = None
573
-
574
- # for order 1, we use a simplified version
575
- if order == 1:
576
- rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
577
- else:
578
- rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
579
-
580
- if self.predict_x0:
581
- x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
582
- if D1s is not None:
583
- corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
584
- else:
585
- corr_res = 0
586
- D1_t = model_t - m0
587
- x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
588
- else:
589
- x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
590
- if D1s is not None:
591
- corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
592
- else:
593
- corr_res = 0
594
- D1_t = model_t - m0
595
- x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
596
- x_t = x_t.to(x.dtype)
597
- return x_t
598
-
599
- def index_for_timestep(self, timestep, schedule_timesteps=None):
600
- if schedule_timesteps is None:
601
- schedule_timesteps = self.timesteps
602
-
603
- indices = (schedule_timesteps == timestep).nonzero()
604
-
605
- # The sigma index that is taken for the **very** first `step`
606
- # is always the second index (or the last index if there is only 1)
607
- # This way we can ensure we don't accidentally skip a sigma in
608
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
609
- pos = 1 if len(indices) > 1 else 0
610
-
611
- return indices[pos].item()
612
-
613
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
614
- def _init_step_index(self, timestep):
615
- """
616
- Initialize the step_index counter for the scheduler.
617
- """
618
-
619
- if self.begin_index is None:
620
- if isinstance(timestep, torch.Tensor):
621
- timestep = timestep.to(self.timesteps.device)
622
- self._step_index = self.index_for_timestep(timestep)
623
- else:
624
- self._step_index = self._begin_index
625
-
626
- def step(
627
- self,
628
- model_output: torch.Tensor,
629
- timestep: Union[int, torch.Tensor],
630
- sample: torch.Tensor,
631
- return_dict: bool = True,
632
- generator=None,
633
- ) -> Union[SchedulerOutput, Tuple]:
634
- """
635
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
636
- the multistep UniPC.
637
-
638
- Args:
639
- model_output (`torch.Tensor`):
640
- The direct output from learned diffusion model.
641
- timestep (`int`):
642
- The current discrete timestep in the diffusion chain.
643
- sample (`torch.Tensor`):
644
- A current instance of a sample created by the diffusion process.
645
- return_dict (`bool`):
646
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
647
-
648
- Returns:
649
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
650
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
651
- tuple is returned where the first element is the sample tensor.
652
-
653
- """
654
- if self.num_inference_steps is None:
655
- raise ValueError(
656
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
657
- )
658
-
659
- if self.step_index is None:
660
- self._init_step_index(timestep)
661
-
662
- use_corrector = (
663
- self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None # pyright: ignore
664
- )
665
-
666
- model_output_convert = self.convert_model_output(model_output, sample=sample)
667
- if use_corrector:
668
- sample = self.multistep_uni_c_bh_update(
669
- this_model_output=model_output_convert,
670
- last_sample=self.last_sample,
671
- this_sample=sample,
672
- order=self.this_order,
673
- )
674
-
675
- for i in range(self.config.solver_order - 1):
676
- self.model_outputs[i] = self.model_outputs[i + 1]
677
- self.timestep_list[i] = self.timestep_list[i + 1]
678
-
679
- self.model_outputs[-1] = model_output_convert
680
- self.timestep_list[-1] = timestep # pyright: ignore
681
-
682
- if self.config.lower_order_final:
683
- this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
684
- else:
685
- this_order = self.config.solver_order
686
-
687
- self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
688
- assert self.this_order > 0
689
-
690
- self.last_sample = sample
691
- prev_sample = self.multistep_uni_p_bh_update(
692
- model_output=model_output, # pass the original non-converted model output, in case solver-p is used
693
- sample=sample,
694
- order=self.this_order,
695
- )
696
-
697
- if self.lower_order_nums < self.config.solver_order:
698
- self.lower_order_nums += 1
699
-
700
- # upon completion increase step index by one
701
- self._step_index += 1 # pyright: ignore
702
-
703
- if not return_dict:
704
- return (prev_sample,)
705
-
706
- return SchedulerOutput(prev_sample=prev_sample)
707
-
708
- def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
709
- """
710
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
711
- current timestep.
712
-
713
- Args:
714
- sample (`torch.Tensor`):
715
- The input sample.
716
-
717
- Returns:
718
- `torch.Tensor`:
719
- A scaled input sample.
720
- """
721
- return sample
722
-
723
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
724
- def add_noise(
725
- self,
726
- original_samples: torch.Tensor,
727
- noise: torch.Tensor,
728
- timesteps: torch.IntTensor,
729
- ) -> torch.Tensor:
730
- # Make sure sigmas and timesteps have the same device and dtype as original_samples
731
- sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
732
- if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
733
- # mps does not support float64
734
- schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
735
- timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
736
- else:
737
- schedule_timesteps = self.timesteps.to(original_samples.device)
738
- timesteps = timesteps.to(original_samples.device)
739
-
740
- # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
741
- if self.begin_index is None:
742
- step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
743
- elif self.step_index is not None:
744
- # add_noise is called after first denoising step (for inpainting)
745
- step_indices = [self.step_index] * timesteps.shape[0]
746
- else:
747
- # add noise is called before first denoising step to create initial latent(img2img)
748
- step_indices = [self.begin_index] * timesteps.shape[0]
749
-
750
- sigma = sigmas[step_indices].flatten()
751
- while len(sigma.shape) < len(original_samples.shape):
752
- sigma = sigma.unsqueeze(-1)
753
-
754
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
755
- noisy_samples = alpha_t * original_samples + sigma_t * noise
756
- return noisy_samples
757
-
758
- def __len__(self):
759
- return self.config.num_train_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assets/astronaut.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c13fd69a91c40ce217162e1bee917c23b86fcd878301bcab11a48fcd3bfeded
3
+ size 1020141
assets/bird.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ceddcad73d270f8a03a5bfa5ab6cc3dc74c9b1ff3db3a652151b5c26efd36e9c
3
+ size 757207
assets/fisherman.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e67764b0104e66ec56dab3cc216f3887aef84d25f3bb7758443de1b8624c745
3
+ size 1506256
assets/i2v_sample.jpg ADDED

Git LFS Details

  • SHA256: a3709a6989fc201b9c4332eadc30d34b365a41ef300f6bf5ecc1fdc40a7c8969
  • Pointer size: 131 Bytes
  • Size of remote file: 378 kB
assets/sage_compare_BF16.webp ADDED

Git LFS Details

  • SHA256: c6de38ff09e335e7c33e7de359418b27de30af05b63d08f0d9ec521bfb7a583f
  • Pointer size: 132 Bytes
  • Size of remote file: 6.53 MB
assets/sage_compare_Q4_K_M.webp ADDED

Git LFS Details

  • SHA256: 6e7625ab6be438419a421f35f297963c80e1314e9cffdbbfe2fe9438966046cf
  • Pointer size: 132 Bytes
  • Size of remote file: 6.03 MB
assets/sage_compare_Q5_K_M.webp ADDED

Git LFS Details

  • SHA256: d4698ccf9716113f605de3a4d2e5ccff24dddfb9b34279a0a05da416a9e701d6
  • Pointer size: 132 Bytes
  • Size of remote file: 6.83 MB
assets/sage_compare_Q8_0.webp ADDED

Git LFS Details

  • SHA256: e4d15be42c5aa4cff9398a6b24922d7473a9815210651072f8ea4059fe288d7f
  • Pointer size: 132 Bytes
  • Size of remote file: 6.4 MB
assets/underwater.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:184b72b8672bca924bf8f0b568488daac507eafd8dd8ddcb15f9928549309550
3
+ size 2024562
assets/woman.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa2e93ccc2d333f26323a0944e4a3e9c0ee3064824ae23b245f5ab2c548a947
3
+ size 2057707
docs/gguf-sageattention.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🧊 GGUF + SageAttention
2
+
3
+ > See the main [README](../README.md) for `FlowDPMSolver` and pipeline setup.
4
+
5
+ GGUF quantized transformer weights are available at [Motif-Video-2B-GGUF](https://huggingface.co/Motif-Technologies/Motif-Video-2B-GGUF), reducing VRAM with minimal quality loss. Combined with [SageAttention](https://github.com/thu-ml/SageAttention) for ~2× faster attention computation.
6
+
7
+ ## GGUF Inference
8
+
9
+ ```bash
10
+ pip install gguf
11
+ ```
12
+
13
+ ```python
14
+ import torch
15
+ from diffusers import (
16
+ AdaptiveProjectedGuidance,
17
+ DPMSolverMultistepScheduler,
18
+ GGUFQuantizationConfig,
19
+ MotifVideoPipeline,
20
+ MotifVideoTransformer3DModel,
21
+ )
22
+ from diffusers.utils import export_to_video
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ guider = AdaptiveProjectedGuidance(
26
+ guidance_scale=8.0,
27
+ adaptive_projected_guidance_rescale=12.0,
28
+ adaptive_projected_guidance_momentum=0.1,
29
+ use_original_formulation=True,
30
+ normalization_dims="spatial",
31
+ )
32
+
33
+ variant = "Q4_K_M" # Options: Q4_0, Q4_1, Q4_K_M, Q5_0, Q5_1, Q5_K_M, Q6_K, Q8_0, BF16
34
+ ckpt_path = hf_hub_download(
35
+ "Motif-Technologies/Motif-Video-2B-GGUF",
36
+ filename=f"motifv-2b-dev-{variant}.gguf",
37
+ )
38
+
39
+ transformer = MotifVideoTransformer3DModel.from_single_file(
40
+ ckpt_path,
41
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
42
+ config="Motif-Technologies/Motif-Video-2B",
43
+ revision="diffusers-integration",
44
+ subfolder="transformer",
45
+ torch_dtype=torch.bfloat16,
46
+ )
47
+
48
+ pipe = MotifVideoPipeline.from_pretrained(
49
+ "Motif-Technologies/Motif-Video-2B",
50
+ revision="diffusers-integration",
51
+ torch_dtype=torch.bfloat16,
52
+ guider=guider,
53
+ transformer=transformer,
54
+ )
55
+
56
+ pipe.scheduler = FlowDPMSolver(
57
+ num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
58
+ algorithm_type="dpmsolver++",
59
+ solver_order=2,
60
+ prediction_type="flow_prediction",
61
+ use_flow_sigmas=True,
62
+ flow_shift=15.0,
63
+ )
64
+ pipe.enable_model_cpu_offload()
65
+
66
+ output = pipe(
67
+ prompt="A woman standing in a sunlit field as flower petals swirl around her in slow motion. Each petal floats gently through the golden light, casting tiny shadows. Her hair moves like water, and time seems to stand still.",
68
+ negative_prompt="text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
69
+ height=736,
70
+ width=1280,
71
+ num_frames=121,
72
+ num_inference_steps=50,
73
+ frame_rate=24,
74
+ use_linear_quadratic_schedule=False,
75
+ )
76
+ export_to_video(output.frames[0], "output.mp4", fps=24)
77
+ ```
78
+
79
+ ## SageAttention (Optional, ~1.6× faster)
80
+
81
+ Same prompt and seed, 1280x736, 121 frames, 50 steps. Left = SDPA, Right = SageAttention.
82
+
83
+ ![BF16](../assets/sage_compare_BF16.webp)
84
+ ![Q8_0](../assets/sage_compare_Q8_0.webp)
85
+ ![Q5_K_M](../assets/sage_compare_Q5_K_M.webp)
86
+ ![Q4_K_M](../assets/sage_compare_Q4_K_M.webp)
87
+
88
+ [SageAttention](https://github.com/thu-ml/SageAttention) accelerates attention by quantizing Q/K to INT8 and V to FP8, reducing memory bandwidth. Works with all GGUF variants.
89
+
90
+ **Install** (build from source — PyPI only has 1.x, need 2.x):
91
+
92
+ ```bash
93
+ # Set TORCH_CUDA_ARCH_LIST to match your GPU: "8.0" for A100, "9.0" for H100/H200
94
+ TORCH_CUDA_ARCH_LIST="9.0" pip install git+https://github.com/thu-ml/SageAttention.git --no-build-isolation
95
+ ```
96
+
97
+ **Usage with `inference.py`:**
98
+
99
+ ```bash
100
+ python inference.py --use-sage-attention --prompt "..."
101
+ ```
102
+
103
+ **Notes:**
104
+ - Requires NVIDIA GPU with SM70+
105
+ - SM90+ (H100, H200) — FP8 kernels for maximum speedup
106
+ - SM80-SM89 (A100, RTX 3090, RTX 4090) — FP16 kernels (still faster than SDPA)
107
+ - SM70-SM75 (V100, RTX 2080 Ti) — FP16 kernels
108
+ - Set `TORCH_CUDA_ARCH_LIST` to match your GPU when building (e.g., `"8.6"` for RTX 3090, `"8.9"` for RTX 4090)
109
+ - No quality degradation observed across all GGUF variants
110
+
111
+ ## Benchmark
112
+
113
+ Measured on NVIDIA H200, 1280x736, 121 frames, 50 steps, DPMSolver++ (order=2, flow_shift=15.0):
114
+
115
+ | Variant | SDPA (s/it) | Sage (s/it) | Speedup | Peak alloc (GB) | Peak rsv (GB) | Total SDPA (s) | Total Sage (s) |
116
+ |---------|------------|------------|---------|-----------------|----------------|----------------|----------------|
117
+ | BF16 | 23.36 | 14.75 | 1.58x | 14.78 / 15.12 | 24.93 / 24.90 | 1184 | 754 |
118
+ | Q8_0 | 23.16 | 14.49 | 1.60x | 13.10 / 13.44 | 23.14 / 23.11 | 1178 | 744 |
119
+ | Q6_K | 23.21 | 14.55 | 1.60x | 12.62 / 12.95 | 22.72 / 22.69 | 1178 | 747 |
120
+ | Q5_K_M | 23.33 | 14.69 | 1.59x | 12.39 / 12.72 | 22.45 / 22.42 | 1184 | 754 |
121
+ | Q5_1 | 23.54 | 14.96 | 1.57x | 12.47 / 12.81 | 22.66 / 22.62 | 1193 | 764 |
122
+ | Q5_0 | 23.26 | 14.67 | 1.59x | 12.37 / 12.71 | 22.55 / 22.52 | 1179 | 750 |
123
+ | Q4_K_M | 23.25 | 14.59 | 1.60x | 12.19 / 12.53 | 22.22 / 22.18 | 1178 | 747 |
124
+ | Q4_1 | 23.31 | 14.68 | 1.59x | 12.26 / 12.60 | 22.26 / 22.22 | 1181 | 750 |
125
+ | Q4_0 | 23.33 | 14.75 | 1.58x | 12.14 / 12.47 | 22.18 / 22.14 | 1188 | 760 |
126
+
127
+ Peak alloc/rsv columns show SDPA / Sage values. Sage adds ~0.3 GB alloc overhead (INT8/FP8 quantization buffers) with no change in reserved memory.
128
+
129
+ **Key findings:**
130
+ - **~1.59x faster with SageAttention** — consistent across all quantization levels
131
+ - **VRAM unchanged** — sage overhead is negligible (~0.3 GB alloc)
132
+ - **GGUF + Sage stacks** — Q4_K_M + Sage achieves 14.59 s/it at 12.53 GB alloc (vs BF16 SDPA: 23.36 s/it at 14.78 GB)
docs/memory-efficient-inference.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Memory-efficient Inference
2
+
3
+ > See the main [README](../README.md) for `FlowDPMSolver` and `guider` setup.
4
+
5
+ By default, `pipe.to("cuda")` loads all components onto the GPU simultaneously, requiring **~30 GB VRAM**.
6
+
7
+ For GPUs with 24 GB or less (e.g. RTX 4090, RTX 3090), use `enable_model_cpu_offload()` with the `expandable_segments` allocator setting:
8
+
9
+ ```bash
10
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
11
+ ```
12
+
13
+ ```python
14
+ pipe = MotifVideoPipeline.from_pretrained(
15
+ "Motif-Technologies/Motif-Video-2B",
16
+ revision="diffusers-integration",
17
+ torch_dtype=torch.bfloat16,
18
+ guider=guider, # see T2V example above
19
+ )
20
+ pipe.scheduler = FlowDPMSolver(
21
+ num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
22
+ algorithm_type="dpmsolver++",
23
+ solver_order=2,
24
+ prediction_type="flow_prediction",
25
+ use_flow_sigmas=True,
26
+ flow_shift=15.0,
27
+ )
28
+ pipe.enable_model_cpu_offload() # replaces pipe.to("cuda")
29
+
30
+ output = pipe(
31
+ prompt="...",
32
+ negative_prompt="...",
33
+ height=736, width=1280, num_frames=121, num_inference_steps=50,
34
+ frame_rate=24, use_linear_quadratic_schedule=False,
35
+ )
36
+ export_to_video(output.frames[0], "output.mp4", fps=24)
37
+ ```
38
+
39
+ This moves each component (text encoder → transformer → VAE) to GPU only when needed. The `expandable_segments` setting allows the CUDA memory allocator to efficiently reuse memory released by earlier components, avoiding fragmentation-related OOM errors.
40
+
41
+ | Mode | Peak VRAM | Speed | Recommended GPU |
42
+ |------|-----------|-------|-----------------|
43
+ | `pipe.to("cuda")` | ~30 GB | Fastest | A100, H100, H200 |
44
+ | `enable_model_cpu_offload()` | ~19 GB | Similar | RTX 4090, RTX 3090 |
45
+
46
+ ## FP8 Weight Quantization (Optional)
47
+
48
+ For further VRAM reduction, you can quantize the transformer weights to FP8 using [torchao](https://github.com/pytorch/ao):
49
+
50
+ ```bash
51
+ pip install torchao
52
+ ```
53
+
54
+ ```python
55
+ from torchao.quantization import quantize_, Float8WeightOnlyConfig
56
+
57
+ pipe = MotifVideoPipeline.from_pretrained(
58
+ "Motif-Technologies/Motif-Video-2B",
59
+ revision="diffusers-integration",
60
+ torch_dtype=torch.bfloat16,
61
+ guider=guider, # see T2V example above
62
+ )
63
+ pipe.scheduler = FlowDPMSolver(
64
+ num_train_timesteps=pipe.scheduler.config.get("num_train_timesteps", 1000),
65
+ algorithm_type="dpmsolver++",
66
+ solver_order=2,
67
+ prediction_type="flow_prediction",
68
+ use_flow_sigmas=True,
69
+ flow_shift=15.0,
70
+ )
71
+ quantize_(pipe.transformer, Float8WeightOnlyConfig())
72
+ pipe.enable_model_cpu_offload()
73
+
74
+ output = pipe(
75
+ prompt="...",
76
+ negative_prompt="...",
77
+ height=736, width=1280, num_frames=121, num_inference_steps=50,
78
+ frame_rate=24, use_linear_quadratic_schedule=False,
79
+ )
80
+ export_to_video(output.frames[0], "output.mp4", fps=24)
81
+ ```
82
+
83
+ This stores the transformer weights in FP8 (8-bit) instead of BF16 (16-bit), reducing peak VRAM from ~19 GB to ~15 GB while keeping all computation in BF16 precision.
84
+
85
+ | Mode | Peak VRAM | Notes |
86
+ |------|-----------|-------|
87
+ | `enable_model_cpu_offload()` | ~19 GB | BF16 baseline |
88
+ | `+ Float8WeightOnlyConfig` | ~15 GB | FP8 weights, BF16 compute |
inference.py DELETED
@@ -1,119 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Motif-Video 2B — Text-to-Video & Image-to-Video inference.
3
-
4
- GPU requirements: ~24GB VRAM for 720p (1280x736, 121 frames).
5
- Tested with: torch>=2.0, diffusers>=0.35.2, transformers>=5.0.0
6
-
7
- Uses Adaptive Projected Guidance (APG) by default for best quality.
8
- """
9
-
10
- import argparse
11
-
12
- import torch
13
- from diffusers import AdaptiveProjectedGuidance, DiffusionPipeline
14
- from diffusers.utils import export_to_video
15
-
16
-
17
- def parse_args():
18
- parser = argparse.ArgumentParser(description="Motif-Video 2B Inference (T2V / I2V)")
19
- parser.add_argument(
20
- "--model-path",
21
- type=str,
22
- default="Motif-Technologies/Motif-Video-2B",
23
- help="HuggingFace model ID or local checkpoint path (uses trust_remote_code=True)",
24
- )
25
- parser.add_argument(
26
- "--prompt",
27
- type=str,
28
- default="A category-five hurricane, viewed from inside the eye, reveals a circular stadium of cloud walls rising to fifty thousand feet with an eerie disk of blue sky directly overhead. Shot from a NOAA reconnaissance aircraft mounted camera, the perspective looks outward toward the eyewall — a near-vertical curtain of rotating cloud and lightning that is simultaneously terrifying and transcendent. The inner surface of the eyewall catches the setting sun, painting it in improbable shades of peach and rose. The camera slowly pans 360 degrees to complete one full revolution, capturing the entire coliseum of the storm. Below, the ocean surface is a white blur of foam and spray. The documentary-style cinematography strips away all artifice to present the storm as an entity of pure elemental power.",
29
- help="Text prompt for video generation",
30
- )
31
- parser.add_argument(
32
- "--image",
33
- type=str,
34
- default=None,
35
- help="Path to input image for I2V mode (omit for T2V)",
36
- )
37
- parser.add_argument(
38
- "--negative-prompt",
39
- type=str,
40
- default=None,
41
- help="Negative prompt (default: built-in pipeline default)",
42
- )
43
- parser.add_argument("--output", type=str, default="output.mp4", help="Output video file path")
44
- parser.add_argument("--num-frames", type=int, default=121, help="Number of frames to generate (121 = ~5s at 24fps)")
45
- parser.add_argument("--height", type=int, default=736, help="Video height in pixels")
46
- parser.add_argument("--width", type=int, default=1280, help="Video width in pixels")
47
- parser.add_argument("--guidance-scale", type=float, default=8.0, help="Classifier-free guidance scale")
48
- parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of denoising steps")
49
- parser.add_argument("--fps", type=int, default=24, help="Output video frame rate")
50
- parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
51
- parser.add_argument(
52
- "--dtype",
53
- type=str,
54
- default="bfloat16",
55
- choices=["float16", "bfloat16", "float32"],
56
- help="Model dtype",
57
- )
58
- return parser.parse_args()
59
-
60
-
61
- def main():
62
- args = parse_args()
63
-
64
- dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
65
- torch_dtype = dtype_map[args.dtype]
66
-
67
- mode = "I2V" if args.image else "T2V"
68
- print(f"[{mode}] Loading model from: {args.model_path}")
69
-
70
- guider = AdaptiveProjectedGuidance(
71
- guidance_scale=args.guidance_scale,
72
- adaptive_projected_guidance_rescale=12.0,
73
- adaptive_projected_guidance_momentum=0.1,
74
- eta=0.0,
75
- use_original_formulation=True,
76
- )
77
-
78
- pipe = DiffusionPipeline.from_pretrained(
79
- args.model_path,
80
- custom_pipeline="pipeline_motif_video",
81
- trust_remote_code=True,
82
- torch_dtype=torch_dtype,
83
- guider=guider,
84
- )
85
- pipe = pipe.to("cuda")
86
-
87
- generator = torch.Generator(device="cuda").manual_seed(args.seed)
88
-
89
- # Load image for I2V mode
90
- image = None
91
- if args.image:
92
- from PIL import Image
93
-
94
- image = Image.open(args.image).convert("RGB")
95
- print(f"[I2V] Input image: {args.image} ({image.size[0]}x{image.size[1]})")
96
-
97
- print(f"Generating video: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps")
98
- pipe_kwargs = dict(
99
- prompt=args.prompt,
100
- image=image,
101
- height=args.height,
102
- width=args.width,
103
- num_frames=args.num_frames,
104
- num_inference_steps=args.num_inference_steps,
105
- generator=generator,
106
- frame_rate=args.fps,
107
- )
108
- if args.negative_prompt is not None:
109
- pipe_kwargs["negative_prompt"] = args.negative_prompt
110
-
111
- output = pipe(**pipe_kwargs)
112
-
113
- video_frames = output.frames[0]
114
- export_to_video(video_frames, args.output, fps=args.fps)
115
- print(f"Video saved to: {args.output}")
116
-
117
-
118
- if __name__ == "__main__":
119
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_index.json CHANGED
@@ -14,7 +14,7 @@
14
  "GemmaTokenizer"
15
  ],
16
  "transformer": [
17
- "transformer_motif_video",
18
  "MotifVideoTransformer3DModel"
19
  ],
20
  "vae": [
 
14
  "GemmaTokenizer"
15
  ],
16
  "transformer": [
17
+ "diffusers",
18
  "MotifVideoTransformer3DModel"
19
  ],
20
  "vae": [
pipeline_motif_video.py DELETED
@@ -1,1388 +0,0 @@
1
- # Copyright 2026 Motif Technologies, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import html
16
- import inspect
17
- from dataclasses import dataclass
18
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
-
20
- import ftfy
21
- import numpy as np
22
- import regex as re
23
- import torch
24
- from diffusers import (
25
- AdaptiveProjectedGuidance,
26
- AutoencoderKLWan,
27
- ClassifierFreeGuidance,
28
- DiffusionPipeline,
29
- DPMSolverMultistepScheduler,
30
- FlowMatchEulerDiscreteScheduler,
31
- SkipLayerGuidance,
32
- UniPCMultistepScheduler,
33
- )
34
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
35
- from diffusers.guiders.adaptive_projected_guidance import MomentumBuffer
36
- from diffusers.guiders.guider_utils import GuiderOutput
37
- from diffusers.utils import (
38
- BaseOutput,
39
- is_torch_xla_available,
40
- logging,
41
- replace_example_docstring,
42
- )
43
- from diffusers.utils.torch_utils import randn_tensor
44
- from diffusers.video_processor import VideoProcessor
45
- from einops import rearrange
46
- from PIL import Image
47
- from torch import Tensor
48
-
49
- from transformers import (
50
- BatchEncoding,
51
- PreTrainedTokenizerBase,
52
- SiglipImageProcessor,
53
- T5Gemma2Encoder,
54
- )
55
-
56
- from ._fm_solvers_unipc import FlowUniPCMultistepScheduler
57
-
58
-
59
- if is_torch_xla_available():
60
- import torch_xla.core.xla_model as xm
61
-
62
- XLA_AVAILABLE = True
63
- else:
64
- XLA_AVAILABLE = False
65
-
66
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
67
-
68
- EXAMPLE_DOC_STRING = """
69
- Examples:
70
- ```py
71
- >>> import torch
72
- >>> from diffusers import MotifVideoPipeline
73
- >>> from diffusers.utils import export_to_video
74
-
75
- >>> # Load the Motif Video pipeline
76
- >>> motif_video_model_id = "MotifTechnologies/Motif-Video"
77
- >>> pipe = MotifVideoPipeline.from_pretrained(motif_video_model_id, torch_dtype=torch.bfloat16)
78
- >>> pipe.to("cuda")
79
-
80
- >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
81
- >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
82
-
83
- >>> video = pipe(
84
- ... prompt=prompt,
85
- ... negative_prompt=negative_prompt,
86
- ... width=640,
87
- ... height=352,
88
- ... num_frames=65,
89
- ... num_inference_steps=50,
90
- ... ).frames[0]
91
- >>> export_to_video(video, "output.mp4", fps=16)
92
- ```
93
- """
94
-
95
-
96
- @dataclass
97
- class MotifVideoPipelineOutput(BaseOutput):
98
- r"""
99
- Output class for Motif Video pipelines.
100
-
101
- Args:
102
- frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
103
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
104
- denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
105
- `(batch_size, num_frames, channels, height, width)`.
106
- """
107
-
108
- frames: torch.Tensor
109
-
110
-
111
- """Video-aware Adaptive Projected Guidance (APG).
112
-
113
- Standard APG normalizes over all spatial dimensions [C, T, H, W], which collapses
114
- temporal variation. This module normalizes over [C, H, W] only, preserving
115
- per-frame independence.
116
- """
117
-
118
-
119
- def video_normalized_guidance(
120
- pred_cond: torch.Tensor,
121
- pred_uncond: torch.Tensor,
122
- guidance_scale: float,
123
- momentum_buffer: MomentumBuffer | None = None,
124
- eta: float = 1.0,
125
- norm_threshold: float = 0.0,
126
- use_original_formulation: bool = False,
127
- ) -> torch.Tensor:
128
- """APG with video-aware normalization: normalize over [C, H, W], exclude T.
129
-
130
- For 5D input [B, C, T, H, W], dim=[-1, -2, -4] normalizes per-frame (W, H, C),
131
- keeping the T dimension independent. For 4D input [B, C, H, W], falls back to
132
- standard [-1, -2, -3] behavior.
133
- """
134
- diff = pred_cond - pred_uncond
135
-
136
- if len(diff.shape) == 5:
137
- # [B, C, T, H, W] → normalize over W(-1), H(-2), C(-4), skip T(-3)
138
- dim = [-1, -2, -4]
139
- else:
140
- # [B, C, H, W] → standard behavior
141
- dim = [-i for i in range(1, len(diff.shape))]
142
-
143
- if momentum_buffer is not None:
144
- momentum_buffer.update(diff)
145
- diff = momentum_buffer.running_average
146
-
147
- if norm_threshold > 0:
148
- ones = torch.ones_like(diff)
149
- diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
150
- scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
151
- diff = diff * scale_factor
152
-
153
- v0, v1 = diff.double(), pred_cond.double()
154
- v1 = torch.nn.functional.normalize(v1, dim=dim)
155
- v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
156
- v0_orthogonal = v0 - v0_parallel
157
- diff_parallel, diff_orthogonal = (
158
- v0_parallel.type_as(diff),
159
- v0_orthogonal.type_as(diff),
160
- )
161
- normalized_update = diff_orthogonal + eta * diff_parallel
162
-
163
- pred = pred_cond if use_original_formulation else pred_uncond
164
- pred = pred + guidance_scale * normalized_update
165
-
166
- return pred
167
-
168
-
169
- class VideoAdaptiveProjectedGuidance(AdaptiveProjectedGuidance):
170
- """APG variant that normalizes over [C, H, W] per frame, excluding the T dimension."""
171
-
172
- def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput:
173
- pred = None
174
-
175
- if not self._is_apg_enabled():
176
- pred = pred_cond
177
- else:
178
- pred = video_normalized_guidance(
179
- pred_cond,
180
- pred_uncond,
181
- self.guidance_scale,
182
- self.momentum_buffer,
183
- self.eta,
184
- self.adaptive_projected_guidance_rescale,
185
- self.use_original_formulation,
186
- )
187
-
188
- if self.guidance_rescale > 0.0:
189
- from diffusers.guiders.classifier_free_guidance import rescale_noise_cfg
190
-
191
- pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
192
-
193
- return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
194
-
195
-
196
- # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
197
- def calculate_shift(
198
- image_seq_len,
199
- base_seq_len: int = 256,
200
- max_seq_len: int = 4096,
201
- base_shift: float = 0.5,
202
- max_shift: float = 1.15,
203
- ):
204
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
205
- b = base_shift - m * base_seq_len
206
- mu = image_seq_len * m + b
207
- return mu
208
-
209
-
210
- def get_linear_quadratic_sigmas(
211
- num_inference_steps: int,
212
- linear_quadratic_emulating_steps: int = 250,
213
- ) -> np.ndarray:
214
- """
215
- Compute a linear-quadratic sigma schedule for flow matching.
216
-
217
- This schedule combines:
218
- - First half: Linear interpolation from high noise to medium noise (slow denoising)
219
- - Second half: Quadratic interpolation from medium noise to clean (faster denoising)
220
-
221
- Convention:
222
- - sigma=1.0 represents pure noise
223
- - sigma=0.0 represents clean image
224
- - Output sigmas are in descending order (1.0 → ~0)
225
-
226
- Args:
227
- num_inference_steps: Total number of denoising steps (must be even).
228
- linear_quadratic_emulating_steps: Controls the slope of linear interpolation.
229
- Higher values result in gentler slope in the first half.
230
-
231
- Returns:
232
- np.ndarray: Array of sigma values with shape (num_inference_steps,).
233
- The scheduler will append a terminal 0.
234
-
235
- Raises:
236
- ValueError: If num_inference_steps is not even.
237
-
238
- Reference:
239
- Linear-quadratic timestep schedule for improved flow matching inference.
240
- """
241
- if num_inference_steps % 2 != 0:
242
- raise ValueError(
243
- f"num_inference_steps must be even for linear-quadratic schedule, but got {num_inference_steps}"
244
- )
245
-
246
- steps = num_inference_steps
247
- N = linear_quadratic_emulating_steps
248
- half_steps = steps // 2
249
-
250
- # First half: linear interpolation from 1 toward 0
251
- # Takes first half_steps values from linspace(1, 0, N+1)
252
- linear_part = np.linspace(1.0, 0.0, N + 1)[:half_steps]
253
-
254
- # Second half: quadratic interpolation
255
- # Formula: x^2 * (half_steps/N - 1) - (half_steps/N - 1)
256
- # = (half_steps/N - 1) * (x^2 - 1)
257
- # This maps x=0 to (half_steps/N - 1) * (-1) = 1 - half_steps/N
258
- # and maps x=1 to 0
259
- x = np.linspace(0.0, 1.0, half_steps + 1)
260
- scale_factor = half_steps / N - 1 # negative value
261
- quadratic_part = x**2 * scale_factor - scale_factor
262
-
263
- # Concatenate and exclude the last 0 (scheduler appends terminal 0)
264
- sigmas = np.concatenate([linear_part, quadratic_part])
265
- sigmas = sigmas[:-1] # Remove trailing 0, scheduler will append it
266
-
267
- return sigmas.astype(np.float32)
268
-
269
-
270
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
271
- def retrieve_timesteps(
272
- scheduler,
273
- num_inference_steps: Optional[int] = None,
274
- device: Optional[Union[str, torch.device]] = None,
275
- timesteps: Optional[List[int]] = None,
276
- sigmas: Optional[List[float]] = None,
277
- use_linear_quadratic_schedule: bool = False,
278
- linear_quadratic_emulating_steps: int = 250,
279
- **kwargs,
280
- ):
281
- """
282
- Retrieve timesteps from the scheduler.
283
-
284
- Args:
285
- scheduler: The noise scheduler to use.
286
- num_inference_steps: Number of denoising steps.
287
- device: Device to place timesteps on.
288
- timesteps: Custom timestep values (mutually exclusive with sigmas).
289
- sigmas: Custom sigma values (mutually exclusive with timesteps).
290
- use_linear_quadratic_schedule: If True, use linear-quadratic sigma schedule.
291
- This overrides the default linear schedule. Requires num_inference_steps
292
- to be even.
293
- linear_quadratic_emulating_steps: Controls the linear portion slope.
294
- Higher values result in gentler slope in the first half. Default: 250.
295
- **kwargs: Additional arguments passed to scheduler.set_timesteps().
296
-
297
- Returns:
298
- Tuple of (timesteps, num_inference_steps).
299
-
300
- Raises:
301
- ValueError: If both timesteps and sigmas are provided, or if
302
- use_linear_quadratic_schedule is True but num_inference_steps is odd.
303
- """
304
- if timesteps is not None and sigmas is not None:
305
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
306
-
307
- # Handle linear-quadratic schedule: compute sigmas if flag is set
308
- if use_linear_quadratic_schedule:
309
- if sigmas is not None:
310
- raise ValueError(
311
- "Cannot use both `sigmas` and `use_linear_quadratic_schedule`. "
312
- "The linear-quadratic schedule computes sigmas automatically."
313
- )
314
- if num_inference_steps is None:
315
- raise ValueError("`num_inference_steps` must be provided when using `use_linear_quadratic_schedule`.")
316
- sigmas = get_linear_quadratic_sigmas(
317
- num_inference_steps=num_inference_steps,
318
- linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
319
- )
320
-
321
- if timesteps is not None:
322
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
323
- if not accepts_timesteps:
324
- raise ValueError(
325
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
326
- f" timestep schedules. Please check whether you are using the correct scheduler."
327
- )
328
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
329
- timesteps = scheduler.timesteps
330
- num_inference_steps = len(timesteps)
331
- elif sigmas is not None:
332
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
333
- if not accept_sigmas:
334
- raise ValueError(
335
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
336
- f" sigmas schedules. Please check whether you are using the correct scheduler."
337
- )
338
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
339
- timesteps = scheduler.timesteps
340
- num_inference_steps = len(timesteps)
341
- else:
342
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
343
- timesteps = scheduler.timesteps
344
- return timesteps, num_inference_steps
345
-
346
-
347
- def basic_clean(text):
348
- text = ftfy.fix_text(text)
349
- text = html.unescape(html.unescape(text))
350
- return text.strip()
351
-
352
-
353
- def whitespace_clean(text):
354
- text = re.sub(r"\s+", " ", text)
355
- text = text.strip()
356
- return text
357
-
358
-
359
- def prompt_clean(text):
360
- text = whitespace_clean(basic_clean(text))
361
- return text
362
-
363
-
364
- class MotifVideoPipeline(DiffusionPipeline):
365
- r"""
366
- Pipeline for text-to-video generation using MotifVideoTransformer.
367
-
368
- Args:
369
- transformer ([`MotifVideoTransformer3DModel`]):
370
- Conditional Transformer architecture to denoise the encoded video latents.
371
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
372
- A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
373
- vae ([`AutoencoderKLWan`]):
374
- Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
375
- text_encoder ([`T5Gemma2Encoder`]):
376
- Primary text encoder for encoding text prompts into embeddings.
377
- tokenizer ([`PreTrainedTokenizerBase`]):
378
- Tokenizer corresponding to the primary text encoder.
379
- guider ([`ClassifierFreeGuidance`] or [`SkipLayerGuidance`] or [`AdaptiveProjectedGuidance`] or [`VideoAdaptiveProjectedGuidance`], *optional*):
380
- The guidance method to use. If `None`, it defaults to `ClassifierFreeGuidance()`.
381
- """
382
-
383
- model_cpu_offload_seq = "text_encoder->transformer->vae"
384
- _optional_components = ["feature_extractor"]
385
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
386
-
387
- def __init__(
388
- self,
389
- scheduler: Union[
390
- FlowMatchEulerDiscreteScheduler,
391
- DPMSolverMultistepScheduler,
392
- UniPCMultistepScheduler,
393
- FlowUniPCMultistepScheduler,
394
- ],
395
- vae: AutoencoderKLWan,
396
- text_encoder: T5Gemma2Encoder,
397
- tokenizer: PreTrainedTokenizerBase,
398
- transformer,
399
- guider: Optional[
400
- Union[
401
- ClassifierFreeGuidance,
402
- SkipLayerGuidance,
403
- AdaptiveProjectedGuidance,
404
- VideoAdaptiveProjectedGuidance,
405
- ]
406
- ] = None,
407
- feature_extractor: Optional[SiglipImageProcessor] = None,
408
- ):
409
- super().__init__()
410
-
411
- self.guider = ClassifierFreeGuidance() if guider is None else guider
412
-
413
- self.register_modules(
414
- vae=vae,
415
- text_encoder=text_encoder,
416
- tokenizer=tokenizer,
417
- transformer=transformer,
418
- scheduler=scheduler,
419
- feature_extractor=feature_extractor,
420
- )
421
-
422
- self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
423
- self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
424
-
425
- self.transformer_spatial_patch_size = (
426
- self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 2
427
- )
428
- self.transformer_temporal_patch_size = (
429
- self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
430
- )
431
-
432
- self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
433
- self.tokenizer_max_length = (
434
- self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 512
435
- )
436
-
437
- def _get_default_embeds(
438
- self,
439
- text_encoder,
440
- tokenizer: PreTrainedTokenizerBase,
441
- prompt: Union[str, List[str]],
442
- max_sequence_length: int = 512,
443
- device: Optional[torch.device] = None,
444
- dtype: Optional[torch.dtype] = None,
445
- ) -> Tuple[torch.Tensor, torch.Tensor]:
446
- dtype = dtype or text_encoder.dtype
447
-
448
- text_inputs = tokenizer(
449
- prompt,
450
- padding="max_length",
451
- max_length=max_sequence_length,
452
- truncation=True,
453
- add_special_tokens=True,
454
- return_attention_mask=True,
455
- return_tensors="pt",
456
- )
457
- text_inputs = BatchEncoding(
458
- {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in text_inputs.items()}
459
- )
460
-
461
- prompt_embeds = text_encoder(**text_inputs)[0]
462
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
463
-
464
- return prompt_embeds, text_inputs.attention_mask
465
-
466
- def _average_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
467
- last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
468
- denom = attention_mask.sum(dim=1, keepdim=True).clamp(min=1) # avoid div by zero
469
- return last_hidden.sum(dim=1) / denom
470
-
471
- def _get_prompt_embeds(
472
- self,
473
- text_encoder: T5Gemma2Encoder,
474
- tokenizer: PreTrainedTokenizerBase,
475
- prompt: Union[str, List[str]] | None = None,
476
- num_videos_per_prompt: int = 1,
477
- max_sequence_length: int = 512,
478
- device: Optional[torch.device] = None,
479
- dtype: Optional[torch.dtype] = None,
480
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
481
- device = device or self._execution_device
482
-
483
- prompt = [prompt] if isinstance(prompt, str) else prompt
484
-
485
- prompt_embeds_kwargs = {
486
- "text_encoder": text_encoder,
487
- "tokenizer": tokenizer,
488
- "prompt": prompt,
489
- "max_sequence_length": max_sequence_length,
490
- "device": device,
491
- "dtype": dtype,
492
- }
493
- # When enable_model_cpu_offload() is active, the accelerate forward hook is on text_encoder (parent). Moving the encoder to the execution device explicitly ensures inputs and
494
- # weights are on the same device. The parent's offload hook will move text_encoder back to CPU after
495
- # the next component claims the GPU.
496
- if next(text_encoder.parameters()).device != torch.device(device):
497
- text_encoder.to(device)
498
- prompt_embeds, prompt_attention_mask = self._get_default_embeds(**prompt_embeds_kwargs)
499
-
500
- pooled_prompt_embeds = self._average_pool(prompt_embeds, prompt_attention_mask)
501
-
502
- return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
503
-
504
- # Keep encode_prompt structure, uses _get_prompt_embeds internally
505
- def encode_prompt(
506
- self,
507
- prompt: Union[str, List[str]],
508
- num_videos_per_prompt: int = 1,
509
- prompt_embeds: Optional[torch.Tensor] = None,
510
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
511
- prompt_attention_mask: Optional[torch.Tensor] = None,
512
- max_sequence_length: int = 512,
513
- device: Optional[torch.device] = None,
514
- dtype: Optional[torch.dtype] = None,
515
- ) -> Tuple[
516
- torch.Tensor,
517
- torch.Tensor,
518
- torch.Tensor,
519
- ]:
520
- device = device or self._execution_device
521
-
522
- prompt = [prompt] if isinstance(prompt, str) else prompt
523
- if prompt is not None:
524
- batch_size = len(prompt)
525
- else:
526
- batch_size = prompt_embeds.shape[0]
527
-
528
- prompt_embeds_kwargs = {
529
- "device": device,
530
- "dtype": dtype,
531
- }
532
-
533
- if prompt_embeds is None:
534
- prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self._get_prompt_embeds(
535
- text_encoder=self.text_encoder,
536
- tokenizer=self.tokenizer,
537
- prompt=prompt,
538
- max_sequence_length=max_sequence_length,
539
- **prompt_embeds_kwargs,
540
- )
541
-
542
- # Compute actual (non-padding) token count for batch=1 Flash Attention trimming in __call__
543
- actual_seq_len = None
544
- if batch_size == 1 and prompt_attention_mask is not None:
545
- actual_seq_len = int(prompt_attention_mask.sum(dim=-1).max().item())
546
-
547
- # duplicate text embeddings for each generation per prompt, using mps friendly method
548
- seq_len = prompt_embeds.shape[1]
549
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
550
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
551
-
552
- if pooled_prompt_embeds is not None:
553
- pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0)
554
-
555
- if prompt_attention_mask is not None:
556
- prompt_attention_mask = prompt_attention_mask.bool()
557
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
558
- prompt_attention_mask = prompt_attention_mask.repeat_interleave(num_videos_per_prompt, dim=0)
559
-
560
- return (
561
- prompt_embeds,
562
- pooled_prompt_embeds,
563
- prompt_attention_mask,
564
- actual_seq_len,
565
- )
566
-
567
- @property
568
- def vision_encoder(self):
569
- """Get the vision encoder from T5Gemma2.
570
-
571
- T5Gemma2 has vision_tower.vision_model structure.
572
- Will raise AttributeError if not available.
573
- """
574
- return self.text_encoder.vision_tower.vision_model
575
-
576
- def encode_image(
577
- self,
578
- image: Image.Image,
579
- batch_size: int = 1,
580
- device: Optional[torch.device] = None,
581
- dtype: Optional[torch.dtype] = None,
582
- ) -> torch.Tensor:
583
- """Encode image to embeddings using SigLIP vision encoder."""
584
- device = device or self._execution_device
585
- dtype = dtype or self.transformer.dtype
586
-
587
- image_embeds = self._get_image_embeds(
588
- image_encoder=self.vision_encoder,
589
- feature_extractor=self.feature_extractor,
590
- image=image,
591
- device=device,
592
- )
593
- image_embeds = image_embeds.repeat(batch_size, 1, 1)
594
- return image_embeds.to(device=device, dtype=dtype)
595
-
596
- @staticmethod
597
- def _get_image_embeds(
598
- image_encoder,
599
- feature_extractor: SiglipImageProcessor,
600
- image,
601
- device: torch.device,
602
- ) -> torch.Tensor:
603
- """Helper to encode single image with SigLIP.
604
-
605
- Args:
606
- image_encoder: The SigLIP vision encoder model.
607
- feature_extractor: SiglipImageProcessor for preprocessing.
608
- image: Can be either:
609
- - PIL.Image.Image: Will be preprocessed by feature_extractor
610
- - torch.Tensor: Assumed to be in [0, 1] range, will be normalized and passed to encoder
611
- device: Device to place tensors on.
612
-
613
- Returns:
614
- Image embeddings from the vision encoder.
615
- """
616
- image_encoder_dtype = next(image_encoder.parameters()).dtype
617
-
618
- if isinstance(image, torch.Tensor):
619
- image = feature_extractor.preprocess(
620
- images=image.float(),
621
- do_resize=True,
622
- do_rescale=False,
623
- do_normalize=True,
624
- do_convert_rgb=True,
625
- return_tensors="pt",
626
- )
627
- else:
628
- image = feature_extractor.preprocess(
629
- images=image,
630
- do_resize=True,
631
- do_rescale=False,
632
- do_normalize=True,
633
- do_convert_rgb=True,
634
- return_tensors="pt",
635
- )
636
-
637
- image = image.to(device, dtype=image_encoder_dtype)
638
- return image_encoder(**image).last_hidden_state
639
-
640
- @torch.compiler.disable
641
- def _prepare_first_frame_conditioning(
642
- self,
643
- video: torch.Tensor,
644
- latents: torch.Tensor,
645
- use_conditioning: bool,
646
- generator: Optional[torch.Generator] = None,
647
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
648
- """Prepare first frame conditioning tensors.
649
-
650
- This method implements batch-level conditioning where entire
651
- batches are either I2V (all samples conditioned) or T2V (no conditioning). This
652
- prevents mode confusion within batches.
653
-
654
- For I2V mode:
655
- 1. Extract and VAE-encode first frame from video
656
- 2. Create latent_condition by repeating first frame across time (frame 0 only)
657
- 3. Create latent_mask with 1.0 at frame 0
658
- 4. Get image_embeds from vision encoder
659
-
660
- For T2V mode:
661
- 1. Pad with zeros for latent_condition and latent_mask
662
-
663
- Args:
664
- video: Input video tensor [batch_size, frames, channels, height, width] in [-1, 1]
665
- latents: Latents [batch_size, lantent_channels, latent_num_frames, latent_height, latent_width]
666
- use_conditioning: Whether to use first-frame conditioning (True for I2V, False for T2V)
667
- generator: Optional random number generator for reproducibility
668
-
669
- Returns:
670
- Tuple of (latent_condition, latent_mask, image_embeds).
671
- - latent_condition: [B, C, F, H, W] conditioning signal (zeros for T2V)
672
- - latent_mask: [B, 1, F, H, W] binary mask (zeros for T2V)
673
- - image_embeds: [B, N, D] image embeddings from vision encoder or None for T2V
674
- """
675
- batch_size, lantent_channels, latent_num_frames, latent_height, latent_width = latents.shape
676
- device = latents.device
677
- dtype = latents.dtype
678
-
679
- # Determine if we should use conditioning
680
- use_conditioning = use_conditioning and (latent_num_frames > 1)
681
-
682
- # Initialize conditioning tensors
683
- latent_condition = torch.zeros(
684
- batch_size,
685
- lantent_channels,
686
- latent_num_frames,
687
- latent_height,
688
- latent_width,
689
- device=device,
690
- dtype=dtype,
691
- )
692
- latent_mask = torch.zeros(
693
- batch_size,
694
- 1,
695
- latent_num_frames,
696
- latent_height,
697
- latent_width,
698
- device=device,
699
- dtype=dtype,
700
- )
701
- image_embeds = None
702
-
703
- if use_conditioning:
704
- with torch.no_grad():
705
- # Encode first frame for latent_condition
706
- first_frame_latents = self.vae.encode(
707
- rearrange(video[:, 0:1], "b f c h w -> b c f h w")
708
- ).latent_dist.sample(generator=generator)
709
- first_frame_latents = self._normalize_latents(
710
- latents=first_frame_latents,
711
- latents_mean=self.vae.config.latents_mean,
712
- latents_std=self.vae.config.latents_std,
713
- )
714
-
715
- # Create latent_condition by repeating first frame across time
716
- latent_condition = first_frame_latents.repeat(1, 1, latent_num_frames, 1, 1)
717
- latent_condition[:, :, 1:, :, :] = 0
718
-
719
- # latent_mask: 1.0 at frame 0, 0.0 elsewhere
720
- latent_mask[:, :, 0] = 1.0
721
-
722
- # image_embeds from vision encoder
723
- first_frame_vision = video[:, 0] # [B, C, H, W]
724
- first_frame_vision = ((first_frame_vision + 1) / 2).clamp(0, 1)
725
-
726
- with torch.no_grad():
727
- image_embeds = self._get_image_embeds(
728
- image_encoder=self.vision_encoder,
729
- feature_extractor=self.feature_extractor,
730
- image=first_frame_vision,
731
- device=device,
732
- )
733
-
734
- return latent_condition, latent_mask, image_embeds
735
-
736
- def check_inputs(
737
- self,
738
- prompt,
739
- negative_prompt,
740
- height,
741
- width,
742
- batch_size,
743
- callback_on_step_end_tensor_inputs=None,
744
- prompt_embeds=None,
745
- negative_prompt_embeds=None,
746
- prompt_attention_mask=None,
747
- negative_prompt_attention_mask=None,
748
- ):
749
- # Resolution must be divisible by VAE scale factor * transformer patch size
750
- # (e.g. 8 * 2 = 16 for default config) to avoid latent/patch dimension mismatch.
751
- spatial_divisor = self.vae_scale_factor_spatial * self.transformer_spatial_patch_size
752
- if height % spatial_divisor != 0 or width % spatial_divisor != 0:
753
- raise ValueError(
754
- f"`height` and `width` have to be divisible by {spatial_divisor} "
755
- f"(vae_scale={self.vae_scale_factor_spatial} * patch_size={self.transformer_spatial_patch_size}) "
756
- f"but are {height} and {width}."
757
- )
758
-
759
- if callback_on_step_end_tensor_inputs is not None and not all(
760
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
761
- ):
762
- raise ValueError(
763
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
764
- )
765
-
766
- if prompt is not None and prompt_embeds is not None:
767
- raise ValueError(
768
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
769
- " only forward one of the two."
770
- )
771
- elif prompt is None and prompt_embeds is None:
772
- raise ValueError(
773
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
774
- )
775
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
776
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
777
-
778
- # Validate negative_prompt: must be None, str, or list with matching batch_size
779
- if negative_prompt is not None:
780
- if not isinstance(negative_prompt, (str, list)):
781
- raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
782
- if isinstance(negative_prompt, list) and len(negative_prompt) != batch_size:
783
- raise ValueError(
784
- f"`negative_prompt` list length ({len(negative_prompt)}) must match batch_size ({batch_size})."
785
- )
786
-
787
- if prompt_embeds is not None and prompt_attention_mask is None:
788
- raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
789
-
790
- if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
791
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
792
-
793
- if prompt_embeds is not None and negative_prompt_embeds is not None:
794
- if prompt_embeds.shape != negative_prompt_embeds.shape:
795
- raise ValueError(
796
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
797
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
798
- f" {negative_prompt_embeds.shape}."
799
- )
800
- if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
801
- raise ValueError(
802
- "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
803
- f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
804
- f" {negative_prompt_attention_mask.shape}."
805
- )
806
-
807
- def _prepare_negative_prompt(
808
- self,
809
- negative_prompt: Optional[Union[str, List[str]]],
810
- batch_size: int,
811
- ) -> List[str]:
812
- """
813
- Prepare negative_prompt to match batch_size.
814
-
815
- Args:
816
- negative_prompt: None, a single string, or a list of strings matching batch_size.
817
- batch_size: The number of prompts in the batch.
818
-
819
- Returns:
820
- A list of negative prompts with length equal to batch_size.
821
- """
822
- if negative_prompt is None:
823
- return [""] * batch_size
824
- if isinstance(negative_prompt, str):
825
- return [negative_prompt] * batch_size
826
- return negative_prompt
827
-
828
- @staticmethod
829
- def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
830
- batch_size, num_channels, num_frames, height, width = latents.shape
831
- post_patch_num_frames = num_frames // patch_size_t
832
- post_patch_height = height // patch_size
833
- post_patch_width = width // patch_size
834
- latents = latents.reshape(
835
- batch_size,
836
- -1,
837
- post_patch_num_frames,
838
- patch_size_t,
839
- post_patch_height,
840
- patch_size,
841
- post_patch_width,
842
- patch_size,
843
- )
844
- latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
845
- return latents
846
-
847
- @staticmethod
848
- def _unpack_latents(
849
- latents: torch.Tensor,
850
- num_frames: int,
851
- height: int,
852
- width: int,
853
- patch_size: int = 1,
854
- patch_size_t: int = 1,
855
- ) -> torch.Tensor:
856
- batch_size = latents.size(0)
857
- latents = latents.reshape(
858
- batch_size,
859
- num_frames,
860
- height,
861
- width,
862
- -1,
863
- patch_size_t,
864
- patch_size,
865
- patch_size,
866
- )
867
- latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
868
- return latents
869
-
870
- @staticmethod
871
- def _normalize_latents(
872
- latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
873
- ) -> torch.Tensor:
874
- # Normalize latents across the channel dimension [B, C, F, H, W]
875
- latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
876
- latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
877
- latents = (latents - latents_mean) / latents_std
878
- return latents
879
-
880
- @staticmethod
881
- def _denormalize_latents(
882
- latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
883
- ) -> torch.Tensor:
884
- # Denormalize latents across the channel dimension [B, C, F, H, W]
885
- latents_mean = torch.tensor(latents_mean).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
886
- latents_std = torch.tensor(latents_std).view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
887
- latents = latents * latents_std + latents_mean
888
- return latents
889
-
890
- def prepare_latents(
891
- self,
892
- batch_size: int = 1,
893
- num_channels_latents: int = 16,
894
- height: int = 352,
895
- width: int = 640,
896
- num_frames: int = 65,
897
- dtype: Optional[torch.dtype] = None,
898
- device: Optional[torch.device] = None,
899
- generator: Optional[torch.Generator] = None,
900
- latents: Optional[torch.Tensor] = None,
901
- ) -> torch.Tensor:
902
- if latents is not None:
903
- return latents.to(device=device, dtype=dtype)
904
-
905
- shape = (
906
- batch_size,
907
- num_channels_latents,
908
- (num_frames - 1) // self.vae_scale_factor_temporal + 1,
909
- height // self.vae_scale_factor_spatial,
910
- width // self.vae_scale_factor_spatial,
911
- )
912
-
913
- if isinstance(generator, list) and len(generator) != batch_size:
914
- raise ValueError(
915
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
916
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
917
- )
918
-
919
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
920
- return latents
921
-
922
- @property
923
- def num_timesteps(self):
924
- return self._num_timesteps
925
-
926
- @property
927
- def current_timestep(self):
928
- return self._current_timestep
929
-
930
- @property
931
- def attention_kwargs(self):
932
- return self._attention_kwargs
933
-
934
- @property
935
- def interrupt(self):
936
- return self._interrupt
937
-
938
- @torch.no_grad()
939
- @replace_example_docstring(EXAMPLE_DOC_STRING)
940
- def __call__(
941
- self,
942
- prompt: Union[str, List[str]] | None = None,
943
- image=None,
944
- negative_prompt: Optional[
945
- Union[str, List[str]]
946
- ] = "text overlay, graphic overlay, watermark, logo, subtitles, timestamp, broadcast graphics, UI elements, random letters, frozen pose, rigid, static expression, jerky motion, mechanical motion, discontinuous motion, flat framing, depthless, dull lighting, monotone, crushed shadows, blown-out highlights, shifting background, fading background, poor continuity, identity drift, deformation, flickering, ghosting, smearing, duplication, mutated proportions, inconsistent clothing, flat colors, desaturated, tonally compressed, poor background separation, exposure shift, uneven brightness, color balance shift",
947
- height: int = 736,
948
- width: int = 1280,
949
- num_frames: int = 121,
950
- frame_rate: int = 24,
951
- num_inference_steps: int = 50,
952
- timesteps: List[int] | None = None,
953
- use_linear_quadratic_schedule: bool = False,
954
- linear_quadratic_emulating_steps: int = 250,
955
- num_videos_per_prompt: Optional[int] = 1,
956
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
957
- latents: Optional[torch.Tensor] = None,
958
- prompt_embeds: Optional[torch.Tensor] = None,
959
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
960
- prompt_attention_mask: Optional[torch.Tensor] = None,
961
- negative_prompt_embeds: Optional[torch.Tensor] = None,
962
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
963
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
964
- output_type: Optional[str] = "pil",
965
- return_dict: bool = True,
966
- attention_kwargs: Optional[Dict[str, Any]] = None,
967
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
968
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
969
- max_sequence_length: int = 512,
970
- use_attention_mask: bool = True,
971
- vae_batch_size: int | None = None,
972
- ):
973
- r"""
974
- Function invoked when calling the pipeline for generation.
975
-
976
- Args:
977
- prompt (`str` or `List[str]`, *optional*):
978
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
979
- instead.
980
- negative_prompt (`str` or `List[str]`, *optional*):
981
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
982
- `negative_prompt_embeds` instead. Ignored when not using guidance.
983
- height (`int`, defaults to `352`):
984
- The height in pixels of the generated image.
985
- width (`int`, defaults to `640`):
986
- The width in pixels of the generated image.
987
- num_frames (`int`, defaults to `65`):
988
- The number of video frames to generate
989
- frame_rate (`int`, defaults to `25`):
990
- Frame rate for the output video.
991
- num_inference_steps (`int`, *optional*, defaults to 50):
992
- The number of denoising steps. More denoising steps usually lead to a higher quality video at the
993
- expense of slower inference.
994
- timesteps (`List[int]`, *optional*):
995
- Custom timesteps to use for the denoising process.
996
- use_linear_quadratic_schedule (`bool`, defaults to `True`):
997
- Whether to use a linear-quadratic sigma schedule instead of the default linear schedule.
998
- This schedule combines linear interpolation in the first half (slow denoising at high noise)
999
- with quadratic interpolation in the second half (faster denoising toward clean image).
1000
- Requires `num_inference_steps` to be even.
1001
- linear_quadratic_emulating_steps (`int`, defaults to `250`):
1002
- Controls the slope of linear interpolation in the first half of the linear-quadratic schedule.
1003
- Higher values result in a gentler slope. Only used when `use_linear_quadratic_schedule=True`.
1004
- num_videos_per_prompt (`int`, *optional*, defaults to 1):
1005
- The number of videos to generate per prompt.
1006
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1007
- PyTorch Generator object(s) for deterministic generation.
1008
- latents (`torch.Tensor`, *optional*):
1009
- Pre-generated noisy latents.
1010
- prompt_embeds (`torch.Tensor`, *optional*):
1011
- Pre-generated text embeddings.
1012
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1013
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1014
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
1015
- prompt_attention_mask (`torch.Tensor`, *optional*):
1016
- Pre-generated attention mask for text embeddings.
1017
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1018
- Pre-generated negative text embeddings.
1019
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1020
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1021
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1022
- input argument.
1023
- negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
1024
- Pre-generated attention mask for negative text embeddings.
1025
- output_type (`str`, *optional*, defaults to `"pil"`):
1026
- The output format ("pil" or "np").
1027
- return_dict (`bool`, *optional*, defaults to `True`):
1028
- Whether to return a `MotifVideoPipelineOutput`.
1029
- attention_kwargs (`dict`, *optional*):
1030
- Arguments passed to the attention processor.
1031
- callback_on_step_end (`Callable`, *optional*):
1032
- Callback function called at the end of each step.
1033
- callback_on_step_end_tensor_inputs (`List`, *optional*):
1034
- Tensors to include in the callback.
1035
- max_sequence_length (`int` defaults to `512`):
1036
- Maximum sequence length for the tokenizer.
1037
-
1038
- Examples:
1039
-
1040
- Returns:
1041
- [`~pipelines.motif_video.MotifVideoPipelineOutput`] or `tuple`:
1042
- If `return_dict` is `True`, returns [`~pipelines.motif_video.MotifVideoPipelineOutput`],
1043
- otherwise returns a tuple where the first element is a list of generated video frames.
1044
- """
1045
-
1046
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1047
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1048
-
1049
- # 1. Define call parameters (batch_size needed for check_inputs)
1050
- if prompt is not None and isinstance(prompt, str):
1051
- batch_size = 1
1052
- elif prompt is not None and isinstance(prompt, list):
1053
- batch_size = len(prompt)
1054
- else:
1055
- batch_size = prompt_embeds.shape[0]
1056
-
1057
- # 2. Check inputs. Raise error if not correct
1058
- self.check_inputs(
1059
- prompt=prompt,
1060
- negative_prompt=negative_prompt,
1061
- height=height,
1062
- width=width,
1063
- batch_size=batch_size,
1064
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1065
- prompt_embeds=prompt_embeds,
1066
- negative_prompt_embeds=negative_prompt_embeds,
1067
- prompt_attention_mask=prompt_attention_mask,
1068
- negative_prompt_attention_mask=negative_prompt_attention_mask,
1069
- )
1070
-
1071
- self._attention_kwargs = attention_kwargs
1072
- self._interrupt = False
1073
- self._current_timestep = None
1074
-
1075
- # Auto-upgrade AdaptiveProjectedGuidance to VideoAdaptiveProjectedGuidance
1076
- # for video generation. Video-aware APG normalizes per-frame [C,H,W] instead
1077
- # of collapsing the temporal axis, preserving motion quality.
1078
- if type(self.guider) is AdaptiveProjectedGuidance:
1079
- self.guider = VideoAdaptiveProjectedGuidance(
1080
- guidance_scale=self.guider.guidance_scale,
1081
- adaptive_projected_guidance_rescale=self.guider.adaptive_projected_guidance_rescale,
1082
- adaptive_projected_guidance_momentum=self.guider.adaptive_projected_guidance_momentum,
1083
- eta=self.guider.eta,
1084
- use_original_formulation=self.guider.use_original_formulation,
1085
- )
1086
-
1087
- device = self._execution_device
1088
-
1089
- # 3. Prepare text embeddings
1090
- prompt_embeds, pooled_prompt_embeds, prompt_attention_mask, pos_actual_len = self.encode_prompt(
1091
- prompt=prompt,
1092
- num_videos_per_prompt=num_videos_per_prompt,
1093
- prompt_embeds=prompt_embeds,
1094
- pooled_prompt_embeds=pooled_prompt_embeds,
1095
- prompt_attention_mask=prompt_attention_mask,
1096
- max_sequence_length=max_sequence_length,
1097
- device=device,
1098
- )
1099
-
1100
- if not self.guider._enabled and pos_actual_len is not None:
1101
- prompt_embeds = prompt_embeds[:, :pos_actual_len, :]
1102
- prompt_attention_mask = None
1103
-
1104
- if self.guider._enabled:
1105
- negative_prompt = self._prepare_negative_prompt(negative_prompt, batch_size)
1106
- (
1107
- negative_prompt_embeds,
1108
- negative_pooled_prompt_embeds,
1109
- negative_prompt_attention_mask,
1110
- neg_actual_len,
1111
- ) = self.encode_prompt(
1112
- prompt=negative_prompt,
1113
- num_videos_per_prompt=num_videos_per_prompt,
1114
- prompt_embeds=negative_prompt_embeds,
1115
- pooled_prompt_embeds=negative_pooled_prompt_embeds,
1116
- prompt_attention_mask=negative_prompt_attention_mask,
1117
- max_sequence_length=max_sequence_length,
1118
- device=device,
1119
- )
1120
-
1121
- # Trim each to its own actual length — guider runs pos/neg in separate loop iterations,
1122
- # so different seq lengths are fine. No padding embeddings attend without mask.
1123
- if pos_actual_len is not None and neg_actual_len is not None:
1124
- prompt_embeds = prompt_embeds[:, :pos_actual_len, :]
1125
- negative_prompt_embeds = negative_prompt_embeds[:, :neg_actual_len, :]
1126
- prompt_attention_mask = None
1127
- negative_prompt_attention_mask = None
1128
-
1129
- num_channels_latents = self.vae.config.z_dim
1130
- latents = self.prepare_latents(
1131
- batch_size * num_videos_per_prompt,
1132
- num_channels_latents,
1133
- height,
1134
- width,
1135
- num_frames,
1136
- self.transformer.dtype,
1137
- device,
1138
- generator,
1139
- latents,
1140
- )
1141
-
1142
- # 4.5 Preprocess image for I2V conditioning
1143
- if image is not None:
1144
- from PIL import Image as PILImage
1145
-
1146
- if isinstance(image, PILImage.Image):
1147
- image = image.convert("RGB").resize((width, height), PILImage.LANCZOS)
1148
- image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
1149
- image = image * 2.0 - 1.0 # [0,1] -> [-1,1]
1150
- image = image.unsqueeze(0) # [1, C, H, W]
1151
- # Handle [C, H, W] -> [1, C, H, W]
1152
- if image.dim() == 3:
1153
- image = image.unsqueeze(0)
1154
- # [B, C, H, W] -> [B, 1, C, H, W] for video format
1155
- if image.dim() == 4:
1156
- image = image.unsqueeze(1)
1157
- image = image.to(device=device, dtype=self.vae.dtype)
1158
-
1159
- # 5. Prepare timesteps (including mu calculation)
1160
-
1161
- # Recalculate latent dims based on VAE for mu calculation
1162
- latent_height = height // self.vae_scale_factor_spatial
1163
- latent_width = width // self.vae_scale_factor_spatial
1164
- latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
1165
-
1166
- # Calculate sequence length based on *packed* dimensions if transformer uses packing
1167
- # Packed dims: H/patch, W/patch, F/patch_t
1168
- packed_latent_height = latent_height // self.transformer_spatial_patch_size
1169
- packed_latent_width = latent_width // self.transformer_spatial_patch_size
1170
- packed_latent_num_frames = latent_num_frames // self.transformer_temporal_patch_size
1171
- video_sequence_length = packed_latent_num_frames * packed_latent_height * packed_latent_width
1172
-
1173
- # Compute sigmas: use linear-quadratic schedule if enabled, otherwise default linear
1174
- _is_flow_multistep = isinstance(
1175
- self.scheduler,
1176
- (
1177
- DPMSolverMultistepScheduler,
1178
- UniPCMultistepScheduler,
1179
- FlowUniPCMultistepScheduler,
1180
- ),
1181
- )
1182
-
1183
- # Compute mu once, shared by both branches (required by FlowUniPCMultistepScheduler)
1184
- mu = calculate_shift(
1185
- video_sequence_length,
1186
- self.scheduler.config.get("base_image_seq_len", 256),
1187
- self.scheduler.config.get("max_image_seq_len", 4096),
1188
- self.scheduler.config.get("base_shift", 0.5),
1189
- self.scheduler.config.get("max_shift", 1.15),
1190
- )
1191
-
1192
- if _is_flow_multistep:
1193
- # DPMSolver/UniPC manage their own sigma schedule via use_flow_sigmas + flow_shift.
1194
- # Pass mu for dynamic shifting support (required by FlowUniPCMultistepScheduler).
1195
- timesteps, num_inference_steps = retrieve_timesteps(
1196
- self.scheduler,
1197
- num_inference_steps,
1198
- device,
1199
- timesteps,
1200
- mu=mu,
1201
- )
1202
- else:
1203
- if use_linear_quadratic_schedule:
1204
- # Linear-quadratic schedule computes sigmas internally in retrieve_timesteps
1205
- sigmas = None
1206
- else:
1207
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1208
-
1209
- timesteps, num_inference_steps = retrieve_timesteps(
1210
- self.scheduler,
1211
- num_inference_steps,
1212
- device,
1213
- timesteps,
1214
- sigmas=sigmas,
1215
- use_linear_quadratic_schedule=use_linear_quadratic_schedule,
1216
- linear_quadratic_emulating_steps=linear_quadratic_emulating_steps,
1217
- mu=mu,
1218
- )
1219
-
1220
- # Get conditioning tensors
1221
- latent_condition, latent_mask, image_embeds = self._prepare_first_frame_conditioning(
1222
- image,
1223
- latents,
1224
- use_conditioning=image is not None,
1225
- generator=generator,
1226
- )
1227
-
1228
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1229
- self._num_timesteps = len(timesteps)
1230
-
1231
- # 6. Denoising loop
1232
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1233
- for i, t in enumerate(timesteps):
1234
- if self.interrupt:
1235
- continue
1236
-
1237
- self._current_timestep = t
1238
-
1239
- # Concatenate current latents with conditioning for this timestep
1240
- # [latents | latent_condition | latent_mask]
1241
- hidden_states = torch.cat([latents, latent_condition, latent_mask], dim=1)
1242
-
1243
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1244
- timestep = t.expand(latents.shape[0])
1245
-
1246
- # Step 1: Collect model inputs needed for the guidance method
1247
- # conditional inputs should always be first element in the tuple
1248
- guider_inputs = {
1249
- "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
1250
- }
1251
- if use_attention_mask and prompt_attention_mask is not None:
1252
- guider_inputs["encoder_attention_mask"] = (
1253
- prompt_attention_mask,
1254
- negative_prompt_attention_mask,
1255
- )
1256
- if self.transformer.config.pooled_projection_dim is not None:
1257
- guider_inputs["pooled_projections"] = (
1258
- pooled_prompt_embeds,
1259
- negative_pooled_prompt_embeds,
1260
- )
1261
- if image_embeds is not None:
1262
- guider_inputs["image_embeds"] = (image_embeds, image_embeds)
1263
-
1264
- # Step 2: Update guider's internal state for this denoising step
1265
- self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
1266
- # Sigma injection for guiders that support sigma-based gating
1267
- # (Kynkäänniemi 2024). Must precede `prepare_inputs` because
1268
- # `num_conditions` → `_is_cfg_enabled()` reads `_current_sigma`.
1269
- # Duck-typed so diffusers-native guiders are unaffected; guard
1270
- # on scheduler too since some schedulers don't expose `sigmas`.
1271
- if hasattr(self.guider, "_current_sigma") and hasattr(self.scheduler, "sigmas"):
1272
- self.guider._current_sigma = float(self.scheduler.sigmas[i])
1273
-
1274
- # Step 3: Prepare batched model inputs based on the guidance method
1275
- # The guider splits model inputs into separate batches for conditional/unconditional predictions.
1276
- # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
1277
- # you will get a guider_state with two batches:
1278
- # guider_state = [
1279
- # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
1280
- # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
1281
- # ]
1282
- # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
1283
- guider_state = self.guider.prepare_inputs(guider_inputs)
1284
-
1285
- # Step 4: Run the denoiser for each batch
1286
- # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
1287
- # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
1288
- for guider_state_batch in guider_state:
1289
- self.guider.prepare_models(self.transformer)
1290
-
1291
- # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
1292
- cond_kwargs = {
1293
- input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
1294
- }
1295
-
1296
- tread_disabled = getattr(self.guider, "_current_tread_disabled", False)
1297
-
1298
- # Override TREAD selection ratio per batch if the guider provides one
1299
- selection_ratio = getattr(self.guider, "_current_selection_ratio", None)
1300
- tread_mixin = getattr(self.transformer, "_inference_tread_mixin", None)
1301
- if (
1302
- selection_ratio is not None
1303
- and tread_mixin is not None
1304
- and tread_mixin._tread_route is not None
1305
- ):
1306
- tread_mixin._tread_route["sel"] = selection_ratio
1307
-
1308
- # e.g. "pred_cond"/"pred_uncond"
1309
- context_name = getattr(guider_state_batch, self.guider._identifier_key)
1310
- with self.transformer.cache_context(context_name):
1311
- # Run denoiser and store noise prediction in this batch
1312
-
1313
- noise_pred = self.transformer(
1314
- hidden_states=hidden_states,
1315
- timestep=timestep,
1316
- attention_kwargs=self.attention_kwargs,
1317
- return_dict=False,
1318
- tread_disabled=tread_disabled,
1319
- **cond_kwargs,
1320
- )[0].clone()
1321
-
1322
- guider_state_batch.noise_pred = noise_pred
1323
- # Cleanup model (e.g., remove hooks)
1324
- self.guider.cleanup_models(self.transformer)
1325
-
1326
- # Step 5: Combine predictions using the guidance method
1327
- # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
1328
- # Continuing the CFG example, the guider receives:
1329
- # guider_state = [
1330
- # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
1331
- # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
1332
- # ]
1333
- # And extracts predictions using the __guidance_identifier__:
1334
- # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
1335
- # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
1336
- # Then applies CFG formula:
1337
- # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
1338
- # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
1339
- noise_pred = self.guider(guider_state)[0]
1340
-
1341
- # compute the previous noisy sample x_t -> x_t-1
1342
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1343
-
1344
- if callback_on_step_end is not None:
1345
- callback_kwargs = {}
1346
- for k in callback_on_step_end_tensor_inputs:
1347
- callback_kwargs[k] = locals()[k]
1348
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1349
-
1350
- latents = callback_outputs.pop("latents", latents)
1351
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1352
- # Handle negative embeds if needed by callback
1353
- if "negative_prompt_embeds" in callback_outputs:
1354
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds")
1355
-
1356
- # call the callback, if provided
1357
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1358
- progress_bar.update()
1359
-
1360
- if XLA_AVAILABLE:
1361
- xm.mark_step()
1362
-
1363
- self._current_timestep = None
1364
-
1365
- if output_type == "latent":
1366
- video = latents
1367
- else:
1368
- latents = latents.to(self.vae.dtype)
1369
- latents = self._denormalize_latents(latents, self.vae.config.latents_mean, self.vae.config.latents_std)
1370
- if vae_batch_size is not None and latents.shape[0] > vae_batch_size:
1371
- video_chunks = []
1372
- for i in range(0, latents.shape[0], vae_batch_size):
1373
- chunk = latents[i : i + vae_batch_size]
1374
- video_chunks.append(self.vae.decode(chunk, return_dict=False)[0])
1375
- video = torch.cat(video_chunks, dim=0)
1376
- del video_chunks
1377
- else:
1378
- video = self.vae.decode(latents, return_dict=False)[0]
1379
- video = self.video_processor.postprocess_video(video, output_type=output_type)
1380
-
1381
- # Offload all models
1382
- self.maybe_free_model_hooks()
1383
-
1384
- if not return_dict:
1385
- return (video,)
1386
-
1387
- # Return updated output type
1388
- return MotifVideoPipelineOutput(frames=video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer/config.json CHANGED
@@ -3,7 +3,6 @@
3
  "_diffusers_version": "0.36.0",
4
  "_library": "diffusers",
5
  "attention_head_dim": 128,
6
- "base_latent_size": null,
7
  "image_embed_dim": 1152,
8
  "in_channels": 33,
9
  "mlp_ratio": 4.0,
@@ -15,7 +14,6 @@
15
  "out_channels": 16,
16
  "patch_size": 2,
17
  "patch_size_t": 1,
18
- "pooled_projection_dim": null,
19
  "qk_norm": "rms_norm",
20
  "rope_axes_dim": [
21
  16,
 
3
  "_diffusers_version": "0.36.0",
4
  "_library": "diffusers",
5
  "attention_head_dim": 128,
 
6
  "image_embed_dim": 1152,
7
  "in_channels": 33,
8
  "mlp_ratio": 4.0,
 
14
  "out_channels": 16,
15
  "patch_size": 2,
16
  "patch_size_t": 1,
 
17
  "qk_norm": "rms_norm",
18
  "rope_axes_dim": [
19
  16,
transformer/transformer_motif_video.py DELETED
@@ -1,1350 +0,0 @@
1
- # Copyright 2026 Motif Technologies. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
- from functools import lru_cache
17
- from typing import Any, Dict, List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
25
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
26
- from diffusers.models.attention import FeedForward
27
- from diffusers.models.attention_processor import Attention, AttentionProcessor
28
- from diffusers.models.cache_utils import CacheMixin
29
- from diffusers.models.embeddings import (
30
- PixArtAlphaTextProjection,
31
- TimestepEmbedding,
32
- Timesteps,
33
- )
34
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
35
- from diffusers.models.modeling_utils import ModelMixin
36
- from diffusers.models.normalization import (
37
- AdaLayerNormContinuous,
38
- AdaLayerNormZero,
39
- AdaLayerNormZeroSingle,
40
- )
41
- from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
42
-
43
- # Stub functions for TREAD (Token REduction with Approximated Distillation).
44
- # These stubs ensure TREAD code paths are never activated during inference
45
- # without requiring the motif_core package.
46
- def is_tread_start(block_idx, start, end): return False
47
- def is_tread_end(block_idx, start, end): return False
48
-
49
-
50
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
-
52
- NUM_TRAIN_TIMESTEPS = 1000
53
-
54
-
55
- def apply_rotary_emb(
56
- x: torch.Tensor,
57
- freqs_cis: Tuple[torch.Tensor, torch.Tensor],
58
- use_real: bool = True,
59
- use_real_unbind_dim: int = -1,
60
- ) -> torch.Tensor:
61
- """
62
- Apply rotary positional embeddings (RoPE) to input tensors.
63
-
64
- This implementation supports both standard 2D RoPE tensors [L, Dh] and batched 4D RoPE
65
- tensors [B, 1, L, Dh] for compatibility with TREAD's token-dropping mechanism where
66
- different batches may have different token subsets.
67
-
68
- Args:
69
- x: Input tensor of shape [B, H, L, Dh].
70
- freqs_cis: Tuple of (cos, sin) tensors. Supports shapes [L, Dh] or [B, 1, L, Dh].
71
- use_real: Whether to use real-valued RoPE implementation.
72
- use_real_unbind_dim: Dimension to unbind when using real-valued RoPE (-1 or -2).
73
-
74
- Returns:
75
- Tensor with rotary embeddings applied, same shape as input x.
76
- """
77
- if use_real:
78
- cos, sin = freqs_cis
79
- if cos.dim() == 2: # [L, Dh] → [1, 1, L, Dh]
80
- cos = cos.unsqueeze(0).unsqueeze(0)
81
- sin = sin.unsqueeze(0).unsqueeze(0)
82
- if cos.dim() != 4 or sin.dim() != 4:
83
- raise RuntimeError(f"RoPE must be 2D or 4D, got cos={cos.dim()}D, sin={sin.dim()}D")
84
-
85
- cos, sin = cos.to(x.device), sin.to(x.device)
86
-
87
- if cos.size(-2) != x.size(-2) or cos.size(-1) != x.size(-1):
88
- raise RuntimeError(
89
- f"RoPE shape mismatch: rope[-2:]=({cos.size(-2)},{cos.size(-1)}) vs x[-2:]=({x.size(-2)},{x.size(-1)})"
90
- )
91
-
92
- if use_real_unbind_dim == -1:
93
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
94
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
95
- elif use_real_unbind_dim == -2:
96
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
97
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
98
- else:
99
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
100
-
101
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
102
- return out
103
- else:
104
- x_rot = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
105
- freqs = freqs_cis.unsqueeze(2)
106
- x_out = torch.view_as_real(x_rot * freqs).flatten(3)
107
- return x_out.type_as(x)
108
-
109
-
110
- class MotifVideoAttnProcessor2_0:
111
- def __init__(self):
112
- if not hasattr(F, "scaled_dot_product_attention"):
113
- raise ImportError(
114
- "MotifVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
115
- )
116
-
117
- def __call__(
118
- self,
119
- attn: Attention,
120
- hidden_states: torch.Tensor,
121
- encoder_hidden_states: Optional[torch.Tensor] = None,
122
- attention_mask: Optional[torch.Tensor] = None,
123
- image_rotary_emb: Optional[torch.Tensor] = None,
124
- query_input: Optional[torch.Tensor] = None,
125
- key_input: Optional[torch.Tensor] = None,
126
- value_input: Optional[torch.Tensor] = None,
127
- ) -> torch.Tensor:
128
- # Cross-attention mode: query already projected externally (cross_attn_query_proj + norm),
129
- # skip to_q and only apply reshape + norm_q + RoPE. K/V use to_k/to_v as normal.
130
- if query_input is not None:
131
- query = query_input.unflatten(2, (attn.heads, -1)).transpose(1, 2)
132
- key = attn.to_k(key_input)
133
- value = attn.to_v(value_input)
134
-
135
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
136
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
137
-
138
- if attn.norm_q is not None:
139
- query = attn.norm_q(query)
140
- if attn.norm_k is not None:
141
- key = attn.norm_k(key)
142
-
143
- if image_rotary_emb is not None:
144
- query = apply_rotary_emb(query, image_rotary_emb)
145
-
146
- hidden_states = F.scaled_dot_product_attention(
147
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
148
- )
149
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
150
- hidden_states = hidden_states.to(query.dtype)
151
- return hidden_states, None
152
-
153
- if attn.add_q_proj is None and encoder_hidden_states is not None:
154
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
155
-
156
- # 1. QKV projections
157
- query = attn.to_q(hidden_states)
158
- key = attn.to_k(hidden_states)
159
- value = attn.to_v(hidden_states)
160
-
161
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
162
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
163
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
164
-
165
- # 2. QK normalization
166
- if attn.norm_q is not None:
167
- query = attn.norm_q(query)
168
- if attn.norm_k is not None:
169
- key = attn.norm_k(key)
170
-
171
- # 3. Rotational positional embeddings applied to latent stream
172
- if image_rotary_emb is not None:
173
- if attn.add_q_proj is None and encoder_hidden_states is not None:
174
- query = torch.cat(
175
- [
176
- apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
177
- query[:, :, -encoder_hidden_states.shape[1] :],
178
- ],
179
- dim=2,
180
- )
181
- key = torch.cat(
182
- [
183
- apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
184
- key[:, :, -encoder_hidden_states.shape[1] :],
185
- ],
186
- dim=2,
187
- )
188
- else:
189
- query = apply_rotary_emb(query, image_rotary_emb)
190
- key = apply_rotary_emb(key, image_rotary_emb)
191
-
192
- # 4. Encoder condition QKV projection and normalization
193
- if attn.add_q_proj is not None and encoder_hidden_states is not None:
194
- encoder_query = attn.add_q_proj(encoder_hidden_states)
195
- encoder_key = attn.add_k_proj(encoder_hidden_states)
196
- encoder_value = attn.add_v_proj(encoder_hidden_states)
197
-
198
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
199
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
200
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
201
-
202
- if attn.norm_added_q is not None:
203
- encoder_query = attn.norm_added_q(encoder_query)
204
- if attn.norm_added_k is not None:
205
- encoder_key = attn.norm_added_k(encoder_key)
206
-
207
- query = torch.cat([query, encoder_query], dim=2)
208
- key = torch.cat([key, encoder_key], dim=2)
209
- value = torch.cat([value, encoder_value], dim=2)
210
-
211
- # 5. Attention
212
- hidden_states = F.scaled_dot_product_attention(
213
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
214
- )
215
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
216
- hidden_states = hidden_states.to(query.dtype)
217
-
218
- # 6. Output projection
219
- if encoder_hidden_states is not None:
220
- hidden_states, encoder_hidden_states = (
221
- hidden_states[:, : -encoder_hidden_states.shape[1]],
222
- hidden_states[:, -encoder_hidden_states.shape[1] :],
223
- )
224
-
225
- if getattr(attn, "to_out", None) is not None:
226
- hidden_states = attn.to_out[0](hidden_states)
227
- hidden_states = attn.to_out[1](hidden_states)
228
-
229
- if getattr(attn, "to_add_out", None) is not None:
230
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
231
-
232
- return hidden_states, encoder_hidden_states
233
-
234
-
235
- class MotifVideoPatchEmbed(nn.Module):
236
- def __init__(
237
- self,
238
- patch_size: Union[int, Tuple[int, int, int]] = 16,
239
- in_chans: int = 3,
240
- embed_dim: int = 768,
241
- ) -> None:
242
- super().__init__()
243
-
244
- patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
245
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
246
-
247
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
248
- hidden_states = self.proj(hidden_states)
249
- hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
250
- return hidden_states
251
-
252
-
253
- class MotifVideoAdaNorm(nn.Module):
254
- def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
255
- super().__init__()
256
-
257
- out_features = out_features or 2 * in_features
258
- self.linear = nn.Linear(in_features, out_features)
259
- self.nonlinearity = nn.SiLU()
260
-
261
- def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
262
- temb = self.linear(self.nonlinearity(temb))
263
- gate_msa, gate_mlp = temb.chunk(2, dim=1)
264
- gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
265
- return gate_msa, gate_mlp
266
-
267
-
268
- class MotifVideoConditionEmbedding(nn.Module):
269
- def __init__(
270
- self,
271
- embedding_dim: int,
272
- pooled_projection_dim: int | None,
273
- ):
274
- super().__init__()
275
-
276
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
277
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
278
-
279
- if isinstance(pooled_projection_dim, int):
280
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
281
-
282
- def forward(
283
- self,
284
- timestep: torch.Tensor,
285
- pooled_projection: torch.Tensor | None = None,
286
- ) -> Tuple[torch.Tensor, torch.Tensor]:
287
- timesteps_proj = self.time_proj(timestep)
288
- timestep_embedder_dtype = next(self.timestep_embedder.parameters()).dtype
289
- conditioning = self.timestep_embedder(timesteps_proj.to(timestep_embedder_dtype)) # (N, D)
290
- if pooled_projection is not None:
291
- conditioning = conditioning + self.text_embedder(pooled_projection)
292
-
293
- token_replace_emb = None
294
-
295
- return conditioning, token_replace_emb
296
-
297
-
298
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L485-L486
299
- def find_correction_factor(num_rotations, dim, base, max_position_embeddings):
300
- dtype = num_rotations.dtype if isinstance(num_rotations, torch.Tensor) else torch.float32
301
- max_pos_tensor = torch.as_tensor(max_position_embeddings, dtype=dtype)
302
- return (dim * torch.log(max_pos_tensor / (num_rotations * 2 * math.pi))) / (
303
- 2 * math.log(base)
304
- ) # Inverse dim formula to find number of rotations
305
-
306
-
307
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L489-L495
308
- def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len):
309
- """
310
- Find the correction range for NTK-by-parts interpolation.
311
- """
312
- low = torch.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len))
313
- high = torch.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len))
314
- low = torch.clamp(low, min=0)
315
- high = torch.clamp(high, max=dim - 1)
316
- return low, high # Clamp values just in case
317
-
318
-
319
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L498-L504
320
- def linear_ramp_mask(min_val, max_val, num_dim):
321
- if isinstance(min_val, torch.Tensor):
322
- if (min_val == max_val).all():
323
- max_val = max_val + 0.001
324
- elif min_val == max_val:
325
- max_val += 0.001
326
-
327
- linear_func = (torch.arange(num_dim, dtype=torch.float32) - min_val) / (max_val - min_val)
328
- ramp_func = torch.clamp(linear_func, 0, 1)
329
- return ramp_func
330
-
331
-
332
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L507-L511
333
- def find_newbase_ntk(dim, base, scale):
334
- """
335
- Calculate the new base for NTK-aware scaling.
336
- """
337
- # Avoid division by zero when dim == 2 (or invalid smaller values).
338
- # In these degenerate cases, fall back to the original base (no NTK adjustment).
339
- if dim <= 2:
340
- return base
341
- return base * (scale ** (dim / (dim - 2)))
342
-
343
-
344
- # Copied from https://github.com/guyyariv/DyPE/blob/5dd4fab99b479ee487754140d717bfb888a6afa2/flux/transformer_flux.py#L514-L652
345
- def get_1d_rotary_pos_embed(
346
- dim: int,
347
- pos: Union[np.ndarray, int],
348
- theta: float = 10000.0,
349
- use_real=False,
350
- linear_factor=1.0,
351
- ntk_factor=1.0,
352
- repeat_interleave_real=True,
353
- freqs_dtype=torch.float32,
354
- yarn=False,
355
- max_pe_len=None,
356
- ori_max_pe_len=64,
357
- dype=False,
358
- current_timestep=1.0,
359
- ):
360
- """
361
- Precompute the frequency tensor for complex exponentials with RoPE.
362
- Supports YARN interpolation for vision transformers.
363
-
364
- Args:
365
- dim (`int`):
366
- Dimension of the frequency tensor.
367
- pos (`np.ndarray` or `int`):
368
- Position indices for the frequency tensor. [S] or scalar.
369
- theta (`float`, *optional*, defaults to 10000.0):
370
- Scaling factor for frequency computation.
371
- use_real (`bool`, *optional*, defaults to False):
372
- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
373
- linear_factor (`float`, *optional*, defaults to 1.0):
374
- Scaling factor for linear interpolation.
375
- ntk_factor (`float`, *optional*, defaults to 1.0):
376
- Scaling factor for NTK-Aware RoPE.
377
- repeat_interleave_real (`bool`, *optional*, defaults to True):
378
- If True and use_real, real and imaginary parts are interleaved with themselves to reach dim.
379
- Otherwise, they are concatenated.
380
- freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
381
- Data type of the frequency tensor.
382
- yarn (`bool`, *optional*, defaults to False):
383
- If True, use YARN interpolation combining NTK, linear, and base methods.
384
- max_pe_len (`int`, *optional*):
385
- Maximum position encoding length (current patches for vision models).
386
- ori_max_pe_len (`int`, *optional*, defaults to 64):
387
- Original maximum position encoding length (base patches for vision models).
388
- dype (`bool`, *optional*, defaults to False):
389
- If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
390
- current_timestep (`float`, *optional*, defaults to 1.0):
391
- Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.
392
-
393
- Returns:
394
- `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
395
- If use_real=True, returns tuple of (cos, sin) tensors.
396
- """
397
- assert dim % 2 == 0
398
-
399
- if isinstance(pos, int):
400
- pos = torch.arange(pos)
401
- if isinstance(pos, np.ndarray):
402
- pos = torch.from_numpy(pos)
403
-
404
- device = pos.device
405
-
406
- if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
407
- if not isinstance(max_pe_len, torch.Tensor):
408
- max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device)
409
-
410
- scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
411
-
412
- beta_0 = 1.25
413
- beta_1 = 0.75
414
- gamma_0 = 16
415
- gamma_1 = 2
416
-
417
- freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
418
-
419
- freqs_linear = 1.0 / torch.einsum(
420
- "..., f -> ... f",
421
- scale,
422
- (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)),
423
- )
424
-
425
- new_base = find_newbase_ntk(dim, theta, scale)
426
- if new_base.dim() > 0:
427
- new_base = new_base.view(-1, 1)
428
- freqs_ntk = 1.0 / torch.pow(new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim))
429
- if freqs_ntk.dim() > 1:
430
- freqs_ntk = freqs_ntk.squeeze()
431
-
432
- if dype:
433
- beta_0 = torch.pow(beta_0, 2.0 * torch.pow(current_timestep, 2.0))
434
- beta_1 = torch.pow(beta_1, 2.0 * torch.pow(current_timestep, 2.0))
435
-
436
- low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
437
- high = torch.clamp(high, max=dim // 2)
438
-
439
- freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
440
- freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
441
-
442
- if dype:
443
- gamma_0 = torch.pow(gamma_0, 2.0 * torch.pow(current_timestep, 2.0))
444
- gamma_1 = torch.pow(gamma_1, 2.0 * torch.pow(current_timestep, 2.0))
445
-
446
- low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
447
- high = torch.clamp(high, max=dim // 2)
448
-
449
- freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(freqs_dtype)
450
- freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
451
-
452
- else:
453
- theta_ntk = theta * ntk_factor
454
- freqs = 1.0 / (theta_ntk ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)) / linear_factor
455
-
456
- freqs = torch.outer(pos, freqs)
457
-
458
- is_npu = freqs.device.type == "npu"
459
- if is_npu:
460
- freqs = freqs.float()
461
-
462
- if use_real and repeat_interleave_real:
463
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
464
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()
465
-
466
- if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
467
- mscale = torch.where(scale <= 1.0, 1.0, 0.1 * torch.log(scale) + 1.0).to(scale)
468
- freqs_cos = freqs_cos * mscale
469
- freqs_sin = freqs_sin * mscale
470
-
471
- return freqs_cos, freqs_sin
472
- elif use_real:
473
- freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
474
- freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
475
- return freqs_cos, freqs_sin
476
- else:
477
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
478
- return freqs_cis
479
-
480
-
481
- class MotifVideoRotaryPosEmbed(nn.Module):
482
- def __init__(
483
- self,
484
- patch_size: int,
485
- patch_size_t: int,
486
- rope_dim: List[int],
487
- theta: float = 256.0,
488
- base_latent_size: int | None = None,
489
- ):
490
- """
491
- Rotary Positional Embedding (RoPE) for video latents.
492
-
493
- Args:
494
- patch_size (`int`):
495
- Spatial patch size (e.g., 2).
496
- patch_size_t (`int`):
497
- Temporal patch size (e.g., 1).
498
- rope_dim (`List[int]`):
499
- Dimensions for RoPE across [Time, Height, Width] axes.
500
- theta (`float`, *optional*, defaults to 256.0):
501
- Base frequency for rotary embeddings.
502
- base_latent_size (`int`, *optional*):
503
- The maximum spatial dimension (in latent units) seen during training,
504
- i.e. `training_resolution / vae_scale_factor_spatial`.
505
- For example, for 1280x1280 training images and a VAE spatial downscale
506
- (`vae_scale_factor_spatial`) of 8, this would be 160; for a downscale
507
- of 16, it would be 80.
508
- """
509
- super().__init__()
510
-
511
- self.patch_size = patch_size
512
- self.patch_size_t = patch_size_t
513
- self.rope_dim = rope_dim
514
- self.theta = theta
515
- self.base_latent_size = base_latent_size
516
-
517
- @lru_cache(maxsize=8)
518
- def _get_base_patch_grid_size(self, base_latent_size: Optional[int], patch_size: int) -> Optional[int]:
519
- return base_latent_size // patch_size if base_latent_size else None
520
-
521
- @lru_cache(maxsize=8)
522
- def _get_dynamic_interpolation_scale(self, h: int, w: int, base_grid_size: int) -> float:
523
- return math.sqrt(h * w / (base_grid_size**2))
524
-
525
- def forward(self, hidden_states: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor:
526
- if self.training:
527
- assert self.base_latent_size is None, (
528
- "RoPE interpolation/extrapolation logic should only be enabled for inference. "
529
- f"During training, base_latent_size must be None, but got {self.base_latent_size!r}."
530
- )
531
-
532
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
533
- rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
534
-
535
- axes_grids = []
536
- for i in range(3):
537
- # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
538
- # original implementation creates it on CPU and then moves it to device. This results in numerical
539
- # differences in layerwise debugging outputs, but visually it is the same.
540
- grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
541
- axes_grids.append(grid)
542
- grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
543
- grid = torch.stack(grid, dim=0) # [3, W, H, T]
544
-
545
- base_patch_grid_size = self._get_base_patch_grid_size(self.base_latent_size, self.patch_size)
546
- if base_patch_grid_size is not None:
547
- if base_patch_grid_size <= 0:
548
- raise ValueError(f"base_patch_grid_size must be a positive number, got {base_patch_grid_size}.")
549
- dynamic_interpolation_scale = self._get_dynamic_interpolation_scale(
550
- rope_sizes[1], rope_sizes[2], base_patch_grid_size
551
- )
552
-
553
- normalized_timestep = torch.tensor(1.0)
554
- if not self.training and timestep is not None:
555
- normalized_timestep = timestep[0] / NUM_TRAIN_TIMESTEPS
556
-
557
- freqs = []
558
- for i in range(3):
559
- common_kwargs = {
560
- "dim": self.rope_dim[i],
561
- "pos": grid[i].reshape(-1),
562
- "theta": self.theta,
563
- "use_real": True,
564
- "freqs_dtype": torch.float64,
565
- }
566
-
567
- # Apply scaling only to spatial dimensions (Height and Width, i=1 and i=2)
568
- if i > 0 and base_patch_grid_size is not None and dynamic_interpolation_scale > 1.0:
569
- # We project the training base to the current size using the uniform scale factor.
570
- # max_pe_len tells the RoPE logic the "new" maximum length it's dealing with.
571
- max_pe_len = torch.tensor(
572
- base_patch_grid_size * dynamic_interpolation_scale,
573
- dtype=torch.float64,
574
- device=hidden_states.device,
575
- )
576
-
577
- freq = get_1d_rotary_pos_embed(
578
- **common_kwargs,
579
- yarn=True, # Enable Yet Another RoPE extensioN (YARN) for extrapolation
580
- max_pe_len=max_pe_len,
581
- ori_max_pe_len=base_patch_grid_size, # The original training scale
582
- dype=True, # Enable Dynamic Position Encoding (time-aware)
583
- current_timestep=normalized_timestep,
584
- )
585
- else:
586
- # Time dimension OR within training bounds -> Standard RoPE
587
- freq = get_1d_rotary_pos_embed(**common_kwargs)
588
-
589
- freqs.append(freq)
590
-
591
- freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
592
- freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
593
- return freqs_cos, freqs_sin
594
-
595
-
596
- class MotifVideoImageProjection(nn.Module):
597
- def __init__(self, in_features: int, hidden_size: int):
598
- super().__init__()
599
- self.norm_in = nn.LayerNorm(in_features)
600
- self.linear_1 = nn.Linear(in_features, in_features)
601
- self.act_fn = nn.GELU()
602
- self.linear_2 = nn.Linear(in_features, hidden_size)
603
- self.norm_out = nn.LayerNorm(hidden_size)
604
-
605
- def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
606
- hidden_states = self.norm_in(image_embeds)
607
- hidden_states = self.linear_1(hidden_states)
608
- hidden_states = self.act_fn(hidden_states)
609
- hidden_states = self.linear_2(hidden_states)
610
- hidden_states = self.norm_out(hidden_states)
611
- return hidden_states
612
-
613
-
614
- class MotifVideoSingleTransformerBlock(nn.Module):
615
- def __init__(
616
- self,
617
- num_attention_heads: int,
618
- attention_head_dim: int,
619
- mlp_ratio: float = 4.0,
620
- qk_norm: str = "rms_norm",
621
- norm_type: str = "layer_norm",
622
- enable_text_cross_attention: bool = False,
623
- ) -> None:
624
- super().__init__()
625
-
626
- hidden_size = num_attention_heads * attention_head_dim
627
- mlp_dim = int(hidden_size * mlp_ratio)
628
-
629
- self.attn = Attention(
630
- query_dim=hidden_size,
631
- cross_attention_dim=None,
632
- dim_head=attention_head_dim,
633
- heads=num_attention_heads,
634
- out_dim=hidden_size,
635
- bias=True,
636
- processor=MotifVideoAttnProcessor2_0(),
637
- qk_norm=qk_norm,
638
- eps=1e-6,
639
- pre_only=True,
640
- )
641
-
642
- self.enable_text_cross_attention = enable_text_cross_attention
643
- if enable_text_cross_attention:
644
- self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
645
- self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
646
- self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
647
- nn.init.zeros_(self.cross_attn_out_proj.weight)
648
- nn.init.zeros_(self.cross_attn_out_proj.bias)
649
-
650
- self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type=norm_type)
651
- self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
652
- self.act_mlp = nn.GELU(approximate="tanh")
653
- self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
654
-
655
- def forward(
656
- self,
657
- hidden_states: torch.Tensor,
658
- encoder_hidden_states: torch.Tensor,
659
- temb: torch.Tensor,
660
- attention_mask: Optional[torch.Tensor] = None,
661
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
662
- token_replace_emb: torch.Tensor | None = None,
663
- first_frame_num_tokens: int | None = None,
664
- image_embed_seq_len: int = 0,
665
- encoder_attention_mask: torch.Tensor | None = None,
666
- ) -> torch.Tensor:
667
- text_seq_length = encoder_hidden_states.shape[1]
668
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
669
-
670
- residual = hidden_states
671
-
672
- # 1. Input normalization
673
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
674
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
675
-
676
- norm_hidden_states, norm_encoder_hidden_states = (
677
- norm_hidden_states[:, :-text_seq_length, :],
678
- norm_hidden_states[:, -text_seq_length:, :],
679
- )
680
-
681
- # 2. Attention
682
- attn_output, context_attn_output = self.attn(
683
- hidden_states=norm_hidden_states,
684
- encoder_hidden_states=norm_encoder_hidden_states,
685
- attention_mask=attention_mask,
686
- image_rotary_emb=image_rotary_emb,
687
- )
688
-
689
- # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
690
- if self.enable_text_cross_attention:
691
- txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
692
- text_mask = None
693
- if encoder_attention_mask is not None:
694
- text_mask = encoder_attention_mask[:, image_embed_seq_len:]
695
- text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
696
- cross_q = self.cross_attn_query_proj(attn_output)
697
- cross_output, _ = self.attn(
698
- hidden_states=cross_q,
699
- query_input=cross_q,
700
- key_input=txt_kv,
701
- value_input=txt_kv,
702
- attention_mask=text_mask,
703
- image_rotary_emb=image_rotary_emb,
704
- )
705
- attn_output = attn_output + self.cross_attn_out_proj(cross_output)
706
-
707
- attn_output = torch.cat([attn_output, context_attn_output], dim=1)
708
-
709
- # 3. Modulation and residual connection
710
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
711
- hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
712
- hidden_states = hidden_states + residual
713
-
714
- hidden_states, encoder_hidden_states = (
715
- hidden_states[:, :-text_seq_length, :],
716
- hidden_states[:, -text_seq_length:, :],
717
- )
718
- return hidden_states, encoder_hidden_states
719
-
720
-
721
- class MotifVideoTransformerBlock(nn.Module):
722
- def __init__(
723
- self,
724
- num_attention_heads: int,
725
- attention_head_dim: int,
726
- mlp_ratio: float,
727
- qk_norm: str = "rms_norm",
728
- norm_type: str = "layer_norm",
729
- enable_text_cross_attention: bool = False,
730
- ) -> None:
731
- super().__init__()
732
-
733
- hidden_size = num_attention_heads * attention_head_dim
734
-
735
- self.norm1 = AdaLayerNormZero(hidden_size, norm_type=norm_type)
736
- self.norm1_context = AdaLayerNormZero(hidden_size, norm_type=norm_type)
737
-
738
- self.attn = Attention(
739
- query_dim=hidden_size,
740
- cross_attention_dim=None,
741
- added_kv_proj_dim=hidden_size,
742
- dim_head=attention_head_dim,
743
- heads=num_attention_heads,
744
- out_dim=hidden_size,
745
- context_pre_only=False,
746
- bias=True,
747
- processor=MotifVideoAttnProcessor2_0(),
748
- qk_norm=qk_norm,
749
- eps=1e-6,
750
- )
751
-
752
- self.enable_text_cross_attention = enable_text_cross_attention
753
- if enable_text_cross_attention:
754
- self.cross_attn_query_proj = nn.Linear(hidden_size, hidden_size)
755
- self.cross_attn_query_norm = nn.LayerNorm(hidden_size, eps=1e-6)
756
- self.cross_attn_out_proj = nn.Linear(hidden_size, hidden_size)
757
- nn.init.zeros_(self.cross_attn_out_proj.weight)
758
- nn.init.zeros_(self.cross_attn_out_proj.bias)
759
-
760
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
761
- self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
762
-
763
- self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
764
- self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
765
-
766
- def forward(
767
- self,
768
- hidden_states: torch.Tensor,
769
- encoder_hidden_states: torch.Tensor,
770
- temb: torch.Tensor,
771
- attention_mask: Optional[torch.Tensor] = None,
772
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
773
- token_replace_emb: torch.Tensor | None = None,
774
- first_frame_num_tokens: int | None = None,
775
- image_embed_seq_len: int = 0,
776
- encoder_attention_mask: torch.Tensor | None = None,
777
- ) -> Tuple[torch.Tensor, torch.Tensor]:
778
- # 1. Input normalization
779
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
780
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
781
- encoder_hidden_states, emb=temb
782
- )
783
-
784
- # 2. Joint attention
785
- attn_output, context_attn_output = self.attn(
786
- hidden_states=norm_hidden_states,
787
- encoder_hidden_states=norm_encoder_hidden_states,
788
- attention_mask=attention_mask,
789
- image_rotary_emb=image_rotary_emb,
790
- )
791
-
792
- # 3. Modulation and residual connection
793
- hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
794
-
795
- # Text cross-attention: Q=proj(attn_output), K/V=normed text, reuse self.attn weights
796
- if self.enable_text_cross_attention:
797
- txt_kv = norm_encoder_hidden_states[:, image_embed_seq_len:, :]
798
- text_mask = None
799
- if encoder_attention_mask is not None:
800
- text_mask = encoder_attention_mask[:, image_embed_seq_len:]
801
- text_mask = text_mask.unsqueeze(1).unsqueeze(1).to(torch.bool) # [B, 1, 1, L_txt]
802
- cross_q = self.cross_attn_query_proj(attn_output)
803
- cross_output, _ = self.attn(
804
- hidden_states=cross_q,
805
- query_input=cross_q,
806
- key_input=txt_kv,
807
- value_input=txt_kv,
808
- attention_mask=text_mask,
809
- image_rotary_emb=image_rotary_emb,
810
- )
811
- hidden_states = hidden_states + self.cross_attn_out_proj(cross_output)
812
-
813
- encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
814
-
815
- norm_hidden_states = self.norm2(hidden_states)
816
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
817
-
818
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
819
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
820
-
821
- # 4. Feed-forward
822
- ff_output = self.ff(norm_hidden_states)
823
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
824
-
825
- hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
826
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
827
-
828
- return hidden_states, encoder_hidden_states
829
-
830
-
831
- TransformerBlockRegistry.register(
832
- model_class=MotifVideoTransformerBlock,
833
- metadata=TransformerBlockMetadata(
834
- return_hidden_states_index=0,
835
- return_encoder_hidden_states_index=1,
836
- ),
837
- )
838
- TransformerBlockRegistry.register(
839
- model_class=MotifVideoSingleTransformerBlock,
840
- metadata=TransformerBlockMetadata(
841
- return_hidden_states_index=0,
842
- return_encoder_hidden_states_index=1,
843
- ),
844
- )
845
-
846
-
847
- class MotifVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
848
- r"""
849
- A Transformer model for video-like data used in [MotifVideo](https://huggingface.co/motif/motifvideo).
850
-
851
- Args:
852
- in_channels (`int`, defaults to `16`):
853
- The number of channels in the input.
854
- out_channels (`int`, defaults to `16`):
855
- The number of channels in the output.
856
- num_attention_heads (`int`, defaults to `24`):
857
- The number of heads to use for multi-head attention.
858
- attention_head_dim (`int`, defaults to `128`):
859
- The number of channels in each head.
860
- num_layers (`int`, defaults to `20`):
861
- The number of layers of dual-stream blocks to use.
862
- num_single_layers (`int`, defaults to `40`):
863
- The number of layers of single-stream blocks to use.
864
-
865
- mlp_ratio (`float`, defaults to `4.0`):
866
- The ratio of the hidden layer size to the input size in the feedforward network.
867
- patch_size (`int`, defaults to `2`):
868
- The size of the spatial patches to use in the patch embedding layer.
869
- patch_size_t (`int`, defaults to `1`):
870
- The size of the temporal patches to use in the patch embedding layer.
871
- qk_norm (`str`, defaults to `rms_norm`):
872
- The normalization to use for the query and key projections in the attention layers.
873
- text_embed_dim (`int`, defaults to `4096`):
874
- Input dimension of text embeddings from the text encoder.
875
- rope_theta (`float`, defaults to `256.0`):
876
- The value of theta to use in the RoPE layer.
877
- rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
878
- The dimensions of the axes to use in the RoPE layer.
879
- base_latent_size (`int`, *optional*):
880
- The maximum spatial dimension (in latent units) seen during training.
881
- For example, if trained on 1280x1280 with a VAE downscale of 16, this is 80.
882
- """
883
-
884
- _supports_gradient_checkpointing = True
885
- _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
886
- _no_split_modules = [
887
- "MotifVideoTransformerBlock",
888
- "MotifVideoSingleTransformerBlock",
889
- "MotifVideoPatchEmbed",
890
- ]
891
-
892
- @register_to_config
893
- def __init__(
894
- self,
895
- in_channels: int = 33,
896
- out_channels: int = 16,
897
- num_attention_heads: int = 24,
898
- attention_head_dim: int = 128,
899
- num_layers: int = 20,
900
- num_single_layers: int = 40,
901
- num_decoder_layers: int = 0,
902
- mlp_ratio: float = 4.0,
903
- patch_size: int = 2,
904
- patch_size_t: int = 1,
905
- qk_norm: str = "rms_norm",
906
- norm_type: str = "layer_norm",
907
- text_embed_dim: int = 4096,
908
- image_embed_dim: int | None = None,
909
- pooled_projection_dim: int | None = None,
910
- rope_theta: float = 256.0,
911
- rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
912
- base_latent_size: int | None = None,
913
- enable_text_cross_attention_dual: bool = False,
914
- enable_text_cross_attention_single: bool = False,
915
- ) -> None:
916
- super().__init__()
917
-
918
- inner_dim = num_attention_heads * attention_head_dim
919
- out_channels = out_channels or in_channels
920
-
921
- # 1. Latent and condition embedders
922
- self.x_embedder = MotifVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
923
- self.context_embedder = PixArtAlphaTextProjection(in_features=text_embed_dim, hidden_size=inner_dim)
924
-
925
- # First frame conditioning: Image conditioning embedders
926
- self.image_embed_dim = image_embed_dim
927
- if image_embed_dim is not None:
928
- # Project image embeddings from vision encoder to transformer dim
929
- self.image_embedder = MotifVideoImageProjection(in_features=image_embed_dim, hidden_size=inner_dim)
930
-
931
- self.time_text_embed = MotifVideoConditionEmbedding(inner_dim, pooled_projection_dim)
932
-
933
- # 2. RoPE
934
- self.rope = MotifVideoRotaryPosEmbed(
935
- patch_size, patch_size_t, rope_axes_dim, rope_theta, base_latent_size=base_latent_size
936
- )
937
-
938
- # Cross-attention config
939
- self.enable_text_cross_attention_dual = enable_text_cross_attention_dual
940
- self.enable_text_cross_attention_single = enable_text_cross_attention_single
941
-
942
- # 3. Dual stream transformer blocks
943
- self.transformer_blocks = nn.ModuleList(
944
- [
945
- MotifVideoTransformerBlock(
946
- num_attention_heads,
947
- attention_head_dim,
948
- mlp_ratio=mlp_ratio,
949
- qk_norm=qk_norm,
950
- norm_type=norm_type,
951
- enable_text_cross_attention=enable_text_cross_attention_dual,
952
- )
953
- for _ in range(num_layers)
954
- ]
955
- )
956
-
957
- # 4. Single stream transformer blocks
958
- # Encoder blocks get cross-attention; decoder blocks do not (no text stream in decoder)
959
- num_encoder_single = num_single_layers - num_decoder_layers
960
- self.single_transformer_blocks = nn.ModuleList(
961
- [
962
- MotifVideoSingleTransformerBlock(
963
- num_attention_heads,
964
- attention_head_dim,
965
- mlp_ratio=mlp_ratio,
966
- qk_norm=qk_norm,
967
- norm_type=norm_type,
968
- enable_text_cross_attention=enable_text_cross_attention_single
969
- if i < num_encoder_single
970
- else False,
971
- )
972
- for i in range(num_single_layers)
973
- ]
974
- )
975
-
976
- # 5. Output projection
977
- self.norm_out = AdaLayerNormContinuous(
978
- inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type=norm_type
979
- )
980
- self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
981
-
982
- # Verify cross-attention config matches actual block state.
983
- # Catches silent misconfiguration (e.g. checkpoint config with renamed keys).
984
- for i, block in enumerate(self.transformer_blocks):
985
- if block.enable_text_cross_attention != enable_text_cross_attention_dual:
986
- raise ValueError(
987
- f"transformer_blocks[{i}].enable_text_cross_attention="
988
- f"{block.enable_text_cross_attention}, expected {enable_text_cross_attention_dual}. "
989
- f"Check checkpoint config.json key names match __init__ parameters."
990
- )
991
- num_encoder_single = num_single_layers - num_decoder_layers
992
- for i, block in enumerate(self.single_transformer_blocks):
993
- expected = enable_text_cross_attention_single if i < num_encoder_single else False
994
- if block.enable_text_cross_attention != expected:
995
- raise ValueError(
996
- f"single_transformer_blocks[{i}].enable_text_cross_attention="
997
- f"{block.enable_text_cross_attention}, expected {expected}. "
998
- f"Check checkpoint config.json key names match __init__ parameters."
999
- )
1000
-
1001
- self.gradient_checkpointing = False
1002
- self.num_decoder_layers = num_decoder_layers
1003
-
1004
- @property
1005
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
1006
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
1007
- r"""
1008
- Returns:
1009
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
1010
- indexed by its weight name.
1011
- """
1012
- # set recursively
1013
- processors = {}
1014
-
1015
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1016
- if hasattr(module, "get_processor"):
1017
- processors[f"{name}.processor"] = module.get_processor()
1018
-
1019
- for sub_name, child in module.named_children():
1020
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1021
-
1022
- return processors
1023
-
1024
- for name, module in self.named_children():
1025
- fn_recursive_add_processors(name, module, processors)
1026
-
1027
- return processors
1028
-
1029
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1030
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
1031
- r"""
1032
- Sets the attention processor to use to compute attention.
1033
-
1034
- Parameters:
1035
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1036
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
1037
- for **all** `Attention` layers.
1038
-
1039
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1040
- processor. This is strongly recommended when setting trainable attention processors.
1041
-
1042
- """
1043
- count = len(self.attn_processors.keys())
1044
-
1045
- if isinstance(processor, dict) and len(processor) != count:
1046
- raise ValueError(
1047
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1048
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1049
- )
1050
-
1051
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1052
- if hasattr(module, "set_processor"):
1053
- if not isinstance(processor, dict):
1054
- module.set_processor(processor)
1055
- else:
1056
- module.set_processor(processor.pop(f"{name}.processor"))
1057
-
1058
- for sub_name, child in module.named_children():
1059
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1060
-
1061
- for name, module in self.named_children():
1062
- fn_recursive_attn_processor(name, module, processor)
1063
-
1064
- def _maybe_gradient_checkpoint_block(self, block, *args):
1065
- if torch.is_grad_enabled() and self.gradient_checkpointing:
1066
- return self._gradient_checkpointing_func(block, *args)
1067
- return block(*args)
1068
-
1069
- def _get_unwrapped_blocks(self, blocks):
1070
- if hasattr(blocks, "_checkpoint_wrapped_module"):
1071
- return blocks._checkpoint_wrapped_module
1072
- elif hasattr(blocks, "module"):
1073
- return blocks.module
1074
- return blocks
1075
-
1076
- def _create_attention_mask(
1077
- self,
1078
- hidden_states: torch.Tensor,
1079
- encoder_attention_mask: torch.Tensor,
1080
- ) -> torch.Tensor:
1081
- """
1082
- Create attention mask of shape [B, 1, 1, N] where N = L + E,
1083
- based on latent tokens (always valid) and the encoder mask.
1084
-
1085
- Args:
1086
- hidden_states: [B, L, D]
1087
- encoder_attention_mask: [B, E] (required)
1088
-
1089
- Returns:
1090
- attention_mask: [B, 1, 1, N]
1091
- """
1092
- attention_mask = F.pad(
1093
- encoder_attention_mask.to(torch.bool),
1094
- (hidden_states.shape[1], 0),
1095
- value=True,
1096
- )
1097
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, L+E]
1098
- return attention_mask
1099
-
1100
- def forward(
1101
- self,
1102
- hidden_states: torch.Tensor,
1103
- timestep: torch.LongTensor,
1104
- encoder_hidden_states: torch.Tensor,
1105
- encoder_attention_mask: torch.Tensor | None = None,
1106
- pooled_projections: torch.Tensor | None = None,
1107
- image_embeds: torch.Tensor | None = None,
1108
- attention_kwargs: Optional[Dict[str, Any]] = None,
1109
- return_dict: bool = True,
1110
- tread_mixin: Optional[Any] = None,
1111
- tread_disabled: bool = False,
1112
- ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
1113
- """
1114
- Forward pass of the MotifVideoTransformer3DModel.
1115
-
1116
- Args:
1117
- hidden_states: Input latent tensor [B, C, F, H, W].
1118
- timestep: Diffusion timesteps [B].
1119
- encoder_hidden_states: Text conditioning [B, E, D].
1120
- encoder_attention_mask: Mask for text conditioning [B, E].
1121
- pooled_projections: Pooled text embeddings [B, D].
1122
- image_embeds: Optional image embeddings from vision encoder [B, N, D].
1123
- attention_kwargs: Additional arguments for attention processors.
1124
- return_dict: Whether to return a Transformer2DModelOutput.
1125
- tread_mixin: Optional TreadMixin instance for token reduction.
1126
- tread_disabled: When True, force tread_mixin to None (dense pass).
1127
- torch.compile specializes on this bool, producing separate graphs
1128
- for dense vs routed without attribute toggling.
1129
-
1130
- Returns:
1131
- Transformer2DModelOutput or tuple containing the predicted samples.
1132
- """
1133
- if tread_disabled:
1134
- tread_mixin = None
1135
- elif tread_mixin is None:
1136
- tread_mixin = getattr(self, "_inference_tread_mixin", None)
1137
-
1138
- if attention_kwargs is not None:
1139
- attention_kwargs = attention_kwargs.copy()
1140
- lora_scale = attention_kwargs.pop("scale", 1.0)
1141
- else:
1142
- lora_scale = 1.0
1143
-
1144
- if USE_PEFT_BACKEND:
1145
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1146
- scale_lora_layers(self, lora_scale)
1147
- else:
1148
- if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
1149
- logger.warning(
1150
- "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
1151
- )
1152
-
1153
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
1154
- p, p_t = self.config.patch_size, self.config.patch_size_t
1155
- post_patch_num_frames = num_frames // p_t
1156
- post_patch_height = height // p
1157
- post_patch_width = width // p
1158
- first_frame_num_tokens = 1 * post_patch_height * post_patch_width
1159
- # 1. RoPE
1160
- image_rotary_emb = self.rope(hidden_states, timestep=timestep)
1161
- # 2. Conditional embeddings
1162
- temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections)
1163
- hidden_states = self.x_embedder(hidden_states)
1164
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1165
-
1166
- # First frame conditioning: Image embeddings from vision encoder
1167
- if image_embeds is not None:
1168
- # image_embeds: [B, N, D_img] -> [B, N, D]
1169
- image_embeds = self.image_embedder(image_embeds)
1170
- encoder_hidden_states = torch.cat([image_embeds, encoder_hidden_states], dim=1)
1171
- # Extend attention mask for image tokens
1172
- if encoder_attention_mask is not None:
1173
- image_mask = torch.ones(
1174
- image_embeds.shape[0],
1175
- image_embeds.shape[1],
1176
- device=encoder_attention_mask.device,
1177
- dtype=encoder_attention_mask.dtype,
1178
- )
1179
- encoder_attention_mask = torch.cat([image_mask, encoder_attention_mask], dim=1)
1180
-
1181
- # image_embed_seq_len: used by cross-attention blocks to slice text from encoder_hidden_states
1182
- image_embed_seq_len = image_embeds.shape[1] if image_embeds is not None else 0
1183
-
1184
- decoder_hidden_states = hidden_states.clone()
1185
-
1186
- if encoder_attention_mask is not None:
1187
- attention_mask = self._create_attention_mask(
1188
- hidden_states=hidden_states,
1189
- encoder_attention_mask=encoder_attention_mask,
1190
- )
1191
- else:
1192
- attention_mask = None
1193
-
1194
- # TREAD state initialization: manage token reduction manually to support activation checkpointing
1195
- tread_active = False
1196
- current_route = None
1197
- ids_keep = None
1198
- x_full = None
1199
- orig_mask = attention_mask
1200
- orig_rope = image_rotary_emb
1201
- latent_len = hidden_states.shape[1]
1202
-
1203
- # 4. Dual stream transformer blocks (Encoder)
1204
- for i, block in enumerate(self.transformer_blocks):
1205
- # Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
1206
- if is_tread_start(tread_mixin, tread_active, i):
1207
- tread_active = True
1208
- current_route = tread_mixin._tread_route
1209
- # Reduce sequence length at the start of a TREAD route
1210
- ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
1211
- x_full = hidden_states.contiguous()
1212
- hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
1213
- attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1214
- image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1215
-
1216
- hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1217
- block,
1218
- hidden_states,
1219
- encoder_hidden_states,
1220
- temb,
1221
- attention_mask,
1222
- image_rotary_emb,
1223
- token_replace_emb,
1224
- first_frame_num_tokens,
1225
- image_embed_seq_len,
1226
- encoder_attention_mask,
1227
- )
1228
-
1229
- if is_tread_end(tread_mixin, tread_active, i):
1230
- # Restore full sequence length at the end of a TREAD route
1231
- hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
1232
- tread_active = False
1233
- current_route = None
1234
- ids_keep = None
1235
- x_full = None
1236
- attention_mask = orig_mask
1237
- image_rotary_emb = orig_rope
1238
-
1239
- # We need to unwrap the blocks because CheckpointWrapper does not support len(),
1240
- # which is required for slicing the blocks into encoder and decoder parts.
1241
- single_transformer_blocks = self.single_transformer_blocks
1242
-
1243
- # 5. Single stream transformer blocks (Encoder)
1244
- num_dual = len(self.transformer_blocks)
1245
- for i, block in enumerate(
1246
- single_transformer_blocks[: len(single_transformer_blocks) - self.num_decoder_layers]
1247
- ):
1248
- # Drop tokens if (1) TREAD is enabled, (2) current block is within the TREAD route.
1249
- abs_i = num_dual + i
1250
- if is_tread_start(tread_mixin, tread_active, abs_i):
1251
- tread_active = True
1252
- current_route = tread_mixin._tread_route
1253
- # Reduce sequence length at the start of a TREAD route
1254
- ids_keep = tread_mixin.keep_indices(hidden_states, current_route["sel"]).to(hidden_states.device)
1255
- x_full = hidden_states.contiguous()
1256
- hidden_states = tread_mixin.gather_tokens(hidden_states, ids_keep)
1257
- attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1258
- image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1259
-
1260
- hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1261
- block,
1262
- hidden_states,
1263
- encoder_hidden_states,
1264
- temb,
1265
- attention_mask,
1266
- image_rotary_emb,
1267
- token_replace_emb,
1268
- first_frame_num_tokens,
1269
- image_embed_seq_len,
1270
- encoder_attention_mask,
1271
- )
1272
-
1273
- if is_tread_end(tread_mixin, tread_active, abs_i):
1274
- # Restore full sequence length at the end of a TREAD route
1275
- hidden_states = tread_mixin.scatter_tokens(hidden_states, ids_keep, x_full)
1276
- tread_active = False
1277
- current_route = None
1278
- ids_keep = None
1279
- x_full = None
1280
- attention_mask = orig_mask
1281
- image_rotary_emb = orig_rope
1282
-
1283
- # 6. Single stream transformer blocks (Decoder)
1284
- if self.num_decoder_layers > 0:
1285
- encoder_hidden_states = hidden_states
1286
- attention_mask = None
1287
-
1288
- num_single = len(single_transformer_blocks)
1289
-
1290
- for i, block in enumerate(single_transformer_blocks[-self.num_decoder_layers :]):
1291
- abs_i = num_dual + (num_single - self.num_decoder_layers) + i
1292
- if is_tread_start(tread_mixin, tread_active, abs_i):
1293
- tread_active = True
1294
- current_route = tread_mixin._tread_route
1295
- # Reduce sequence length at the start of a TREAD route
1296
- ids_keep = tread_mixin.keep_indices(decoder_hidden_states, current_route["sel"]).to(
1297
- decoder_hidden_states.device
1298
- )
1299
- x_full = encoder_hidden_states.contiguous()
1300
- x_t_full = decoder_hidden_states.contiguous()
1301
- decoder_hidden_states = tread_mixin.gather_tokens(decoder_hidden_states, ids_keep)
1302
- encoder_hidden_states = tread_mixin.gather_tokens(encoder_hidden_states, ids_keep)
1303
- attention_mask = tread_mixin.adjust_mask(orig_mask, latent_len, ids_keep)
1304
- image_rotary_emb = tread_mixin.gather_rope(orig_rope, ids_keep)
1305
-
1306
- decoder_hidden_states, encoder_hidden_states = self._maybe_gradient_checkpoint_block(
1307
- block,
1308
- decoder_hidden_states,
1309
- encoder_hidden_states,
1310
- temb,
1311
- attention_mask,
1312
- image_rotary_emb,
1313
- token_replace_emb,
1314
- first_frame_num_tokens,
1315
- )
1316
-
1317
- if is_tread_end(tread_mixin, tread_active, abs_i):
1318
- # Restore full sequence length at the end of a TREAD route
1319
- decoder_hidden_states = tread_mixin.scatter_tokens(decoder_hidden_states, ids_keep, x_t_full)
1320
- encoder_hidden_states = tread_mixin.scatter_tokens(encoder_hidden_states, ids_keep, x_full)
1321
- tread_active = False
1322
- current_route = None
1323
- ids_keep = None
1324
- x_full = None
1325
- x_t_full = None
1326
- attention_mask = orig_mask
1327
- image_rotary_emb = orig_rope
1328
-
1329
- hidden_states = decoder_hidden_states
1330
-
1331
- # 7. Output projection
1332
- hidden_states = self.norm_out(hidden_states, temb)
1333
- hidden_states = self.proj_out(hidden_states)
1334
-
1335
- hidden_states = hidden_states.reshape(
1336
- batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
1337
- )
1338
- hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
1339
- hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
1340
-
1341
- if USE_PEFT_BACKEND:
1342
- # remove `lora_scale` from each PEFT layer
1343
- unscale_lora_layers(self, lora_scale)
1344
-
1345
- if not return_dict:
1346
- return (hidden_states,)
1347
-
1348
- return Transformer2DModelOutput(
1349
- sample=hidden_states,
1350
- )