Fabrice-TIERCELIN commited on
Commit
df6f431
·
verified ·
1 Parent(s): b2fbc55

Upload 2 files

Browse files
ltx_video/pipelines/crf_compressor.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import av
2
+ import torch
3
+ import io
4
+ import numpy as np
5
+
6
+
7
+ def _encode_single_frame(output_file, image_array: np.ndarray, crf):
8
+ container = av.open(output_file, "w", format="mp4")
9
+ try:
10
+ stream = container.add_stream(
11
+ "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
12
+ )
13
+ stream.height = image_array.shape[0]
14
+ stream.width = image_array.shape[1]
15
+ av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
16
+ format="yuv420p"
17
+ )
18
+ container.mux(stream.encode(av_frame))
19
+ container.mux(stream.encode())
20
+ finally:
21
+ container.close()
22
+
23
+
24
+ def _decode_single_frame(video_file):
25
+ container = av.open(video_file)
26
+ try:
27
+ stream = next(s for s in container.streams if s.type == "video")
28
+ frame = next(container.decode(stream))
29
+ finally:
30
+ container.close()
31
+ return frame.to_ndarray(format="rgb24")
32
+
33
+
34
+ def compress(image: torch.Tensor, crf=29):
35
+ if crf == 0:
36
+ return image
37
+
38
+ image_array = (
39
+ (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
40
+ .byte()
41
+ .cpu()
42
+ .numpy()
43
+ )
44
+ with io.BytesIO() as output_file:
45
+ _encode_single_frame(output_file, image_array, crf)
46
+ video_bytes = output_file.getvalue()
47
+ with io.BytesIO(video_bytes) as video_file:
48
+ image_array = _decode_single_frame(video_file)
49
+ tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
50
+ return tensor
ltx_video/pipelines/pipeline_ltx_video.py ADDED
@@ -0,0 +1,1845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
2
+ import copy
3
+ import inspect
4
+ import math
5
+ import re
6
+ from contextlib import nullcontext
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models import AutoencoderKL
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
15
+ from diffusers.schedulers import DPMSolverMultistepScheduler
16
+ from diffusers.utils import deprecate, logging
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from einops import rearrange
19
+ from transformers import (
20
+ T5EncoderModel,
21
+ T5Tokenizer,
22
+ AutoModelForCausalLM,
23
+ AutoProcessor,
24
+ AutoTokenizer,
25
+ )
26
+
27
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
28
+ CausalVideoAutoencoder,
29
+ )
30
+ from ltx_video.models.autoencoders.vae_encode import (
31
+ get_vae_size_scale_factor,
32
+ latent_to_pixel_coords,
33
+ vae_decode,
34
+ vae_encode,
35
+ )
36
+ from ltx_video.models.transformers.symmetric_patchifier import Patchifier
37
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
38
+ from ltx_video.schedulers.rf import TimestepShifter
39
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
40
+ from ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt
41
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
42
+ from ltx_video.models.autoencoders.vae_encode import (
43
+ un_normalize_latents,
44
+ normalize_latents,
45
+ )
46
+
47
+
48
+ try:
49
+ import torch_xla.distributed.spmd as xs
50
+ except ImportError:
51
+ xs = None
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ ASPECT_RATIO_1024_BIN = {
57
+ "0.25": [512.0, 2048.0],
58
+ "0.28": [512.0, 1856.0],
59
+ "0.32": [576.0, 1792.0],
60
+ "0.33": [576.0, 1728.0],
61
+ "0.35": [576.0, 1664.0],
62
+ "0.4": [640.0, 1600.0],
63
+ "0.42": [640.0, 1536.0],
64
+ "0.48": [704.0, 1472.0],
65
+ "0.5": [704.0, 1408.0],
66
+ "0.52": [704.0, 1344.0],
67
+ "0.57": [768.0, 1344.0],
68
+ "0.6": [768.0, 1280.0],
69
+ "0.68": [832.0, 1216.0],
70
+ "0.72": [832.0, 1152.0],
71
+ "0.78": [896.0, 1152.0],
72
+ "0.82": [896.0, 1088.0],
73
+ "0.88": [960.0, 1088.0],
74
+ "0.94": [960.0, 1024.0],
75
+ "1.0": [1024.0, 1024.0],
76
+ "1.07": [1024.0, 960.0],
77
+ "1.13": [1088.0, 960.0],
78
+ "1.21": [1088.0, 896.0],
79
+ "1.29": [1152.0, 896.0],
80
+ "1.38": [1152.0, 832.0],
81
+ "1.46": [1216.0, 832.0],
82
+ "1.67": [1280.0, 768.0],
83
+ "1.75": [1344.0, 768.0],
84
+ "2.0": [1408.0, 704.0],
85
+ "2.09": [1472.0, 704.0],
86
+ "2.4": [1536.0, 640.0],
87
+ "2.5": [1600.0, 640.0],
88
+ "3.0": [1728.0, 576.0],
89
+ "4.0": [2048.0, 512.0],
90
+ }
91
+
92
+ ASPECT_RATIO_512_BIN = {
93
+ "0.25": [256.0, 1024.0],
94
+ "0.28": [256.0, 928.0],
95
+ "0.32": [288.0, 896.0],
96
+ "0.33": [288.0, 864.0],
97
+ "0.35": [288.0, 832.0],
98
+ "0.4": [320.0, 800.0],
99
+ "0.42": [320.0, 768.0],
100
+ "0.48": [352.0, 736.0],
101
+ "0.5": [352.0, 704.0],
102
+ "0.52": [352.0, 672.0],
103
+ "0.57": [384.0, 672.0],
104
+ "0.6": [384.0, 640.0],
105
+ "0.68": [416.0, 608.0],
106
+ "0.72": [416.0, 576.0],
107
+ "0.78": [448.0, 576.0],
108
+ "0.82": [448.0, 544.0],
109
+ "0.88": [480.0, 544.0],
110
+ "0.94": [480.0, 512.0],
111
+ "1.0": [512.0, 512.0],
112
+ "1.07": [512.0, 480.0],
113
+ "1.13": [544.0, 480.0],
114
+ "1.21": [544.0, 448.0],
115
+ "1.29": [576.0, 448.0],
116
+ "1.38": [576.0, 416.0],
117
+ "1.46": [608.0, 416.0],
118
+ "1.67": [640.0, 384.0],
119
+ "1.75": [672.0, 384.0],
120
+ "2.0": [704.0, 352.0],
121
+ "2.09": [736.0, 352.0],
122
+ "2.4": [768.0, 320.0],
123
+ "2.5": [800.0, 320.0],
124
+ "3.0": [864.0, 288.0],
125
+ "4.0": [1024.0, 256.0],
126
+ }
127
+
128
+
129
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
130
+ def retrieve_timesteps(
131
+ scheduler,
132
+ num_inference_steps: Optional[int] = None,
133
+ device: Optional[Union[str, torch.device]] = None,
134
+ timesteps: Optional[List[int]] = None,
135
+ skip_initial_inference_steps: int = 0,
136
+ skip_final_inference_steps: int = 0,
137
+ **kwargs,
138
+ ):
139
+ """
140
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
141
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
142
+
143
+ Args:
144
+ scheduler (`SchedulerMixin`):
145
+ The scheduler to get timesteps from.
146
+ num_inference_steps (`int`):
147
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
148
+ `timesteps` must be `None`.
149
+ device (`str` or `torch.device`, *optional*):
150
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
151
+ timesteps (`List[int]`, *optional*):
152
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
153
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
154
+ must be `None`.
155
+ max_timestep ('float', *optional*, defaults to 1.0):
156
+ The initial noising level for image-to-image/video-to-video. The list if timestamps will be
157
+ truncated to start with a timestamp greater or equal to this.
158
+
159
+ Returns:
160
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
161
+ second element is the number of inference steps.
162
+ """
163
+ if timesteps is not None:
164
+ accepts_timesteps = "timesteps" in set(
165
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
166
+ )
167
+ if not accepts_timesteps:
168
+ raise ValueError(
169
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
170
+ f" timestep schedules. Please check whether you are using the correct scheduler."
171
+ )
172
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
173
+ timesteps = scheduler.timesteps
174
+ num_inference_steps = len(timesteps)
175
+ else:
176
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
177
+ timesteps = scheduler.timesteps
178
+
179
+ if (
180
+ skip_initial_inference_steps < 0
181
+ or skip_final_inference_steps < 0
182
+ or skip_initial_inference_steps + skip_final_inference_steps
183
+ >= num_inference_steps
184
+ ):
185
+ raise ValueError(
186
+ "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps"
187
+ )
188
+
189
+ timesteps = timesteps[
190
+ skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps
191
+ ]
192
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
193
+ num_inference_steps = len(timesteps)
194
+
195
+ return timesteps, num_inference_steps
196
+
197
+
198
+ @dataclass
199
+ class ConditioningItem:
200
+ """
201
+ Defines a single frame-conditioning item - a single frame or a sequence of frames.
202
+
203
+ Attributes:
204
+ media_item (torch.Tensor): shape=(b, 3, f, h, w). The media item to condition on.
205
+ media_frame_number (int): The start-frame number of the media item in the generated video.
206
+ conditioning_strength (float): The strength of the conditioning (1.0 = full conditioning).
207
+ media_x (Optional[int]): Optional left x coordinate of the media item in the generated frame.
208
+ media_y (Optional[int]): Optional top y coordinate of the media item in the generated frame.
209
+ """
210
+
211
+ media_item: torch.Tensor
212
+ media_frame_number: int
213
+ conditioning_strength: float
214
+ media_x: Optional[int] = None
215
+ media_y: Optional[int] = None
216
+
217
+
218
+ class LTXVideoPipeline(DiffusionPipeline):
219
+ r"""
220
+ Pipeline for text-to-image generation using LTX-Video.
221
+
222
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
223
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
224
+
225
+ Args:
226
+ vae ([`AutoencoderKL`]):
227
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
228
+ text_encoder ([`T5EncoderModel`]):
229
+ Frozen text-encoder. This uses
230
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
231
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
232
+ tokenizer (`T5Tokenizer`):
233
+ Tokenizer of class
234
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
235
+ transformer ([`Transformer2DModel`]):
236
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
237
+ scheduler ([`SchedulerMixin`]):
238
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
239
+ """
240
+
241
+ bad_punct_regex = re.compile(
242
+ r"["
243
+ + "#®•©™&@·º½¾¿¡§~"
244
+ + r"\)"
245
+ + r"\("
246
+ + r"\]"
247
+ + r"\["
248
+ + r"\}"
249
+ + r"\{"
250
+ + r"\|"
251
+ + "\\"
252
+ + r"\/"
253
+ + r"\*"
254
+ + r"]{1,}"
255
+ ) # noqa
256
+
257
+ _optional_components = [
258
+ "tokenizer",
259
+ "text_encoder",
260
+ "prompt_enhancer_image_caption_model",
261
+ "prompt_enhancer_image_caption_processor",
262
+ "prompt_enhancer_llm_model",
263
+ "prompt_enhancer_llm_tokenizer",
264
+ ]
265
+ model_cpu_offload_seq = "prompt_enhancer_image_caption_model->prompt_enhancer_llm_model->text_encoder->transformer->vae"
266
+
267
+ def __init__(
268
+ self,
269
+ tokenizer: T5Tokenizer,
270
+ text_encoder: T5EncoderModel,
271
+ vae: AutoencoderKL,
272
+ transformer: Transformer3DModel,
273
+ scheduler: DPMSolverMultistepScheduler,
274
+ patchifier: Patchifier,
275
+ prompt_enhancer_image_caption_model: AutoModelForCausalLM,
276
+ prompt_enhancer_image_caption_processor: AutoProcessor,
277
+ prompt_enhancer_llm_model: AutoModelForCausalLM,
278
+ prompt_enhancer_llm_tokenizer: AutoTokenizer,
279
+ allowed_inference_steps: Optional[List[float]] = None,
280
+ ):
281
+ super().__init__()
282
+
283
+ self.register_modules(
284
+ tokenizer=tokenizer,
285
+ text_encoder=text_encoder,
286
+ vae=vae,
287
+ transformer=transformer,
288
+ scheduler=scheduler,
289
+ patchifier=patchifier,
290
+ prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model,
291
+ prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor,
292
+ prompt_enhancer_llm_model=prompt_enhancer_llm_model,
293
+ prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer,
294
+ )
295
+
296
+ self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
297
+ self.vae
298
+ )
299
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
300
+
301
+ self.allowed_inference_steps = allowed_inference_steps
302
+
303
+ def mask_text_embeddings(self, emb, mask):
304
+ if emb.shape[0] == 1:
305
+ keep_index = mask.sum().item()
306
+ return emb[:, :, :keep_index, :], keep_index
307
+ else:
308
+ masked_feature = emb * mask[:, None, :, None]
309
+ return masked_feature, emb.shape[2]
310
+
311
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
312
+ def encode_prompt(
313
+ self,
314
+ prompt: Union[str, List[str]],
315
+ do_classifier_free_guidance: bool = True,
316
+ negative_prompt: str = "",
317
+ num_images_per_prompt: int = 1,
318
+ device: Optional[torch.device] = None,
319
+ prompt_embeds: Optional[torch.FloatTensor] = None,
320
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
321
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
322
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
323
+ text_encoder_max_tokens: int = 256,
324
+ **kwargs,
325
+ ):
326
+ r"""
327
+ Encodes the prompt into text encoder hidden states.
328
+
329
+ Args:
330
+ prompt (`str` or `List[str]`, *optional*):
331
+ prompt to be encoded
332
+ negative_prompt (`str` or `List[str]`, *optional*):
333
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
334
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
335
+ This should be "".
336
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
337
+ whether to use classifier free guidance or not
338
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
339
+ number of images that should be generated per prompt
340
+ device: (`torch.device`, *optional*):
341
+ torch device to place the resulting embeddings on
342
+ prompt_embeds (`torch.FloatTensor`, *optional*):
343
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
344
+ provided, text embeddings will be generated from `prompt` input argument.
345
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
346
+ Pre-generated negative text embeddings.
347
+ """
348
+
349
+ if "mask_feature" in kwargs:
350
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
351
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
352
+
353
+ if device is None:
354
+ device = self._execution_device
355
+
356
+ if prompt is not None and isinstance(prompt, str):
357
+ batch_size = 1
358
+ elif prompt is not None and isinstance(prompt, list):
359
+ batch_size = len(prompt)
360
+ else:
361
+ batch_size = prompt_embeds.shape[0]
362
+
363
+ # See Section 3.1. of the paper.
364
+ max_length = (
365
+ text_encoder_max_tokens # TPU supports only lengths multiple of 128
366
+ )
367
+ if prompt_embeds is None:
368
+ assert (
369
+ self.text_encoder is not None
370
+ ), "You should provide either prompt_embeds or self.text_encoder should not be None,"
371
+ text_enc_device = next(self.text_encoder.parameters()).device
372
+ prompt = self._text_preprocessing(prompt)
373
+ text_inputs = self.tokenizer(
374
+ prompt,
375
+ padding="max_length",
376
+ max_length=max_length,
377
+ truncation=True,
378
+ add_special_tokens=True,
379
+ return_tensors="pt",
380
+ )
381
+ text_input_ids = text_inputs.input_ids
382
+ untruncated_ids = self.tokenizer(
383
+ prompt, padding="longest", return_tensors="pt"
384
+ ).input_ids
385
+
386
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
387
+ -1
388
+ ] and not torch.equal(text_input_ids, untruncated_ids):
389
+ removed_text = self.tokenizer.batch_decode(
390
+ untruncated_ids[:, max_length - 1 : -1]
391
+ )
392
+ logger.warning(
393
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
394
+ f" {max_length} tokens: {removed_text}"
395
+ )
396
+
397
+ prompt_attention_mask = text_inputs.attention_mask
398
+ prompt_attention_mask = prompt_attention_mask.to(text_enc_device)
399
+ prompt_attention_mask = prompt_attention_mask.to(device)
400
+
401
+ prompt_embeds = self.text_encoder(
402
+ text_input_ids.to(text_enc_device), attention_mask=prompt_attention_mask
403
+ )
404
+ prompt_embeds = prompt_embeds[0]
405
+
406
+ if self.text_encoder is not None:
407
+ dtype = self.text_encoder.dtype
408
+ elif self.transformer is not None:
409
+ dtype = self.transformer.dtype
410
+ else:
411
+ dtype = None
412
+
413
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
414
+
415
+ bs_embed, seq_len, _ = prompt_embeds.shape
416
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
417
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
418
+ prompt_embeds = prompt_embeds.view(
419
+ bs_embed * num_images_per_prompt, seq_len, -1
420
+ )
421
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
422
+ prompt_attention_mask = prompt_attention_mask.view(
423
+ bs_embed * num_images_per_prompt, -1
424
+ )
425
+
426
+ # get unconditional embeddings for classifier free guidance
427
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
428
+ uncond_tokens = self._text_preprocessing(negative_prompt)
429
+ uncond_tokens = uncond_tokens * batch_size
430
+ max_length = prompt_embeds.shape[1]
431
+ uncond_input = self.tokenizer(
432
+ uncond_tokens,
433
+ padding="max_length",
434
+ max_length=max_length,
435
+ truncation=True,
436
+ return_attention_mask=True,
437
+ add_special_tokens=True,
438
+ return_tensors="pt",
439
+ )
440
+ negative_prompt_attention_mask = uncond_input.attention_mask
441
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(
442
+ text_enc_device
443
+ )
444
+
445
+ negative_prompt_embeds = self.text_encoder(
446
+ uncond_input.input_ids.to(text_enc_device),
447
+ attention_mask=negative_prompt_attention_mask,
448
+ )
449
+ negative_prompt_embeds = negative_prompt_embeds[0]
450
+
451
+ if do_classifier_free_guidance:
452
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
453
+ seq_len = negative_prompt_embeds.shape[1]
454
+
455
+ negative_prompt_embeds = negative_prompt_embeds.to(
456
+ dtype=dtype, device=device
457
+ )
458
+
459
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
460
+ 1, num_images_per_prompt, 1
461
+ )
462
+ negative_prompt_embeds = negative_prompt_embeds.view(
463
+ batch_size * num_images_per_prompt, seq_len, -1
464
+ )
465
+
466
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
467
+ 1, num_images_per_prompt
468
+ )
469
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
470
+ bs_embed * num_images_per_prompt, -1
471
+ )
472
+ else:
473
+ negative_prompt_embeds = None
474
+ negative_prompt_attention_mask = None
475
+
476
+ return (
477
+ prompt_embeds,
478
+ prompt_attention_mask,
479
+ negative_prompt_embeds,
480
+ negative_prompt_attention_mask,
481
+ )
482
+
483
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
484
+ def prepare_extra_step_kwargs(self, generator, eta):
485
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
486
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
487
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
488
+ # and should be between [0, 1]
489
+
490
+ accepts_eta = "eta" in set(
491
+ inspect.signature(self.scheduler.step).parameters.keys()
492
+ )
493
+ extra_step_kwargs = {}
494
+ if accepts_eta:
495
+ extra_step_kwargs["eta"] = eta
496
+
497
+ # check if the scheduler accepts generator
498
+ accepts_generator = "generator" in set(
499
+ inspect.signature(self.scheduler.step).parameters.keys()
500
+ )
501
+ if accepts_generator:
502
+ extra_step_kwargs["generator"] = generator
503
+ return extra_step_kwargs
504
+
505
+ def check_inputs(
506
+ self,
507
+ prompt,
508
+ height,
509
+ width,
510
+ negative_prompt,
511
+ prompt_embeds=None,
512
+ negative_prompt_embeds=None,
513
+ prompt_attention_mask=None,
514
+ negative_prompt_attention_mask=None,
515
+ enhance_prompt=False,
516
+ ):
517
+ if height % 8 != 0 or width % 8 != 0:
518
+ raise ValueError(
519
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
520
+ )
521
+
522
+ if prompt is not None and prompt_embeds is not None:
523
+ raise ValueError(
524
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
525
+ " only forward one of the two."
526
+ )
527
+ elif prompt is None and prompt_embeds is None:
528
+ raise ValueError(
529
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
530
+ )
531
+ elif prompt is not None and (
532
+ not isinstance(prompt, str) and not isinstance(prompt, list)
533
+ ):
534
+ raise ValueError(
535
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
536
+ )
537
+
538
+ if prompt is not None and negative_prompt_embeds is not None:
539
+ raise ValueError(
540
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
541
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
542
+ )
543
+
544
+ if negative_prompt is not None and negative_prompt_embeds is not None:
545
+ raise ValueError(
546
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
547
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
548
+ )
549
+
550
+ if prompt_embeds is not None and prompt_attention_mask is None:
551
+ raise ValueError(
552
+ "Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
553
+ )
554
+
555
+ if (
556
+ negative_prompt_embeds is not None
557
+ and negative_prompt_attention_mask is None
558
+ ):
559
+ raise ValueError(
560
+ "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
561
+ )
562
+
563
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
564
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
565
+ raise ValueError(
566
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
567
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
568
+ f" {negative_prompt_embeds.shape}."
569
+ )
570
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
571
+ raise ValueError(
572
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
573
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
574
+ f" {negative_prompt_attention_mask.shape}."
575
+ )
576
+
577
+ if enhance_prompt:
578
+ assert (
579
+ self.prompt_enhancer_image_caption_model is not None
580
+ ), "Image caption model must be initialized if enhance_prompt is True"
581
+ assert (
582
+ self.prompt_enhancer_image_caption_processor is not None
583
+ ), "Image caption processor must be initialized if enhance_prompt is True"
584
+ assert (
585
+ self.prompt_enhancer_llm_model is not None
586
+ ), "Text prompt enhancer model must be initialized if enhance_prompt is True"
587
+ assert (
588
+ self.prompt_enhancer_llm_tokenizer is not None
589
+ ), "Text prompt enhancer tokenizer must be initialized if enhance_prompt is True"
590
+
591
+ def _text_preprocessing(self, text):
592
+ if not isinstance(text, (tuple, list)):
593
+ text = [text]
594
+
595
+ def process(text: str):
596
+ text = text.strip()
597
+ return text
598
+
599
+ return [process(t) for t in text]
600
+
601
+ @staticmethod
602
+ def add_noise_to_image_conditioning_latents(
603
+ t: float,
604
+ init_latents: torch.Tensor,
605
+ latents: torch.Tensor,
606
+ noise_scale: float,
607
+ conditioning_mask: torch.Tensor,
608
+ generator,
609
+ eps=1e-6,
610
+ ):
611
+ """
612
+ Add timestep-dependent noise to the hard-conditioning latents.
613
+ This helps with motion continuity, especially when conditioned on a single frame.
614
+ """
615
+ noise = randn_tensor(
616
+ latents.shape,
617
+ generator=generator,
618
+ device=latents.device,
619
+ dtype=latents.dtype,
620
+ )
621
+ # Add noise only to hard-conditioning latents (conditioning_mask = 1.0)
622
+ need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1)
623
+ noised_latents = init_latents + noise_scale * noise * (t**2)
624
+ latents = torch.where(need_to_noise, noised_latents, latents)
625
+ return latents
626
+
627
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
628
+ def prepare_latents(
629
+ self,
630
+ latents: torch.Tensor | None,
631
+ media_items: torch.Tensor | None,
632
+ timestep: float,
633
+ latent_shape: torch.Size | Tuple[Any, ...],
634
+ dtype: torch.dtype,
635
+ device: torch.device,
636
+ generator: torch.Generator | List[torch.Generator],
637
+ vae_per_channel_normalize: bool = True,
638
+ ):
639
+ """
640
+ Prepare the initial latent tensor to be denoised.
641
+ The latents are either pure noise or a noised version of the encoded media items.
642
+ Args:
643
+ latents (`torch.FloatTensor` or `None`):
644
+ The latents to use (provided by the user) or `None` to create new latents.
645
+ media_items (`torch.FloatTensor` or `None`):
646
+ An image or video to be updated using img2img or vid2vid. The media item is encoded and noised.
647
+ timestep (`float`):
648
+ The timestep to noise the encoded media_items to.
649
+ latent_shape (`torch.Size`):
650
+ The target latent shape.
651
+ dtype (`torch.dtype`):
652
+ The target dtype.
653
+ device (`torch.device`):
654
+ The target device.
655
+ generator (`torch.Generator` or `List[torch.Generator]`):
656
+ Generator(s) to be used for the noising process.
657
+ vae_per_channel_normalize ('bool'):
658
+ When encoding the media_items, whether to normalize the latents per-channel.
659
+ Returns:
660
+ `torch.FloatTensor`: The latents to be used for the denoising process. This is a tensor of shape
661
+ (batch_size, num_channels, height, width).
662
+ """
663
+ if isinstance(generator, list) and len(generator) != latent_shape[0]:
664
+ raise ValueError(
665
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
666
+ f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators."
667
+ )
668
+
669
+ # Initialize the latents with the given latents or encoded media item, if provided
670
+ assert (
671
+ latents is None or media_items is None
672
+ ), "Cannot provide both latents and media_items. Please provide only one of the two."
673
+
674
+ assert (
675
+ latents is None and media_items is None or timestep < 1.0
676
+ ), "Input media_item or latents are provided, but they will be replaced with noise."
677
+
678
+ if media_items is not None:
679
+ latents = vae_encode(
680
+ media_items.to(dtype=self.vae.dtype, device=self.vae.device),
681
+ self.vae,
682
+ vae_per_channel_normalize=vae_per_channel_normalize,
683
+ )
684
+ if latents is not None:
685
+ assert (
686
+ latents.shape == latent_shape
687
+ ), f"Latents have to be of shape {latent_shape} but are {latents.shape}."
688
+ latents = latents.to(device=device, dtype=dtype)
689
+
690
+ # For backward compatibility, generate in the "patchified" shape and rearrange
691
+ b, c, f, h, w = latent_shape
692
+ noise = randn_tensor(
693
+ (b, f * h * w, c), generator=generator, device=device, dtype=dtype
694
+ )
695
+ noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w)
696
+
697
+ # scale the initial noise by the standard deviation required by the scheduler
698
+ noise = noise * self.scheduler.init_noise_sigma
699
+
700
+ if latents is None:
701
+ latents = noise
702
+ else:
703
+ # Noise the latents to the required (first) timestep
704
+ latents = timestep * noise + (1 - timestep) * latents
705
+
706
+ return latents
707
+
708
+ @staticmethod
709
+ def classify_height_width_bin(
710
+ height: int, width: int, ratios: dict
711
+ ) -> Tuple[int, int]:
712
+ """Returns binned height and width."""
713
+ ar = float(height / width)
714
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
715
+ default_hw = ratios[closest_ratio]
716
+ return int(default_hw[0]), int(default_hw[1])
717
+
718
+ @staticmethod
719
+ def resize_and_crop_tensor(
720
+ samples: torch.Tensor, new_width: int, new_height: int
721
+ ) -> torch.Tensor:
722
+ n_frames, orig_height, orig_width = samples.shape[-3:]
723
+
724
+ # Check if resizing is needed
725
+ if orig_height != new_height or orig_width != new_width:
726
+ ratio = max(new_height / orig_height, new_width / orig_width)
727
+ resized_width = int(orig_width * ratio)
728
+ resized_height = int(orig_height * ratio)
729
+
730
+ # Resize
731
+ samples = LTXVideoPipeline.resize_tensor(
732
+ samples, resized_height, resized_width
733
+ )
734
+
735
+ # Center Crop
736
+ start_x = (resized_width - new_width) // 2
737
+ end_x = start_x + new_width
738
+ start_y = (resized_height - new_height) // 2
739
+ end_y = start_y + new_height
740
+ samples = samples[..., start_y:end_y, start_x:end_x]
741
+
742
+ return samples
743
+
744
+ @staticmethod
745
+ def resize_tensor(media_items, height, width):
746
+ n_frames = media_items.shape[2]
747
+ if media_items.shape[-2:] != (height, width):
748
+ media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
749
+ media_items = F.interpolate(
750
+ media_items,
751
+ size=(height, width),
752
+ mode="bilinear",
753
+ align_corners=False,
754
+ )
755
+ media_items = rearrange(media_items, "(b n) c h w -> b c n h w", n=n_frames)
756
+ return media_items
757
+
758
+ @torch.no_grad()
759
+ def __call__(
760
+ self,
761
+ height: int,
762
+ width: int,
763
+ num_frames: int,
764
+ frame_rate: float,
765
+ prompt: Union[str, List[str]] = None,
766
+ negative_prompt: str = "",
767
+ num_inference_steps: int = 20,
768
+ skip_initial_inference_steps: int = 0,
769
+ skip_final_inference_steps: int = 0,
770
+ timesteps: List[int] = None,
771
+ guidance_scale: Union[float, List[float]] = 4.5,
772
+ cfg_star_rescale: bool = False,
773
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
774
+ skip_block_list: Optional[Union[List[List[int]], List[int]]] = None,
775
+ stg_scale: Union[float, List[float]] = 1.0,
776
+ rescaling_scale: Union[float, List[float]] = 0.7,
777
+ guidance_timesteps: Optional[List[int]] = None,
778
+ num_images_per_prompt: Optional[int] = 1,
779
+ eta: float = 0.0,
780
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
781
+ latents: Optional[torch.FloatTensor] = None,
782
+ prompt_embeds: Optional[torch.FloatTensor] = None,
783
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
784
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
785
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
786
+ output_type: Optional[str] = "pil",
787
+ return_dict: bool = True,
788
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
789
+ conditioning_items: Optional[List[ConditioningItem]] = None,
790
+ decode_timestep: Union[List[float], float] = 0.0,
791
+ decode_noise_scale: Optional[List[float]] = None,
792
+ mixed_precision: bool = False,
793
+ offload_to_cpu: bool = False,
794
+ enhance_prompt: bool = False,
795
+ text_encoder_max_tokens: int = 256,
796
+ stochastic_sampling: bool = False,
797
+ media_items: Optional[torch.Tensor] = None,
798
+ **kwargs,
799
+ ) -> Union[ImagePipelineOutput, Tuple]:
800
+ """
801
+ Function invoked when calling the pipeline for generation.
802
+
803
+ Args:
804
+ prompt (`str` or `List[str]`, *optional*):
805
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
806
+ instead.
807
+ negative_prompt (`str` or `List[str]`, *optional*):
808
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
809
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
810
+ less than `1`).
811
+ num_inference_steps (`int`, *optional*, defaults to 100):
812
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
813
+ expense of slower inference. If `timesteps` is provided, this parameter is ignored.
814
+ skip_initial_inference_steps (`int`, *optional*, defaults to 0):
815
+ The number of initial timesteps to skip. After calculating the timesteps, this number of timesteps will
816
+ be removed from the beginning of the timesteps list. Meaning the highest-timesteps values will not run.
817
+ skip_final_inference_steps (`int`, *optional*, defaults to 0):
818
+ The number of final timesteps to skip. After calculating the timesteps, this number of timesteps will
819
+ be removed from the end of the timesteps list. Meaning the lowest-timesteps values will not run.
820
+ timesteps (`List[int]`, *optional*):
821
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
822
+ timesteps are used. Must be in descending order.
823
+ guidance_scale (`float`, *optional*, defaults to 4.5):
824
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
825
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
826
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
827
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
828
+ usually at the expense of lower image quality.
829
+ cfg_star_rescale (`bool`, *optional*, defaults to `False`):
830
+ If set to `True`, applies the CFG star rescale. Scales the negative prediction according to dot
831
+ product between positive and negative.
832
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
833
+ The number of images to generate per prompt.
834
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
835
+ The height in pixels of the generated image.
836
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
837
+ The width in pixels of the generated image.
838
+ eta (`float`, *optional*, defaults to 0.0):
839
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
840
+ [`schedulers.DDIMScheduler`], will be ignored for others.
841
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
842
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
843
+ to make generation deterministic.
844
+ latents (`torch.FloatTensor`, *optional*):
845
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
846
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
847
+ tensor will ge generated by sampling using the supplied random `generator`.
848
+ prompt_embeds (`torch.FloatTensor`, *optional*):
849
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
850
+ provided, text embeddings will be generated from `prompt` input argument.
851
+ prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
852
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
853
+ Pre-generated negative text embeddings. This negative prompt should be "". If not
854
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
855
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
856
+ Pre-generated attention mask for negative text embeddings.
857
+ output_type (`str`, *optional*, defaults to `"pil"`):
858
+ The output format of the generate image. Choose between
859
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
860
+ return_dict (`bool`, *optional*, defaults to `True`):
861
+ Whether to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
862
+ callback_on_step_end (`Callable`, *optional*):
863
+ A function that calls at the end of each denoising steps during the inference. The function is called
864
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
865
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
866
+ `callback_on_step_end_tensor_inputs`.
867
+ use_resolution_binning (`bool` defaults to `True`):
868
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
869
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
870
+ the requested resolution. Useful for generating non-square images.
871
+ enhance_prompt (`bool`, *optional*, defaults to `False`):
872
+ If set to `True`, the prompt is enhanced using a LLM model.
873
+ text_encoder_max_tokens (`int`, *optional*, defaults to `256`):
874
+ The maximum number of tokens to use for the text encoder.
875
+ stochastic_sampling (`bool`, *optional*, defaults to `False`):
876
+ If set to `True`, the sampling is stochastic. If set to `False`, the sampling is deterministic.
877
+ media_items ('torch.Tensor', *optional*):
878
+ The input media item used for image-to-image / video-to-video.
879
+ Examples:
880
+
881
+ Returns:
882
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
883
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
884
+ returned where the first element is a list with the generated images
885
+ """
886
+ if "mask_feature" in kwargs:
887
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
888
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
889
+
890
+ is_video = kwargs.get("is_video", False)
891
+ self.check_inputs(
892
+ prompt,
893
+ height,
894
+ width,
895
+ negative_prompt,
896
+ prompt_embeds,
897
+ negative_prompt_embeds,
898
+ prompt_attention_mask,
899
+ negative_prompt_attention_mask,
900
+ )
901
+
902
+ # 2. Default height and width to transformer
903
+ if prompt is not None and isinstance(prompt, str):
904
+ batch_size = 1
905
+ elif prompt is not None and isinstance(prompt, list):
906
+ batch_size = len(prompt)
907
+ else:
908
+ batch_size = prompt_embeds.shape[0]
909
+
910
+ device = self._execution_device
911
+
912
+ self.video_scale_factor = self.video_scale_factor if is_video else 1
913
+ vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", True)
914
+ image_cond_noise_scale = kwargs.get("image_cond_noise_scale", 0.0)
915
+
916
+ latent_height = height // self.vae_scale_factor
917
+ latent_width = width // self.vae_scale_factor
918
+ latent_num_frames = num_frames // self.video_scale_factor
919
+ if isinstance(self.vae, CausalVideoAutoencoder) and is_video:
920
+ latent_num_frames += 1
921
+ latent_shape = (
922
+ batch_size * num_images_per_prompt,
923
+ self.transformer.config.in_channels,
924
+ latent_num_frames,
925
+ latent_height,
926
+ latent_width,
927
+ )
928
+
929
+ # Prepare the list of denoising time-steps
930
+
931
+ retrieve_timesteps_kwargs = {}
932
+ if isinstance(self.scheduler, TimestepShifter):
933
+ retrieve_timesteps_kwargs["samples_shape"] = latent_shape
934
+
935
+ assert (
936
+ skip_initial_inference_steps == 0
937
+ or latents is not None
938
+ or media_items is not None
939
+ ), (
940
+ f"skip_initial_inference_steps ({skip_initial_inference_steps}) is used for image-to-image/video-to-video - "
941
+ "media_item or latents should be provided."
942
+ )
943
+
944
+ timesteps, num_inference_steps = retrieve_timesteps(
945
+ self.scheduler,
946
+ num_inference_steps,
947
+ device,
948
+ timesteps,
949
+ skip_initial_inference_steps=skip_initial_inference_steps,
950
+ skip_final_inference_steps=skip_final_inference_steps,
951
+ **retrieve_timesteps_kwargs,
952
+ )
953
+
954
+ if self.allowed_inference_steps is not None:
955
+ for timestep in [round(x, 4) for x in timesteps.tolist()]:
956
+ assert (
957
+ timestep in self.allowed_inference_steps
958
+ ), f"Invalid inference timestep {timestep}. Allowed timesteps are {self.allowed_inference_steps}."
959
+
960
+ if guidance_timesteps:
961
+ guidance_mapping = []
962
+ for timestep in timesteps:
963
+ indices = [
964
+ i for i, val in enumerate(guidance_timesteps) if val <= timestep
965
+ ]
966
+ # assert len(indices) > 0, f"No guidance timestep found for {timestep}"
967
+ guidance_mapping.append(
968
+ indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1)
969
+ )
970
+
971
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
972
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
973
+ # corresponds to doing no classifier free guidance.
974
+ if not isinstance(guidance_scale, List):
975
+ guidance_scale = [guidance_scale] * len(timesteps)
976
+ else:
977
+ guidance_scale = [
978
+ guidance_scale[guidance_mapping[i]] for i in range(len(timesteps))
979
+ ]
980
+
981
+ # For simplicity, we are using a constant num_conds for all timesteps, so we need to zero
982
+ # out cases where the guidance scale should not be applied.
983
+ guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale]
984
+
985
+ if not isinstance(stg_scale, List):
986
+ stg_scale = [stg_scale] * len(timesteps)
987
+ else:
988
+ stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(timesteps))]
989
+
990
+ if not isinstance(rescaling_scale, List):
991
+ rescaling_scale = [rescaling_scale] * len(timesteps)
992
+ else:
993
+ rescaling_scale = [
994
+ rescaling_scale[guidance_mapping[i]] for i in range(len(timesteps))
995
+ ]
996
+
997
+ do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale)
998
+ do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale)
999
+ do_rescaling = any(x != 1.0 for x in rescaling_scale)
1000
+
1001
+ num_conds = 1
1002
+ if do_classifier_free_guidance:
1003
+ num_conds += 1
1004
+ if do_spatio_temporal_guidance:
1005
+ num_conds += 1
1006
+
1007
+ # Normalize skip_block_list to always be None or a list of lists matching timesteps
1008
+ if skip_block_list is not None:
1009
+ # Convert single list to list of lists if needed
1010
+ if len(skip_block_list) == 0 or not isinstance(skip_block_list[0], list):
1011
+ skip_block_list = [skip_block_list] * len(timesteps)
1012
+ else:
1013
+ new_skip_block_list = []
1014
+ for i, timestep in enumerate(timesteps):
1015
+ new_skip_block_list.append(skip_block_list[guidance_mapping[i]])
1016
+ skip_block_list = new_skip_block_list
1017
+
1018
+ # Prepare skip layer masks
1019
+ skip_layer_masks: Optional[List[torch.Tensor]] = None
1020
+ if do_spatio_temporal_guidance:
1021
+ if skip_block_list is not None:
1022
+ skip_layer_masks = [
1023
+ self.transformer.create_skip_layer_mask(
1024
+ batch_size, num_conds, num_conds - 1, skip_blocks
1025
+ )
1026
+ for skip_blocks in skip_block_list
1027
+ ]
1028
+
1029
+ if enhance_prompt:
1030
+ self.prompt_enhancer_image_caption_model = (
1031
+ self.prompt_enhancer_image_caption_model.to(self._execution_device)
1032
+ )
1033
+ self.prompt_enhancer_llm_model = self.prompt_enhancer_llm_model.to(
1034
+ self._execution_device
1035
+ )
1036
+
1037
+ prompt = generate_cinematic_prompt(
1038
+ self.prompt_enhancer_image_caption_model,
1039
+ self.prompt_enhancer_image_caption_processor,
1040
+ self.prompt_enhancer_llm_model,
1041
+ self.prompt_enhancer_llm_tokenizer,
1042
+ prompt,
1043
+ conditioning_items,
1044
+ max_new_tokens=text_encoder_max_tokens,
1045
+ )
1046
+
1047
+ # 3. Encode input prompt
1048
+ if self.text_encoder is not None:
1049
+ self.text_encoder = self.text_encoder.to(self._execution_device)
1050
+
1051
+ (
1052
+ prompt_embeds,
1053
+ prompt_attention_mask,
1054
+ negative_prompt_embeds,
1055
+ negative_prompt_attention_mask,
1056
+ ) = self.encode_prompt(
1057
+ prompt,
1058
+ do_classifier_free_guidance,
1059
+ negative_prompt=negative_prompt,
1060
+ num_images_per_prompt=num_images_per_prompt,
1061
+ device=device,
1062
+ prompt_embeds=prompt_embeds,
1063
+ negative_prompt_embeds=negative_prompt_embeds,
1064
+ prompt_attention_mask=prompt_attention_mask,
1065
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
1066
+ text_encoder_max_tokens=text_encoder_max_tokens,
1067
+ )
1068
+
1069
+ if offload_to_cpu and self.text_encoder is not None:
1070
+ self.text_encoder = self.text_encoder.cpu()
1071
+
1072
+ self.transformer = self.transformer.to(self._execution_device)
1073
+
1074
+ prompt_embeds_batch = prompt_embeds
1075
+ prompt_attention_mask_batch = prompt_attention_mask
1076
+ if do_classifier_free_guidance:
1077
+ prompt_embeds_batch = torch.cat(
1078
+ [negative_prompt_embeds, prompt_embeds], dim=0
1079
+ )
1080
+ prompt_attention_mask_batch = torch.cat(
1081
+ [negative_prompt_attention_mask, prompt_attention_mask], dim=0
1082
+ )
1083
+ if do_spatio_temporal_guidance:
1084
+ prompt_embeds_batch = torch.cat([prompt_embeds_batch, prompt_embeds], dim=0)
1085
+ prompt_attention_mask_batch = torch.cat(
1086
+ [
1087
+ prompt_attention_mask_batch,
1088
+ prompt_attention_mask,
1089
+ ],
1090
+ dim=0,
1091
+ )
1092
+
1093
+ # 4. Prepare the initial latents using the provided media and conditioning items
1094
+
1095
+ # Prepare the initial latents tensor, shape = (b, c, f, h, w)
1096
+ latents = self.prepare_latents(
1097
+ latents=latents,
1098
+ media_items=media_items,
1099
+ timestep=timesteps[0],
1100
+ latent_shape=latent_shape,
1101
+ dtype=prompt_embeds_batch.dtype,
1102
+ device=device,
1103
+ generator=generator,
1104
+ vae_per_channel_normalize=vae_per_channel_normalize,
1105
+ )
1106
+
1107
+ # Update the latents with the conditioning items and patchify them into (b, n, c)
1108
+ latents, pixel_coords, conditioning_mask, num_cond_latents = (
1109
+ self.prepare_conditioning(
1110
+ conditioning_items=conditioning_items,
1111
+ init_latents=latents,
1112
+ num_frames=num_frames,
1113
+ height=height,
1114
+ width=width,
1115
+ vae_per_channel_normalize=vae_per_channel_normalize,
1116
+ generator=generator,
1117
+ )
1118
+ )
1119
+ init_latents = latents.clone() # Used for image_cond_noise_update
1120
+
1121
+ pixel_coords = torch.cat([pixel_coords] * num_conds)
1122
+ orig_conditioning_mask = conditioning_mask
1123
+ if conditioning_mask is not None and is_video:
1124
+ assert num_images_per_prompt == 1
1125
+ conditioning_mask = torch.cat([conditioning_mask] * num_conds)
1126
+ fractional_coords = pixel_coords.to(torch.float32)
1127
+ fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
1128
+
1129
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1130
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1131
+
1132
+ # 7. Denoising loop
1133
+ num_warmup_steps = max(
1134
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
1135
+ )
1136
+
1137
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1138
+ for i, t in enumerate(timesteps):
1139
+ if conditioning_mask is not None and image_cond_noise_scale > 0.0:
1140
+ latents = self.add_noise_to_image_conditioning_latents(
1141
+ t,
1142
+ init_latents,
1143
+ latents,
1144
+ image_cond_noise_scale,
1145
+ orig_conditioning_mask,
1146
+ generator,
1147
+ )
1148
+
1149
+ latent_model_input = (
1150
+ torch.cat([latents] * num_conds) if num_conds > 1 else latents
1151
+ )
1152
+ latent_model_input = self.scheduler.scale_model_input(
1153
+ latent_model_input, t
1154
+ )
1155
+
1156
+ current_timestep = t
1157
+ if not torch.is_tensor(current_timestep):
1158
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1159
+ # This would be a good case for the `match` statement (Python 3.10+)
1160
+ is_mps = latent_model_input.device.type == "mps"
1161
+ if isinstance(current_timestep, float):
1162
+ dtype = torch.float32 if is_mps else torch.float64
1163
+ else:
1164
+ dtype = torch.int32 if is_mps else torch.int64
1165
+ current_timestep = torch.tensor(
1166
+ [current_timestep],
1167
+ dtype=dtype,
1168
+ device=latent_model_input.device,
1169
+ )
1170
+ elif len(current_timestep.shape) == 0:
1171
+ current_timestep = current_timestep[None].to(
1172
+ latent_model_input.device
1173
+ )
1174
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1175
+ current_timestep = current_timestep.expand(
1176
+ latent_model_input.shape[0]
1177
+ ).unsqueeze(-1)
1178
+
1179
+ if conditioning_mask is not None:
1180
+ # Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask)
1181
+ # and will start to be denoised when the current timestep is lower than their conditioning timestep.
1182
+ current_timestep = torch.min(
1183
+ current_timestep, 1.0 - conditioning_mask
1184
+ )
1185
+
1186
+ # Choose the appropriate context manager based on `mixed_precision`
1187
+ if mixed_precision:
1188
+ context_manager = torch.autocast(device.type, dtype=torch.bfloat16)
1189
+ else:
1190
+ context_manager = nullcontext() # Dummy context manager
1191
+
1192
+ # predict noise model_output
1193
+ with context_manager:
1194
+ noise_pred = self.transformer(
1195
+ latent_model_input.to(self.transformer.dtype),
1196
+ indices_grid=fractional_coords,
1197
+ encoder_hidden_states=prompt_embeds_batch.to(
1198
+ self.transformer.dtype
1199
+ ),
1200
+ encoder_attention_mask=prompt_attention_mask_batch,
1201
+ timestep=current_timestep,
1202
+ skip_layer_mask=(
1203
+ skip_layer_masks[i]
1204
+ if skip_layer_masks is not None
1205
+ else None
1206
+ ),
1207
+ skip_layer_strategy=skip_layer_strategy,
1208
+ return_dict=False,
1209
+ )[0]
1210
+
1211
+ # perform guidance
1212
+ if do_spatio_temporal_guidance:
1213
+ noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(
1214
+ num_conds
1215
+ )[-2:]
1216
+ if do_classifier_free_guidance:
1217
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_conds)[:2]
1218
+
1219
+ if cfg_star_rescale:
1220
+ # Rescales the unconditional noise prediction using the projection of the conditional prediction onto it:
1221
+ # α = (⟨ε_text, ε_uncond⟩ / ||ε_uncond||²), then ε_uncond ← α * ε_uncond
1222
+ # where ε_text is the conditional noise prediction and ε_uncond is the unconditional one.
1223
+ positive_flat = noise_pred_text.view(batch_size, -1)
1224
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
1225
+ dot_product = torch.sum(
1226
+ positive_flat * negative_flat, dim=1, keepdim=True
1227
+ )
1228
+ squared_norm = (
1229
+ torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
1230
+ )
1231
+ alpha = dot_product / squared_norm
1232
+ noise_pred_uncond = alpha * noise_pred_uncond
1233
+
1234
+ noise_pred = noise_pred_uncond + guidance_scale[i] * (
1235
+ noise_pred_text - noise_pred_uncond
1236
+ )
1237
+ elif do_spatio_temporal_guidance:
1238
+ noise_pred = noise_pred_text
1239
+ if do_spatio_temporal_guidance:
1240
+ noise_pred = noise_pred + stg_scale[i] * (
1241
+ noise_pred_text - noise_pred_text_perturb
1242
+ )
1243
+ if do_rescaling and stg_scale[i] > 0.0:
1244
+ noise_pred_text_std = noise_pred_text.view(batch_size, -1).std(
1245
+ dim=1, keepdim=True
1246
+ )
1247
+ noise_pred_std = noise_pred.view(batch_size, -1).std(
1248
+ dim=1, keepdim=True
1249
+ )
1250
+
1251
+ factor = noise_pred_text_std / noise_pred_std
1252
+ factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i])
1253
+
1254
+ noise_pred = noise_pred * factor.view(batch_size, 1, 1)
1255
+
1256
+ current_timestep = current_timestep[:1]
1257
+ # learned sigma
1258
+ if (
1259
+ self.transformer.config.out_channels // 2
1260
+ == self.transformer.config.in_channels
1261
+ ):
1262
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
1263
+
1264
+ # compute previous image: x_t -> x_t-1
1265
+ latents = self.denoising_step(
1266
+ latents,
1267
+ noise_pred,
1268
+ current_timestep,
1269
+ orig_conditioning_mask,
1270
+ t,
1271
+ extra_step_kwargs,
1272
+ stochastic_sampling=stochastic_sampling,
1273
+ )
1274
+
1275
+ # call the callback, if provided
1276
+ if i == len(timesteps) - 1 or (
1277
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1278
+ ):
1279
+ progress_bar.update()
1280
+
1281
+ if callback_on_step_end is not None:
1282
+ callback_on_step_end(self, i, t, {})
1283
+
1284
+ if offload_to_cpu:
1285
+ self.transformer = self.transformer.cpu()
1286
+ if self._execution_device == "cuda":
1287
+ torch.cuda.empty_cache()
1288
+
1289
+ # Remove the added conditioning latents
1290
+ latents = latents[:, num_cond_latents:]
1291
+
1292
+ latents = self.patchifier.unpatchify(
1293
+ latents=latents,
1294
+ output_height=latent_height,
1295
+ output_width=latent_width,
1296
+ out_channels=self.transformer.in_channels
1297
+ // math.prod(self.patchifier.patch_size),
1298
+ )
1299
+ if output_type != "latent":
1300
+ if self.vae.decoder.timestep_conditioning:
1301
+ noise = torch.randn_like(latents)
1302
+ if not isinstance(decode_timestep, list):
1303
+ decode_timestep = [decode_timestep] * latents.shape[0]
1304
+ if decode_noise_scale is None:
1305
+ decode_noise_scale = decode_timestep
1306
+ elif not isinstance(decode_noise_scale, list):
1307
+ decode_noise_scale = [decode_noise_scale] * latents.shape[0]
1308
+
1309
+ decode_timestep = torch.tensor(decode_timestep).to(latents.device)
1310
+ decode_noise_scale = torch.tensor(decode_noise_scale).to(
1311
+ latents.device
1312
+ )[:, None, None, None, None]
1313
+ latents = (
1314
+ latents * (1 - decode_noise_scale) + noise * decode_noise_scale
1315
+ )
1316
+ else:
1317
+ decode_timestep = None
1318
+ image = vae_decode(
1319
+ latents,
1320
+ self.vae,
1321
+ is_video,
1322
+ vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
1323
+ timestep=decode_timestep,
1324
+ )
1325
+
1326
+ image = self.image_processor.postprocess(image, output_type=output_type)
1327
+
1328
+ else:
1329
+ image = latents
1330
+
1331
+ # Offload all models
1332
+ self.maybe_free_model_hooks()
1333
+
1334
+ if not return_dict:
1335
+ return (image,)
1336
+
1337
+ return ImagePipelineOutput(images=image)
1338
+
1339
+ def denoising_step(
1340
+ self,
1341
+ latents: torch.Tensor,
1342
+ noise_pred: torch.Tensor,
1343
+ current_timestep: torch.Tensor,
1344
+ conditioning_mask: torch.Tensor,
1345
+ t: float,
1346
+ extra_step_kwargs,
1347
+ t_eps=1e-6,
1348
+ stochastic_sampling=False,
1349
+ ):
1350
+ """
1351
+ Perform the denoising step for the required tokens, based on the current timestep and
1352
+ conditioning mask:
1353
+ Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask)
1354
+ and will start to be denoised when the current timestep is equal or lower than their
1355
+ conditioning timestep.
1356
+ (hard-conditioning latents with conditioning_mask = 1.0 are never denoised)
1357
+ """
1358
+ # Denoise the latents using the scheduler
1359
+ denoised_latents = self.scheduler.step(
1360
+ noise_pred,
1361
+ t if current_timestep is None else current_timestep,
1362
+ latents,
1363
+ **extra_step_kwargs,
1364
+ return_dict=False,
1365
+ stochastic_sampling=stochastic_sampling,
1366
+ )[0]
1367
+
1368
+ if conditioning_mask is None:
1369
+ return denoised_latents
1370
+
1371
+ tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
1372
+ return torch.where(tokens_to_denoise_mask, denoised_latents, latents)
1373
+
1374
+ def prepare_conditioning(
1375
+ self,
1376
+ conditioning_items: Optional[List[ConditioningItem]],
1377
+ init_latents: torch.Tensor,
1378
+ num_frames: int,
1379
+ height: int,
1380
+ width: int,
1381
+ vae_per_channel_normalize: bool = False,
1382
+ generator=None,
1383
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
1384
+ """
1385
+ Prepare conditioning tokens based on the provided conditioning items.
1386
+
1387
+ This method encodes provided conditioning items (video frames or single frames) into latents
1388
+ and integrates them with the initial latent tensor. It also calculates corresponding pixel
1389
+ coordinates, a mask indicating the influence of conditioning latents, and the total number of
1390
+ conditioning latents.
1391
+
1392
+ Args:
1393
+ conditioning_items (Optional[List[ConditioningItem]]): A list of ConditioningItem objects.
1394
+ init_latents (torch.Tensor): The initial latent tensor of shape (b, c, f_l, h_l, w_l), where
1395
+ `f_l` is the number of latent frames, and `h_l` and `w_l` are latent spatial dimensions.
1396
+ num_frames, height, width: The dimensions of the generated video.
1397
+ vae_per_channel_normalize (bool, optional): Whether to normalize channels during VAE encoding.
1398
+ Defaults to `False`.
1399
+ generator: The random generator
1400
+
1401
+ Returns:
1402
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
1403
+ - `init_latents` (torch.Tensor): The updated latent tensor including conditioning latents,
1404
+ patchified into (b, n, c) shape.
1405
+ - `init_pixel_coords` (torch.Tensor): The pixel coordinates corresponding to the updated
1406
+ latent tensor.
1407
+ - `conditioning_mask` (torch.Tensor): A mask indicating the conditioning-strength of each
1408
+ latent token.
1409
+ - `num_cond_latents` (int): The total number of latent tokens added from conditioning items.
1410
+
1411
+ Raises:
1412
+ AssertionError: If input shapes, dimensions, or conditions for applying conditioning are invalid.
1413
+ """
1414
+ assert isinstance(self.vae, CausalVideoAutoencoder)
1415
+
1416
+ if conditioning_items:
1417
+ batch_size, _, num_latent_frames = init_latents.shape[:3]
1418
+
1419
+ init_conditioning_mask = torch.zeros(
1420
+ init_latents[:, 0, :, :, :].shape,
1421
+ dtype=torch.float32,
1422
+ device=init_latents.device,
1423
+ )
1424
+
1425
+ extra_conditioning_latents = []
1426
+ extra_conditioning_pixel_coords = []
1427
+ extra_conditioning_mask = []
1428
+ extra_conditioning_num_latents = 0 # Number of extra conditioning latents added (should be removed before decoding)
1429
+
1430
+ # Process each conditioning item
1431
+ for conditioning_item in conditioning_items:
1432
+ conditioning_item = self._resize_conditioning_item(
1433
+ conditioning_item, height, width
1434
+ )
1435
+ media_item = conditioning_item.media_item
1436
+ media_frame_number = conditioning_item.media_frame_number
1437
+ strength = conditioning_item.conditioning_strength
1438
+ assert media_item.ndim == 5 # (b, c, f, h, w)
1439
+ b, c, n_frames, h, w = media_item.shape
1440
+ assert (
1441
+ height == h and width == w
1442
+ ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
1443
+ assert n_frames % 8 == 1
1444
+ assert (
1445
+ media_frame_number >= 0
1446
+ and media_frame_number + n_frames <= num_frames
1447
+ )
1448
+
1449
+ # Encode the provided conditioning media item
1450
+ media_item_latents = vae_encode(
1451
+ media_item.to(dtype=self.vae.dtype, device=self.vae.device),
1452
+ self.vae,
1453
+ vae_per_channel_normalize=vae_per_channel_normalize,
1454
+ ).to(dtype=init_latents.dtype)
1455
+
1456
+ # Handle the different conditioning cases
1457
+ if media_frame_number == 0:
1458
+ # Get the target spatial position of the latent conditioning item
1459
+ media_item_latents, l_x, l_y = self._get_latent_spatial_position(
1460
+ media_item_latents,
1461
+ conditioning_item,
1462
+ height,
1463
+ width,
1464
+ strip_latent_border=True,
1465
+ )
1466
+ b, c_l, f_l, h_l, w_l = media_item_latents.shape
1467
+
1468
+ # First frame or sequence - just update the initial noise latents and the mask
1469
+ init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = (
1470
+ torch.lerp(
1471
+ init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l],
1472
+ media_item_latents,
1473
+ strength,
1474
+ )
1475
+ )
1476
+ init_conditioning_mask[
1477
+ :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l
1478
+ ] = strength
1479
+ else:
1480
+ # Non-first frame or sequence
1481
+ if n_frames > 1:
1482
+ # Handle non-first sequence.
1483
+ # Encoded latents are either fully consumed, or the prefix is handled separately below.
1484
+ (
1485
+ init_latents,
1486
+ init_conditioning_mask,
1487
+ media_item_latents,
1488
+ ) = self._handle_non_first_conditioning_sequence(
1489
+ init_latents,
1490
+ init_conditioning_mask,
1491
+ media_item_latents,
1492
+ media_frame_number,
1493
+ strength,
1494
+ )
1495
+
1496
+ # Single frame or sequence-prefix latents
1497
+ if media_item_latents is not None:
1498
+ noise = randn_tensor(
1499
+ media_item_latents.shape,
1500
+ generator=generator,
1501
+ device=media_item_latents.device,
1502
+ dtype=media_item_latents.dtype,
1503
+ )
1504
+
1505
+ media_item_latents = torch.lerp(
1506
+ noise, media_item_latents, strength
1507
+ )
1508
+
1509
+ # Patchify the extra conditioning latents and calculate their pixel coordinates
1510
+ media_item_latents, latent_coords = self.patchifier.patchify(
1511
+ latents=media_item_latents
1512
+ )
1513
+ pixel_coords = latent_to_pixel_coords(
1514
+ latent_coords,
1515
+ self.vae,
1516
+ causal_fix=self.transformer.config.causal_temporal_positioning,
1517
+ )
1518
+
1519
+ # Update the frame numbers to match the target frame number
1520
+ pixel_coords[:, 0] += media_frame_number
1521
+ extra_conditioning_num_latents += media_item_latents.shape[1]
1522
+
1523
+ conditioning_mask = torch.full(
1524
+ media_item_latents.shape[:2],
1525
+ strength,
1526
+ dtype=torch.float32,
1527
+ device=init_latents.device,
1528
+ )
1529
+
1530
+ extra_conditioning_latents.append(media_item_latents)
1531
+ extra_conditioning_pixel_coords.append(pixel_coords)
1532
+ extra_conditioning_mask.append(conditioning_mask)
1533
+
1534
+ # Patchify the updated latents and calculate their pixel coordinates
1535
+ init_latents, init_latent_coords = self.patchifier.patchify(
1536
+ latents=init_latents
1537
+ )
1538
+ init_pixel_coords = latent_to_pixel_coords(
1539
+ init_latent_coords,
1540
+ self.vae,
1541
+ causal_fix=self.transformer.config.causal_temporal_positioning,
1542
+ )
1543
+
1544
+ if not conditioning_items:
1545
+ return init_latents, init_pixel_coords, None, 0
1546
+
1547
+ init_conditioning_mask, _ = self.patchifier.patchify(
1548
+ latents=init_conditioning_mask.unsqueeze(1)
1549
+ )
1550
+ init_conditioning_mask = init_conditioning_mask.squeeze(-1)
1551
+
1552
+ if extra_conditioning_latents:
1553
+ # Stack the extra conditioning latents, pixel coordinates and mask
1554
+ init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
1555
+ init_pixel_coords = torch.cat(
1556
+ [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2
1557
+ )
1558
+ init_conditioning_mask = torch.cat(
1559
+ [*extra_conditioning_mask, init_conditioning_mask], dim=1
1560
+ )
1561
+
1562
+ if self.transformer.use_tpu_flash_attention:
1563
+ # When flash attention is used, keep the original number of tokens by removing
1564
+ # tokens from the end.
1565
+ init_latents = init_latents[:, :-extra_conditioning_num_latents]
1566
+ init_pixel_coords = init_pixel_coords[
1567
+ :, :, :-extra_conditioning_num_latents
1568
+ ]
1569
+ init_conditioning_mask = init_conditioning_mask[
1570
+ :, :-extra_conditioning_num_latents
1571
+ ]
1572
+
1573
+ return (
1574
+ init_latents,
1575
+ init_pixel_coords,
1576
+ init_conditioning_mask,
1577
+ extra_conditioning_num_latents,
1578
+ )
1579
+
1580
+ @staticmethod
1581
+ def _resize_conditioning_item(
1582
+ conditioning_item: ConditioningItem,
1583
+ height: int,
1584
+ width: int,
1585
+ ):
1586
+ if conditioning_item.media_x or conditioning_item.media_y:
1587
+ raise ValueError(
1588
+ "Provide media_item in the target size for spatial conditioning."
1589
+ )
1590
+ new_conditioning_item = copy.copy(conditioning_item)
1591
+ new_conditioning_item.media_item = LTXVideoPipeline.resize_tensor(
1592
+ conditioning_item.media_item, height, width
1593
+ )
1594
+ return new_conditioning_item
1595
+
1596
+ def _get_latent_spatial_position(
1597
+ self,
1598
+ latents: torch.Tensor,
1599
+ conditioning_item: ConditioningItem,
1600
+ height: int,
1601
+ width: int,
1602
+ strip_latent_border,
1603
+ ):
1604
+ """
1605
+ Get the spatial position of the conditioning item in the latent space.
1606
+ If requested, strip the conditioning latent borders that do not align with target borders.
1607
+ (border latents look different then other latents and might confuse the model)
1608
+ """
1609
+ scale = self.vae_scale_factor
1610
+ h, w = conditioning_item.media_item.shape[-2:]
1611
+ assert (
1612
+ h <= height and w <= width
1613
+ ), f"Conditioning item size {h}x{w} is larger than target size {height}x{width}"
1614
+ assert h % scale == 0 and w % scale == 0
1615
+
1616
+ # Compute the start and end spatial positions of the media item
1617
+ x_start, y_start = conditioning_item.media_x, conditioning_item.media_y
1618
+ x_start = (width - w) // 2 if x_start is None else x_start
1619
+ y_start = (height - h) // 2 if y_start is None else y_start
1620
+ x_end, y_end = x_start + w, y_start + h
1621
+ assert (
1622
+ x_end <= width and y_end <= height
1623
+ ), f"Conditioning item {x_start}:{x_end}x{y_start}:{y_end} is out of bounds for target size {width}x{height}"
1624
+
1625
+ if strip_latent_border:
1626
+ # Strip one latent from left/right and/or top/bottom, update x, y accordingly
1627
+ if x_start > 0:
1628
+ x_start += scale
1629
+ latents = latents[:, :, :, :, 1:]
1630
+
1631
+ if y_start > 0:
1632
+ y_start += scale
1633
+ latents = latents[:, :, :, 1:, :]
1634
+
1635
+ if x_end < width:
1636
+ latents = latents[:, :, :, :, :-1]
1637
+
1638
+ if y_end < height:
1639
+ latents = latents[:, :, :, :-1, :]
1640
+
1641
+ return latents, x_start // scale, y_start // scale
1642
+
1643
+ @staticmethod
1644
+ def _handle_non_first_conditioning_sequence(
1645
+ init_latents: torch.Tensor,
1646
+ init_conditioning_mask: torch.Tensor,
1647
+ latents: torch.Tensor,
1648
+ media_frame_number: int,
1649
+ strength: float,
1650
+ num_prefix_latent_frames: int = 2,
1651
+ prefix_latents_mode: str = "concat",
1652
+ prefix_soft_conditioning_strength: float = 0.15,
1653
+ ):
1654
+ """
1655
+ Special handling for a conditioning sequence that does not start on the first frame.
1656
+ The special handling is required to allow a short encoded video to be used as middle
1657
+ (or last) sequence in a longer video.
1658
+ Args:
1659
+ init_latents (torch.Tensor): The initial noise latents to be updated.
1660
+ init_conditioning_mask (torch.Tensor): The initial conditioning mask to be updated.
1661
+ latents (torch.Tensor): The encoded conditioning item.
1662
+ media_frame_number (int): The target frame number of the first frame in the conditioning sequence.
1663
+ strength (float): The conditioning strength for the conditioning latents.
1664
+ num_prefix_latent_frames (int, optional): The length of the sequence prefix, to be handled
1665
+ separately. Defaults to 2.
1666
+ prefix_latents_mode (str, optional): Special treatment for prefix (boundary) latents.
1667
+ - "drop": Drop the prefix latents.
1668
+ - "soft": Use the prefix latents, but with soft-conditioning
1669
+ - "concat": Add the prefix latents as extra tokens (like single frames)
1670
+ prefix_soft_conditioning_strength (float, optional): The strength of the soft-conditioning for
1671
+ the prefix latents, relevant if `prefix_latents_mode` is "soft". Defaults to 0.1.
1672
+
1673
+ """
1674
+ f_l = latents.shape[2]
1675
+ f_l_p = num_prefix_latent_frames
1676
+ assert f_l >= f_l_p
1677
+ assert media_frame_number % 8 == 0
1678
+ if f_l > f_l_p:
1679
+ # Insert the conditioning latents **excluding the prefix** into the sequence
1680
+ f_l_start = media_frame_number // 8 + f_l_p
1681
+ f_l_end = f_l_start + f_l - f_l_p
1682
+ init_latents[:, :, f_l_start:f_l_end] = torch.lerp(
1683
+ init_latents[:, :, f_l_start:f_l_end],
1684
+ latents[:, :, f_l_p:],
1685
+ strength,
1686
+ )
1687
+ # Mark these latent frames as conditioning latents
1688
+ init_conditioning_mask[:, f_l_start:f_l_end] = strength
1689
+
1690
+ # Handle the prefix-latents
1691
+ if prefix_latents_mode == "soft":
1692
+ if f_l_p > 1:
1693
+ # Drop the first (single-frame) latent and soft-condition the remaining prefix
1694
+ f_l_start = media_frame_number // 8 + 1
1695
+ f_l_end = f_l_start + f_l_p - 1
1696
+ strength = min(prefix_soft_conditioning_strength, strength)
1697
+ init_latents[:, :, f_l_start:f_l_end] = torch.lerp(
1698
+ init_latents[:, :, f_l_start:f_l_end],
1699
+ latents[:, :, 1:f_l_p],
1700
+ strength,
1701
+ )
1702
+ # Mark these latent frames as conditioning latents
1703
+ init_conditioning_mask[:, f_l_start:f_l_end] = strength
1704
+ latents = None # No more latents to handle
1705
+ elif prefix_latents_mode == "drop":
1706
+ # Drop the prefix latents
1707
+ latents = None
1708
+ elif prefix_latents_mode == "concat":
1709
+ # Pass-on the prefix latents to be handled as extra conditioning frames
1710
+ latents = latents[:, :, :f_l_p]
1711
+ else:
1712
+ raise ValueError(f"Invalid prefix_latents_mode: {prefix_latents_mode}")
1713
+ return (
1714
+ init_latents,
1715
+ init_conditioning_mask,
1716
+ latents,
1717
+ )
1718
+
1719
+ def trim_conditioning_sequence(
1720
+ self, start_frame: int, sequence_num_frames: int, target_num_frames: int
1721
+ ):
1722
+ """
1723
+ Trim a conditioning sequence to the allowed number of frames.
1724
+
1725
+ Args:
1726
+ start_frame (int): The target frame number of the first frame in the sequence.
1727
+ sequence_num_frames (int): The number of frames in the sequence.
1728
+ target_num_frames (int): The target number of frames in the generated video.
1729
+
1730
+ Returns:
1731
+ int: updated sequence length
1732
+ """
1733
+ scale_factor = self.video_scale_factor
1734
+ num_frames = min(sequence_num_frames, target_num_frames - start_frame)
1735
+ # Trim down to a multiple of temporal_scale_factor frames plus 1
1736
+ num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
1737
+ return num_frames
1738
+
1739
+
1740
+ def adain_filter_latent(
1741
+ latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
1742
+ ):
1743
+ """
1744
+ Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on
1745
+ statistics from a reference latent tensor.
1746
+
1747
+ Args:
1748
+ latent (torch.Tensor): Input latents to normalize
1749
+ reference_latent (torch.Tensor): The reference latents providing style statistics.
1750
+ factor (float): Blending factor between original and transformed latent.
1751
+ Range: -10.0 to 10.0, Default: 1.0
1752
+
1753
+ Returns:
1754
+ torch.Tensor: The transformed latent tensor
1755
+ """
1756
+ result = latents.clone()
1757
+
1758
+ for i in range(latents.size(0)):
1759
+ for c in range(latents.size(1)):
1760
+ r_sd, r_mean = torch.std_mean(
1761
+ reference_latents[i, c], dim=None
1762
+ ) # index by original dim order
1763
+ i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
1764
+
1765
+ result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
1766
+
1767
+ result = torch.lerp(latents, result, factor)
1768
+ return result
1769
+
1770
+
1771
+ class LTXMultiScalePipeline:
1772
+ def _upsample_latents(
1773
+ self, latest_upsampler: LatentUpsampler, latents: torch.Tensor
1774
+ ):
1775
+ assert latents.device == latest_upsampler.device
1776
+
1777
+ latents = un_normalize_latents(
1778
+ latents, self.vae, vae_per_channel_normalize=True
1779
+ )
1780
+ upsampled_latents = latest_upsampler(latents)
1781
+ upsampled_latents = normalize_latents(
1782
+ upsampled_latents, self.vae, vae_per_channel_normalize=True
1783
+ )
1784
+ return upsampled_latents
1785
+
1786
+ def __init__(
1787
+ self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler
1788
+ ):
1789
+ self.video_pipeline = video_pipeline
1790
+ self.vae = video_pipeline.vae
1791
+ self.latent_upsampler = latent_upsampler
1792
+
1793
+ def __call__(
1794
+ self,
1795
+ downscale_factor: float,
1796
+ first_pass: dict,
1797
+ second_pass: dict,
1798
+ *args: Any,
1799
+ **kwargs: Any,
1800
+ ) -> Any:
1801
+ original_kwargs = kwargs.copy()
1802
+ original_output_type = kwargs["output_type"]
1803
+ original_width = kwargs["width"]
1804
+ original_height = kwargs["height"]
1805
+
1806
+ x_width = int(kwargs["width"] * downscale_factor)
1807
+ downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor)
1808
+ x_height = int(kwargs["height"] * downscale_factor)
1809
+ downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor)
1810
+
1811
+ kwargs["output_type"] = "latent"
1812
+ kwargs["width"] = downscaled_width
1813
+ kwargs["height"] = downscaled_height
1814
+ kwargs.update(**first_pass)
1815
+ result = self.video_pipeline(*args, **kwargs)
1816
+ latents = result.images
1817
+
1818
+ upsampled_latents = self._upsample_latents(self.latent_upsampler, latents)
1819
+ upsampled_latents = adain_filter_latent(
1820
+ latents=upsampled_latents, reference_latents=latents
1821
+ )
1822
+
1823
+ kwargs = original_kwargs
1824
+
1825
+ kwargs["latents"] = upsampled_latents
1826
+ kwargs["output_type"] = original_output_type
1827
+ kwargs["width"] = downscaled_width * 2
1828
+ kwargs["height"] = downscaled_height * 2
1829
+ kwargs.update(**second_pass)
1830
+
1831
+ result = self.video_pipeline(*args, **kwargs)
1832
+ if original_output_type != "latent":
1833
+ num_frames = result.images.shape[2]
1834
+ videos = rearrange(result.images, "b c f h w -> (b f) c h w")
1835
+
1836
+ videos = F.interpolate(
1837
+ videos,
1838
+ size=(original_height, original_width),
1839
+ mode="bilinear",
1840
+ align_corners=False,
1841
+ )
1842
+ videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames)
1843
+ result.images = videos
1844
+
1845
+ return result