eeuuia commited on
Commit
5ddc813
·
verified ·
1 Parent(s): 61634bf

Upload inference.py

Browse files
Files changed (1) hide show
  1. api/ltx/inference.py +785 -0
api/ltx/inference.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from diffusers.utils import logging
7
+ from typing import Optional, List, Union
8
+ import yaml
9
+
10
+ from huggingface_hub import logging
11
+
12
+
13
+
14
+ logging.set_verbosity_error()
15
+ logging.set_verbosity_warning()
16
+ logging.set_verbosity_info()
17
+ logging.set_verbosity_debug()
18
+
19
+
20
+
21
+ import imageio
22
+ import json
23
+ import numpy as np
24
+ import torch
25
+ import cv2
26
+ from safetensors import safe_open
27
+ from PIL import Image
28
+ from transformers import (
29
+ T5EncoderModel,
30
+ T5Tokenizer,
31
+ AutoModelForCausalLM,
32
+ AutoProcessor,
33
+ AutoTokenizer,
34
+ )
35
+ from huggingface_hub import hf_hub_download
36
+
37
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
38
+ CausalVideoAutoencoder,
39
+ )
40
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
41
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
42
+ from ltx_video.pipelines.pipeline_ltx_video import (
43
+ ConditioningItem,
44
+ LTXVideoPipeline,
45
+ LTXMultiScalePipeline,
46
+ )
47
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
48
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
49
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
50
+ import ltx_video.pipelines.crf_compressor as crf_compressor
51
+
52
+ MAX_HEIGHT = 720
53
+ MAX_WIDTH = 1280
54
+ MAX_NUM_FRAMES = 257
55
+
56
+ logger = logging.get_logger("LTX-Video")
57
+
58
+
59
+ def get_total_gpu_memory():
60
+ if torch.cuda.is_available():
61
+ total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
62
+ return total_memory
63
+ return 0
64
+
65
+
66
+ def get_device():
67
+ if torch.cuda.is_available():
68
+ return "cuda"
69
+ elif torch.backends.mps.is_available():
70
+ return "mps"
71
+ return "cpu"
72
+
73
+
74
+ def load_image_to_tensor_with_resize_and_crop(
75
+ image_input: Union[str, Image.Image],
76
+ target_height: int = 512,
77
+ target_width: int = 768,
78
+ just_crop: bool = False,
79
+ ) -> torch.Tensor:
80
+ """Load and process an image into a tensor.
81
+
82
+ Args:
83
+ image_input: Either a file path (str) or a PIL Image object
84
+ target_height: Desired height of output tensor
85
+ target_width: Desired width of output tensor
86
+ just_crop: If True, only crop the image to the target size without resizing
87
+ """
88
+ if isinstance(image_input, str):
89
+ image = Image.open(image_input).convert("RGB")
90
+ elif isinstance(image_input, Image.Image):
91
+ image = image_input
92
+ else:
93
+ raise ValueError("image_input must be either a file path or a PIL Image object")
94
+
95
+ input_width, input_height = image.size
96
+ aspect_ratio_target = target_width / target_height
97
+ aspect_ratio_frame = input_width / input_height
98
+ if aspect_ratio_frame > aspect_ratio_target:
99
+ new_width = int(input_height * aspect_ratio_target)
100
+ new_height = input_height
101
+ x_start = (input_width - new_width) // 2
102
+ y_start = 0
103
+ else:
104
+ new_width = input_width
105
+ new_height = int(input_width / aspect_ratio_target)
106
+ x_start = 0
107
+ y_start = (input_height - new_height) // 2
108
+
109
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
110
+ if not just_crop:
111
+ image = image.resize((target_width, target_height))
112
+
113
+ image = np.array(image)
114
+ image = cv2.GaussianBlur(image, (3, 3), 0)
115
+ frame_tensor = torch.from_numpy(image).float()
116
+ frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
117
+ frame_tensor = frame_tensor.permute(2, 0, 1)
118
+ frame_tensor = (frame_tensor / 127.5) - 1.0
119
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
120
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
121
+
122
+
123
+ def calculate_padding(
124
+ source_height: int, source_width: int, target_height: int, target_width: int
125
+ ) -> tuple[int, int, int, int]:
126
+
127
+ # Calculate total padding needed
128
+ pad_height = target_height - source_height
129
+ pad_width = target_width - source_width
130
+
131
+ # Calculate padding for each side
132
+ pad_top = pad_height // 2
133
+ pad_bottom = pad_height - pad_top # Handles odd padding
134
+ pad_left = pad_width // 2
135
+ pad_right = pad_width - pad_left # Handles odd padding
136
+
137
+ # Return padded tensor
138
+ # Padding format is (left, right, top, bottom)
139
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
140
+ return padding
141
+
142
+
143
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
144
+ # Remove non-letters and convert to lowercase
145
+ clean_text = "".join(
146
+ char.lower() for char in text if char.isalpha() or char.isspace()
147
+ )
148
+
149
+ # Split into words
150
+ words = clean_text.split()
151
+
152
+ # Build result string keeping track of length
153
+ result = []
154
+ current_length = 0
155
+
156
+ for word in words:
157
+ # Add word length plus 1 for underscore (except for first word)
158
+ new_length = current_length + len(word)
159
+
160
+ if new_length <= max_len:
161
+ result.append(word)
162
+ current_length += len(word)
163
+ else:
164
+ break
165
+
166
+ return "-".join(result)
167
+
168
+
169
+ # Generate output video name
170
+ def get_unique_filename(
171
+ base: str,
172
+ ext: str,
173
+ prompt: str,
174
+ seed: int,
175
+ resolution: tuple[int, int, int],
176
+ dir: Path,
177
+ endswith=None,
178
+ index_range=1000,
179
+ ) -> Path:
180
+ base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
181
+ for i in range(index_range):
182
+ filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
183
+ if not os.path.exists(filename):
184
+ return filename
185
+ raise FileExistsError(
186
+ f"Could not find a unique filename after {index_range} attempts."
187
+ )
188
+
189
+
190
+ def seed_everething(seed: int):
191
+ random.seed(seed)
192
+ np.random.seed(seed)
193
+ torch.manual_seed(seed)
194
+ if torch.cuda.is_available():
195
+ torch.cuda.manual_seed(seed)
196
+ if torch.backends.mps.is_available():
197
+ torch.mps.manual_seed(seed)
198
+
199
+
200
+ def main():
201
+ parser = argparse.ArgumentParser(
202
+ description="Load models from separate directories and run the pipeline."
203
+ )
204
+
205
+ # Directories
206
+ parser.add_argument(
207
+ "--output_path",
208
+ type=str,
209
+ default=None,
210
+ help="Path to the folder to save output video, if None will save in outputs/ directory.",
211
+ )
212
+ parser.add_argument("--seed", type=int, default="171198")
213
+
214
+ # Pipeline parameters
215
+ parser.add_argument(
216
+ "--num_images_per_prompt",
217
+ type=int,
218
+ default=1,
219
+ help="Number of images per prompt",
220
+ )
221
+ parser.add_argument(
222
+ "--image_cond_noise_scale",
223
+ type=float,
224
+ default=0.15,
225
+ help="Amount of noise to add to the conditioned image",
226
+ )
227
+ parser.add_argument(
228
+ "--height",
229
+ type=int,
230
+ default=704,
231
+ help="Height of the output video frames. Optional if an input image provided.",
232
+ )
233
+ parser.add_argument(
234
+ "--width",
235
+ type=int,
236
+ default=1216,
237
+ help="Width of the output video frames. If None will infer from input image.",
238
+ )
239
+ parser.add_argument(
240
+ "--num_frames",
241
+ type=int,
242
+ default=121,
243
+ help="Number of frames to generate in the output video",
244
+ )
245
+ parser.add_argument(
246
+ "--frame_rate", type=int, default=30, help="Frame rate for the output video"
247
+ )
248
+ parser.add_argument(
249
+ "--device",
250
+ default=None,
251
+ help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
252
+ )
253
+ parser.add_argument(
254
+ "--pipeline_config",
255
+ type=str,
256
+ default="configs/ltxv-13b-0.9.7-dev.yaml",
257
+ help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
258
+ )
259
+
260
+ # Prompts
261
+ parser.add_argument(
262
+ "--prompt",
263
+ type=str,
264
+ help="Text prompt to guide generation",
265
+ )
266
+ parser.add_argument(
267
+ "--negative_prompt",
268
+ type=str,
269
+ default="worst quality, inconsistent motion, blurry, jittery, distorted",
270
+ help="Negative prompt for undesired features",
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--offload_to_cpu",
275
+ action="store_true",
276
+ help="Offloading unnecessary computations to CPU.",
277
+ )
278
+
279
+ # video-to-video arguments:
280
+ parser.add_argument(
281
+ "--input_media_path",
282
+ type=str,
283
+ default=None,
284
+ help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
285
+ )
286
+
287
+ # Conditioning arguments
288
+ parser.add_argument(
289
+ "--conditioning_media_paths",
290
+ type=str,
291
+ nargs="*",
292
+ help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
293
+ )
294
+ parser.add_argument(
295
+ "--conditioning_strengths",
296
+ type=float,
297
+ nargs="*",
298
+ help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
299
+ )
300
+ parser.add_argument(
301
+ "--conditioning_start_frames",
302
+ type=int,
303
+ nargs="*",
304
+ help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
305
+ )
306
+
307
+ args = parser.parse_args()
308
+ logger.warning(f"Running generation with arguments: {args}")
309
+ infer(**vars(args))
310
+
311
+
312
+ def create_ltx_video_pipeline(
313
+ ckpt_path: str,
314
+ precision: str,
315
+ text_encoder_model_name_or_path: str,
316
+ sampler: Optional[str] = None,
317
+ device: Optional[str] = None,
318
+ enhance_prompt: bool = False,
319
+ prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
320
+ prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
321
+ ) -> LTXVideoPipeline:
322
+ ckpt_path = Path(ckpt_path)
323
+ assert os.path.exists(
324
+ ckpt_path
325
+ ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
326
+
327
+ with safe_open(ckpt_path, framework="pt") as f:
328
+ metadata = f.metadata()
329
+ config_str = metadata.get("config")
330
+ configs = json.loads(config_str)
331
+ allowed_inference_steps = configs.get("allowed_inference_steps", None)
332
+
333
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
334
+ transformer = Transformer3DModel.from_pretrained(ckpt_path)
335
+
336
+ # Use constructor if sampler is specified, otherwise use from_pretrained
337
+ if sampler == "from_checkpoint" or not sampler:
338
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
339
+ else:
340
+ scheduler = RectifiedFlowScheduler(
341
+ sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
342
+ )
343
+
344
+ text_encoder = T5EncoderModel.from_pretrained(
345
+ text_encoder_model_name_or_path, subfolder="text_encoder"
346
+ )
347
+ patchifier = SymmetricPatchifier(patch_size=1)
348
+ tokenizer = T5Tokenizer.from_pretrained(
349
+ text_encoder_model_name_or_path, subfolder="tokenizer"
350
+ )
351
+
352
+ transformer = transformer.to(device)
353
+ vae = vae.to(device)
354
+ text_encoder = text_encoder.to(device)
355
+
356
+ if enhance_prompt:
357
+ prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
358
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
359
+ )
360
+ prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
361
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
362
+ )
363
+ prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
364
+ prompt_enhancer_llm_model_name_or_path,
365
+ torch_dtype="bfloat16",
366
+ )
367
+ prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
368
+ prompt_enhancer_llm_model_name_or_path,
369
+ )
370
+ else:
371
+ prompt_enhancer_image_caption_model = None
372
+ prompt_enhancer_image_caption_processor = None
373
+ prompt_enhancer_llm_model = None
374
+ prompt_enhancer_llm_tokenizer = None
375
+
376
+ vae = vae.to(torch.bfloat16)
377
+ if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
378
+ transformer = transformer.to(torch.bfloat16)
379
+ text_encoder = text_encoder.to(torch.bfloat16)
380
+
381
+ # Use submodels for the pipeline
382
+ submodel_dict = {
383
+ "transformer": transformer,
384
+ "patchifier": patchifier,
385
+ "text_encoder": text_encoder,
386
+ "tokenizer": tokenizer,
387
+ "scheduler": scheduler,
388
+ "vae": vae,
389
+ "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
390
+ "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
391
+ "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
392
+ "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
393
+ "allowed_inference_steps": allowed_inference_steps,
394
+ }
395
+
396
+ pipeline = LTXVideoPipeline(**submodel_dict)
397
+ pipeline = pipeline.to(device)
398
+ return pipeline
399
+
400
+
401
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
402
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
403
+ latent_upsampler.to(device)
404
+ latent_upsampler.eval()
405
+ return latent_upsampler
406
+
407
+
408
+ def infer(
409
+ output_path: Optional[str],
410
+ seed: int,
411
+ pipeline_config: str,
412
+ image_cond_noise_scale: float,
413
+ height: Optional[int],
414
+ width: Optional[int],
415
+ num_frames: int,
416
+ frame_rate: int,
417
+ prompt: str,
418
+ negative_prompt: str,
419
+ offload_to_cpu: bool,
420
+ input_media_path: Optional[str] = None,
421
+ conditioning_media_paths: Optional[List[str]] = None,
422
+ conditioning_strengths: Optional[List[float]] = None,
423
+ conditioning_start_frames: Optional[List[int]] = None,
424
+ device: Optional[str] = None,
425
+ **kwargs,
426
+ ):
427
+ # check if pipeline_config is a file
428
+ if not os.path.isfile(pipeline_config):
429
+ raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
430
+ with open(pipeline_config, "r") as f:
431
+ pipeline_config = yaml.safe_load(f)
432
+
433
+ models_dir = "MODEL_DIR"
434
+
435
+ ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
436
+ if not os.path.isfile(ltxv_model_name_or_path):
437
+ ltxv_model_path = hf_hub_download(
438
+ repo_id="Lightricks/LTX-Video",
439
+ filename=ltxv_model_name_or_path,
440
+ local_dir=models_dir,
441
+ repo_type="model",
442
+ )
443
+ else:
444
+ ltxv_model_path = ltxv_model_name_or_path
445
+
446
+ spatial_upscaler_model_name_or_path = pipeline_config.get(
447
+ "spatial_upscaler_model_path"
448
+ )
449
+ if spatial_upscaler_model_name_or_path and not os.path.isfile(
450
+ spatial_upscaler_model_name_or_path
451
+ ):
452
+ spatial_upscaler_model_path = hf_hub_download(
453
+ repo_id="Lightricks/LTX-Video",
454
+ filename=spatial_upscaler_model_name_or_path,
455
+ local_dir=models_dir,
456
+ repo_type="model",
457
+ )
458
+ else:
459
+ spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
460
+
461
+ if kwargs.get("input_image_path", None):
462
+ logger.warning(
463
+ "Please use conditioning_media_paths instead of input_image_path."
464
+ )
465
+ assert not conditioning_media_paths and not conditioning_start_frames
466
+ conditioning_media_paths = [kwargs["input_image_path"]]
467
+ conditioning_start_frames = [0]
468
+
469
+ # Validate conditioning arguments
470
+ if conditioning_media_paths:
471
+ # Use default strengths of 1.0
472
+ if not conditioning_strengths:
473
+ conditioning_strengths = [1.0] * len(conditioning_media_paths)
474
+ if not conditioning_start_frames:
475
+ raise ValueError(
476
+ "If `conditioning_media_paths` is provided, "
477
+ "`conditioning_start_frames` must also be provided"
478
+ )
479
+ if len(conditioning_media_paths) != len(conditioning_strengths) or len(
480
+ conditioning_media_paths
481
+ ) != len(conditioning_start_frames):
482
+ raise ValueError(
483
+ "`conditioning_media_paths`, `conditioning_strengths`, "
484
+ "and `conditioning_start_frames` must have the same length"
485
+ )
486
+ if any(s < 0 or s > 1 for s in conditioning_strengths):
487
+ raise ValueError("All conditioning strengths must be between 0 and 1")
488
+ if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
489
+ raise ValueError(
490
+ f"All conditioning start frames must be between 0 and {num_frames-1}"
491
+ )
492
+
493
+ seed_everething(seed)
494
+ if offload_to_cpu and not torch.cuda.is_available():
495
+ logger.warning(
496
+ "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
497
+ )
498
+ offload_to_cpu = False
499
+ else:
500
+ offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
501
+
502
+ output_dir = (
503
+ Path(output_path)
504
+ if output_path
505
+ else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
506
+ )
507
+ output_dir.mkdir(parents=True, exist_ok=True)
508
+
509
+ # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
510
+ height_padded = ((height - 1) // 32 + 1) * 32
511
+ width_padded = ((width - 1) // 32 + 1) * 32
512
+ num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
513
+
514
+ padding = calculate_padding(height, width, height_padded, width_padded)
515
+
516
+ logger.warning(
517
+ f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
518
+ )
519
+
520
+ prompt_enhancement_words_threshold = pipeline_config[
521
+ "prompt_enhancement_words_threshold"
522
+ ]
523
+
524
+ prompt_word_count = len(prompt.split())
525
+ enhance_prompt = (
526
+ prompt_enhancement_words_threshold > 0
527
+ and prompt_word_count < prompt_enhancement_words_threshold
528
+ )
529
+
530
+ if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
531
+ logger.info(
532
+ f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
533
+ )
534
+
535
+ precision = pipeline_config["precision"]
536
+ text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
537
+ sampler = pipeline_config["sampler"]
538
+ prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
539
+ "prompt_enhancer_image_caption_model_name_or_path"
540
+ ]
541
+ prompt_enhancer_llm_model_name_or_path = pipeline_config[
542
+ "prompt_enhancer_llm_model_name_or_path"
543
+ ]
544
+
545
+ pipeline = create_ltx_video_pipeline(
546
+ ckpt_path=ltxv_model_path,
547
+ precision=precision,
548
+ text_encoder_model_name_or_path=text_encoder_model_name_or_path,
549
+ sampler=sampler,
550
+ device=kwargs.get("device", get_device()),
551
+ enhance_prompt=enhance_prompt,
552
+ prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
553
+ prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
554
+ )
555
+
556
+ if pipeline_config.get("pipeline_type", None) == "multi-scale":
557
+ if not spatial_upscaler_model_path:
558
+ raise ValueError(
559
+ "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
560
+ )
561
+ latent_upsampler = create_latent_upsampler(
562
+ spatial_upscaler_model_path, pipeline.device
563
+ )
564
+ pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
565
+
566
+ media_item = None
567
+ if input_media_path:
568
+ media_item = load_media_file(
569
+ media_path=input_media_path,
570
+ height=height,
571
+ width=width,
572
+ max_frames=num_frames_padded,
573
+ padding=padding,
574
+ )
575
+
576
+ conditioning_items = (
577
+ prepare_conditioning(
578
+ conditioning_media_paths=conditioning_media_paths,
579
+ conditioning_strengths=conditioning_strengths,
580
+ conditioning_start_frames=conditioning_start_frames,
581
+ height=height,
582
+ width=width,
583
+ num_frames=num_frames,
584
+ padding=padding,
585
+ pipeline=pipeline,
586
+ )
587
+ if conditioning_media_paths
588
+ else None
589
+ )
590
+
591
+ stg_mode = pipeline_config.get("stg_mode", "attention_values")
592
+ del pipeline_config["stg_mode"]
593
+ if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
594
+ skip_layer_strategy = SkipLayerStrategy.AttentionValues
595
+ elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
596
+ skip_layer_strategy = SkipLayerStrategy.AttentionSkip
597
+ elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
598
+ skip_layer_strategy = SkipLayerStrategy.Residual
599
+ elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
600
+ skip_layer_strategy = SkipLayerStrategy.TransformerBlock
601
+ else:
602
+ raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
603
+
604
+ # Prepare input for the pipeline
605
+ sample = {
606
+ "prompt": prompt,
607
+ "prompt_attention_mask": None,
608
+ "negative_prompt": negative_prompt,
609
+ "negative_prompt_attention_mask": None,
610
+ }
611
+
612
+ device = device or get_device()
613
+ generator = torch.Generator(device=device).manual_seed(seed)
614
+
615
+ images = pipeline(
616
+ **pipeline_config,
617
+ skip_layer_strategy=skip_layer_strategy,
618
+ generator=generator,
619
+ output_type="pt",
620
+ callback_on_step_end=None,
621
+ height=height_padded,
622
+ width=width_padded,
623
+ num_frames=num_frames_padded,
624
+ frame_rate=frame_rate,
625
+ **sample,
626
+ media_items=media_item,
627
+ conditioning_items=conditioning_items,
628
+ is_video=True,
629
+ vae_per_channel_normalize=True,
630
+ image_cond_noise_scale=image_cond_noise_scale,
631
+ mixed_precision=(precision == "mixed_precision"),
632
+ offload_to_cpu=offload_to_cpu,
633
+ device=device,
634
+ enhance_prompt=enhance_prompt,
635
+ ).images
636
+
637
+ # Crop the padded images to the desired resolution and number of frames
638
+ (pad_left, pad_right, pad_top, pad_bottom) = padding
639
+ pad_bottom = -pad_bottom
640
+ pad_right = -pad_right
641
+ if pad_bottom == 0:
642
+ pad_bottom = images.shape[3]
643
+ if pad_right == 0:
644
+ pad_right = images.shape[4]
645
+ images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
646
+
647
+ for i in range(images.shape[0]):
648
+ # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
649
+ video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
650
+ # Unnormalizing images to [0, 255] range
651
+ video_np = (video_np * 255).astype(np.uint8)
652
+ fps = frame_rate
653
+ height, width = video_np.shape[1:3]
654
+ # In case a single image is generated
655
+ if video_np.shape[0] == 1:
656
+ output_filename = get_unique_filename(
657
+ f"image_output_{i}",
658
+ ".png",
659
+ prompt=prompt,
660
+ seed=seed,
661
+ resolution=(height, width, num_frames),
662
+ dir=output_dir,
663
+ )
664
+ imageio.imwrite(output_filename, video_np[0])
665
+ else:
666
+ output_filename = get_unique_filename(
667
+ f"video_output_{i}",
668
+ ".mp4",
669
+ prompt=prompt,
670
+ seed=seed,
671
+ resolution=(height, width, num_frames),
672
+ dir=output_dir,
673
+ )
674
+
675
+ # Write video
676
+ with imageio.get_writer(output_filename, fps=fps) as video:
677
+ for frame in video_np:
678
+ video.append_data(frame)
679
+
680
+ logger.warning(f"Output saved to {output_filename}")
681
+
682
+
683
+ def prepare_conditioning(
684
+ conditioning_media_paths: List[str],
685
+ conditioning_strengths: List[float],
686
+ conditioning_start_frames: List[int],
687
+ height: int,
688
+ width: int,
689
+ num_frames: int,
690
+ padding: tuple[int, int, int, int],
691
+ pipeline: LTXVideoPipeline,
692
+ ) -> Optional[List[ConditioningItem]]:
693
+ """Prepare conditioning items based on input media paths and their parameters.
694
+
695
+ Args:
696
+ conditioning_media_paths: List of paths to conditioning media (images or videos)
697
+ conditioning_strengths: List of conditioning strengths for each media item
698
+ conditioning_start_frames: List of frame indices where each item should be applied
699
+ height: Height of the output frames
700
+ width: Width of the output frames
701
+ num_frames: Number of frames in the output video
702
+ padding: Padding to apply to the frames
703
+ pipeline: LTXVideoPipeline object used for condition video trimming
704
+
705
+ Returns:
706
+ A list of ConditioningItem objects.
707
+ """
708
+ conditioning_items = []
709
+ for path, strength, start_frame in zip(
710
+ conditioning_media_paths, conditioning_strengths, conditioning_start_frames
711
+ ):
712
+ num_input_frames = orig_num_input_frames = get_media_num_frames(path)
713
+ if hasattr(pipeline, "trim_conditioning_sequence") and callable(
714
+ getattr(pipeline, "trim_conditioning_sequence")
715
+ ):
716
+ num_input_frames = pipeline.trim_conditioning_sequence(
717
+ start_frame, orig_num_input_frames, num_frames
718
+ )
719
+ if num_input_frames < orig_num_input_frames:
720
+ logger.warning(
721
+ f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
722
+ )
723
+
724
+ media_tensor = load_media_file(
725
+ media_path=path,
726
+ height=height,
727
+ width=width,
728
+ max_frames=num_input_frames,
729
+ padding=padding,
730
+ just_crop=True,
731
+ )
732
+ conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
733
+ return conditioning_items
734
+
735
+
736
+ def get_media_num_frames(media_path: str) -> int:
737
+ is_video = any(
738
+ media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
739
+ )
740
+ num_frames = 1
741
+ if is_video:
742
+ reader = imageio.get_reader(media_path)
743
+ num_frames = reader.count_frames()
744
+ reader.close()
745
+ return num_frames
746
+
747
+
748
+ def load_media_file(
749
+ media_path: str,
750
+ height: int,
751
+ width: int,
752
+ max_frames: int,
753
+ padding: tuple[int, int, int, int],
754
+ just_crop: bool = False,
755
+ ) -> torch.Tensor:
756
+ is_video = any(
757
+ media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
758
+ )
759
+ if is_video:
760
+ reader = imageio.get_reader(media_path)
761
+ num_input_frames = min(reader.count_frames(), max_frames)
762
+
763
+ # Read and preprocess the relevant frames from the video file.
764
+ frames = []
765
+ for i in range(num_input_frames):
766
+ frame = Image.fromarray(reader.get_data(i))
767
+ frame_tensor = load_image_to_tensor_with_resize_and_crop(
768
+ frame, height, width, just_crop=just_crop
769
+ )
770
+ frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
771
+ frames.append(frame_tensor)
772
+ reader.close()
773
+
774
+ # Stack frames along the temporal dimension
775
+ media_tensor = torch.cat(frames, dim=2)
776
+ else: # Input image
777
+ media_tensor = load_image_to_tensor_with_resize_and_crop(
778
+ media_path, height, width, just_crop=just_crop
779
+ )
780
+ media_tensor = torch.nn.functional.pad(media_tensor, padding)
781
+ return media_tensor
782
+
783
+
784
+ if __name__ == "__main__":
785
+ main()