Spaces:
Paused
Paused
| import argparse | |
| import os | |
| import random | |
| from datetime import datetime | |
| from pathlib import Path | |
| from diffusers.utils import logging | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import ( | |
| CausalVideoAutoencoder, | |
| ) | |
| from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
| from ltx_video.models.transformers.transformer3d import Transformer3DModel | |
| from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline | |
| from ltx_video.schedulers.rf import RectifiedFlowScheduler | |
| from ltx_video.utils.conditioning_method import ConditioningMethod | |
| from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy | |
| MAX_HEIGHT = 720 | |
| MAX_WIDTH = 1280 | |
| MAX_NUM_FRAMES = 257 | |
| def get_total_gpu_memory(): | |
| if torch.cuda.is_available(): | |
| total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| return total_memory | |
| return None | |
| def load_image_to_tensor_with_resize_and_crop( | |
| image_path, target_height=512, target_width=768 | |
| ): | |
| image = Image.open(image_path).convert("RGB") | |
| input_width, input_height = image.size | |
| aspect_ratio_target = target_width / target_height | |
| aspect_ratio_frame = input_width / input_height | |
| if aspect_ratio_frame > aspect_ratio_target: | |
| new_width = int(input_height * aspect_ratio_target) | |
| new_height = input_height | |
| x_start = (input_width - new_width) // 2 | |
| y_start = 0 | |
| else: | |
| new_width = input_width | |
| new_height = int(input_width / aspect_ratio_target) | |
| x_start = 0 | |
| y_start = (input_height - new_height) // 2 | |
| image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) | |
| image = image.resize((target_width, target_height)) | |
| frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() | |
| frame_tensor = (frame_tensor / 127.5) - 1.0 | |
| # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) | |
| return frame_tensor.unsqueeze(0).unsqueeze(2) | |
| def calculate_padding( | |
| source_height: int, source_width: int, target_height: int, target_width: int | |
| ) -> tuple[int, int, int, int]: | |
| # Calculate total padding needed | |
| pad_height = target_height - source_height | |
| pad_width = target_width - source_width | |
| # Calculate padding for each side | |
| pad_top = pad_height // 2 | |
| pad_bottom = pad_height - pad_top # Handles odd padding | |
| pad_left = pad_width // 2 | |
| pad_right = pad_width - pad_left # Handles odd padding | |
| # Return padded tensor | |
| # Padding format is (left, right, top, bottom) | |
| padding = (pad_left, pad_right, pad_top, pad_bottom) | |
| return padding | |
| def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: | |
| # Remove non-letters and convert to lowercase | |
| clean_text = "".join( | |
| char.lower() for char in text if char.isalpha() or char.isspace() | |
| ) | |
| # Split into words | |
| words = clean_text.split() | |
| # Build result string keeping track of length | |
| result = [] | |
| current_length = 0 | |
| for word in words: | |
| # Add word length plus 1 for underscore (except for first word) | |
| new_length = current_length + len(word) | |
| if new_length <= max_len: | |
| result.append(word) | |
| current_length += len(word) | |
| else: | |
| break | |
| return "-".join(result) | |
| # Generate output video name | |
| def get_unique_filename( | |
| base: str, | |
| ext: str, | |
| prompt: str, | |
| seed: int, | |
| resolution: tuple[int, int, int], | |
| dir: Path, | |
| endswith=None, | |
| index_range=1000, | |
| ) -> Path: | |
| base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" | |
| for i in range(index_range): | |
| filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" | |
| if not os.path.exists(filename): | |
| return filename | |
| raise FileExistsError( | |
| f"Could not find a unique filename after {index_range} attempts." | |
| ) | |
| def seed_everething(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Load models from separate directories and run the pipeline." | |
| ) | |
| # Directories | |
| parser.add_argument( | |
| "--ckpt_path", | |
| type=str, | |
| required=True, | |
| help="Path to a safetensors file that contains all model parts.", | |
| ) | |
| parser.add_argument( | |
| "--input_video_path", | |
| type=str, | |
| help="Path to the input video file (first frame used)", | |
| ) | |
| parser.add_argument( | |
| "--input_image_path", type=str, help="Path to the input image file" | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| default=None, | |
| help="Path to the folder to save output video, if None will save in outputs/ directory.", | |
| ) | |
| parser.add_argument("--seed", type=int, default="171198") | |
| # Pipeline parameters | |
| parser.add_argument( | |
| "--num_inference_steps", type=int, default=40, help="Number of inference steps" | |
| ) | |
| parser.add_argument( | |
| "--num_images_per_prompt", | |
| type=int, | |
| default=1, | |
| help="Number of images per prompt", | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale", | |
| type=float, | |
| default=3, | |
| help="Guidance scale for the pipeline", | |
| ) | |
| parser.add_argument( | |
| "--stg_scale", | |
| type=float, | |
| default=1, | |
| help="Spatiotemporal guidance scale for the pipeline. 0 to disable STG.", | |
| ) | |
| parser.add_argument( | |
| "--stg_rescale", | |
| type=float, | |
| default=0.7, | |
| help="Spatiotemporal guidance rescaling scale for the pipeline. 1 to disable rescale.", | |
| ) | |
| parser.add_argument( | |
| "--stg_mode", | |
| type=str, | |
| default="stg_a", | |
| help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.", | |
| ) | |
| parser.add_argument( | |
| "--stg_skip_layers", | |
| type=str, | |
| default="19", | |
| help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.", | |
| ) | |
| parser.add_argument( | |
| "--image_cond_noise_scale", | |
| type=float, | |
| default=0.15, | |
| help="Amount of noise to add to the conditioned image", | |
| ) | |
| parser.add_argument( | |
| "--height", | |
| type=int, | |
| default=480, | |
| help="Height of the output video frames. Optional if an input image provided.", | |
| ) | |
| parser.add_argument( | |
| "--width", | |
| type=int, | |
| default=704, | |
| help="Width of the output video frames. If None will infer from input image.", | |
| ) | |
| parser.add_argument( | |
| "--num_frames", | |
| type=int, | |
| default=121, | |
| help="Number of frames to generate in the output video", | |
| ) | |
| parser.add_argument( | |
| "--frame_rate", type=int, default=25, help="Frame rate for the output video" | |
| ) | |
| parser.add_argument( | |
| "--precision", | |
| choices=["bfloat16", "mixed_precision"], | |
| default="bfloat16", | |
| help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.", | |
| ) | |
| # VAE noise augmentation | |
| parser.add_argument( | |
| "--decode_timestep", | |
| type=float, | |
| default=0.05, | |
| help="Timestep for decoding noise", | |
| ) | |
| parser.add_argument( | |
| "--decode_noise_scale", | |
| type=float, | |
| default=0.025, | |
| help="Noise level for decoding noise", | |
| ) | |
| # Prompts | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| help="Text prompt to guide generation", | |
| ) | |
| parser.add_argument( | |
| "--negative_prompt", | |
| type=str, | |
| default="worst quality, inconsistent motion, blurry, jittery, distorted", | |
| help="Negative prompt for undesired features", | |
| ) | |
| parser.add_argument( | |
| "--offload_to_cpu", | |
| action="store_true", | |
| help="Offloading unnecessary computations to CPU.", | |
| ) | |
| logger = logging.get_logger(__name__) | |
| args = parser.parse_args() | |
| logger.warning(f"Running generation with arguments: {args}") | |
| seed_everething(args.seed) | |
| offload_to_cpu = False if not args.offload_to_cpu else get_total_gpu_memory() < 30 | |
| output_dir = ( | |
| Path(args.output_path) | |
| if args.output_path | |
| else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") | |
| ) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Load image | |
| if args.input_image_path: | |
| media_items_prepad = load_image_to_tensor_with_resize_and_crop( | |
| args.input_image_path, args.height, args.width | |
| ) | |
| else: | |
| media_items_prepad = None | |
| height = args.height if args.height else media_items_prepad.shape[-2] | |
| width = args.width if args.width else media_items_prepad.shape[-1] | |
| num_frames = args.num_frames | |
| if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES: | |
| logger.warning( | |
| f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}." | |
| ) | |
| # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1) | |
| height_padded = ((height - 1) // 32 + 1) * 32 | |
| width_padded = ((width - 1) // 32 + 1) * 32 | |
| num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1 | |
| padding = calculate_padding(height, width, height_padded, width_padded) | |
| logger.warning( | |
| f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}" | |
| ) | |
| if media_items_prepad is not None: | |
| media_items = F.pad( | |
| media_items_prepad, padding, mode="constant", value=-1 | |
| ) # -1 is the value for padding since the image is normalized to -1, 1 | |
| else: | |
| media_items = None | |
| ckpt_path = Path(args.ckpt_path) | |
| vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) | |
| transformer = Transformer3DModel.from_pretrained(ckpt_path) | |
| scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" | |
| ) | |
| patchifier = SymmetricPatchifier(patch_size=1) | |
| tokenizer = T5Tokenizer.from_pretrained( | |
| "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer" | |
| ) | |
| if torch.cuda.is_available(): | |
| transformer = transformer.cuda() | |
| vae = vae.cuda() | |
| text_encoder = text_encoder.cuda() | |
| vae = vae.to(torch.bfloat16) | |
| if args.precision == "bfloat16" and transformer.dtype != torch.bfloat16: | |
| transformer = transformer.to(torch.bfloat16) | |
| text_encoder = text_encoder.to(torch.bfloat16) | |
| # Set spatiotemporal guidance | |
| skip_block_list = [int(x.strip()) for x in args.stg_skip_layers.split(",")] | |
| skip_layer_strategy = ( | |
| SkipLayerStrategy.Attention | |
| if args.stg_mode.lower() == "stg_a" | |
| else SkipLayerStrategy.Residual | |
| ) | |
| # Use submodels for the pipeline | |
| submodel_dict = { | |
| "transformer": transformer, | |
| "patchifier": patchifier, | |
| "text_encoder": text_encoder, | |
| "tokenizer": tokenizer, | |
| "scheduler": scheduler, | |
| "vae": vae, | |
| } | |
| pipeline = LTXVideoPipeline(**submodel_dict) | |
| if torch.cuda.is_available(): | |
| pipeline = pipeline.to("cuda") | |
| # Prepare input for the pipeline | |
| sample = { | |
| "prompt": args.prompt, | |
| "prompt_attention_mask": None, | |
| "negative_prompt": args.negative_prompt, | |
| "negative_prompt_attention_mask": None, | |
| "media_items": media_items, | |
| } | |
| generator = torch.Generator( | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ).manual_seed(args.seed) | |
| images = pipeline( | |
| num_inference_steps=args.num_inference_steps, | |
| num_images_per_prompt=args.num_images_per_prompt, | |
| guidance_scale=args.guidance_scale, | |
| skip_layer_strategy=skip_layer_strategy, | |
| skip_block_list=skip_block_list, | |
| stg_scale=args.stg_scale, | |
| do_rescaling=args.stg_rescale != 1, | |
| rescaling_scale=args.stg_rescale, | |
| generator=generator, | |
| output_type="pt", | |
| callback_on_step_end=None, | |
| height=height_padded, | |
| width=width_padded, | |
| num_frames=num_frames_padded, | |
| frame_rate=args.frame_rate, | |
| **sample, | |
| is_video=True, | |
| vae_per_channel_normalize=True, | |
| conditioning_method=( | |
| ConditioningMethod.FIRST_FRAME | |
| if media_items is not None | |
| else ConditioningMethod.UNCONDITIONAL | |
| ), | |
| image_cond_noise_scale=args.image_cond_noise_scale, | |
| decode_timestep=args.decode_timestep, | |
| decode_noise_scale=args.decode_noise_scale, | |
| mixed_precision=(args.precision == "mixed_precision"), | |
| offload_to_cpu=offload_to_cpu, | |
| ).images | |
| # Crop the padded images to the desired resolution and number of frames | |
| (pad_left, pad_right, pad_top, pad_bottom) = padding | |
| pad_bottom = -pad_bottom | |
| pad_right = -pad_right | |
| if pad_bottom == 0: | |
| pad_bottom = images.shape[3] | |
| if pad_right == 0: | |
| pad_right = images.shape[4] | |
| images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right] | |
| for i in range(images.shape[0]): | |
| # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C | |
| video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy() | |
| # Unnormalizing images to [0, 255] range | |
| video_np = (video_np * 255).astype(np.uint8) | |
| fps = args.frame_rate | |
| height, width = video_np.shape[1:3] | |
| # In case a single image is generated | |
| if video_np.shape[0] == 1: | |
| output_filename = get_unique_filename( | |
| f"image_output_{i}", | |
| ".png", | |
| prompt=args.prompt, | |
| seed=args.seed, | |
| resolution=(height, width, num_frames), | |
| dir=output_dir, | |
| ) | |
| imageio.imwrite(output_filename, video_np[0]) | |
| else: | |
| if args.input_image_path: | |
| base_filename = f"img_to_vid_{i}" | |
| else: | |
| base_filename = f"text_to_vid_{i}" | |
| output_filename = get_unique_filename( | |
| base_filename, | |
| ".mp4", | |
| prompt=args.prompt, | |
| seed=args.seed, | |
| resolution=(height, width, num_frames), | |
| dir=output_dir, | |
| ) | |
| # Write video | |
| with imageio.get_writer(output_filename, fps=fps) as video: | |
| for frame in video_np: | |
| video.append_data(frame) | |
| # Write condition image | |
| if args.input_image_path: | |
| reference_image = ( | |
| ( | |
| media_items_prepad[0, :, 0].permute(1, 2, 0).cpu().data.numpy() | |
| + 1.0 | |
| ) | |
| / 2.0 | |
| * 255 | |
| ) | |
| imageio.imwrite( | |
| get_unique_filename( | |
| base_filename, | |
| ".png", | |
| prompt=args.prompt, | |
| seed=args.seed, | |
| resolution=(height, width, num_frames), | |
| dir=output_dir, | |
| endswith="_condition", | |
| ), | |
| reference_image.astype(np.uint8), | |
| ) | |
| logger.warning(f"Output saved to {output_dir}") | |
| if __name__ == "__main__": | |
| main() | |