import gradio as gr from gradio_toggle import Toggle import argparse import json import os import random from datetime import datetime from pathlib import Path from diffusers.utils import logging import imageio import numpy as np import safetensors.torch import torch import torch.nn.functional as F from PIL import Image from transformers import T5EncoderModel, T5Tokenizer import tempfile from ltx_video.models.autoencoders.causal_video_autoencoder import ( CausalVideoAutoencoder, ) from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier from ltx_video.models.transformers.transformer3d import Transformer3DModel from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline from ltx_video.schedulers.rf import RectifiedFlowScheduler from ltx_video.utils.conditioning_method import ConditioningMethod from torchao.quantization import quantize_, int8_weight_only MAX_HEIGHT = 720 MAX_WIDTH = 1280 MAX_NUM_FRAMES = 257 def load_vae(vae_dir, int8=False): vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors" vae_config_path = vae_dir / "config.json" with open(vae_config_path, "r") as f: vae_config = json.load(f) vae = CausalVideoAutoencoder.from_config(vae_config) vae_state_dict = safetensors.torch.load_file(vae_ckpt_path) vae.load_state_dict(vae_state_dict) # Ensure everything runs on the CPU vae = vae.to('cpu') if int8: print("vae - quantization = true") quantize_(vae, int8_weight_only()) return vae def load_unet(unet_dir, int8=False): unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors" unet_config_path = unet_dir / "config.json" transformer_config = Transformer3DModel.load_config(unet_config_path) transformer = Transformer3DModel.from_config(transformer_config) unet_state_dict = safetensors.torch.load_file(unet_ckpt_path) transformer.load_state_dict(unet_state_dict, strict=True) # Ensure everything runs on the CPU transformer = transformer.to('cpu') if int8: print("unet - quantization = true") quantize_(transformer, int8_weight_only()) return transformer def load_scheduler(scheduler_dir): scheduler_config_path = scheduler_dir / "scheduler_config.json" scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) return RectifiedFlowScheduler.from_config(scheduler_config) def load_image_to_tensor_with_resize_and_crop(image_path, target_height=512, target_width=768): image = Image.open(image_path).convert("RGB") input_width, input_height = image.size aspect_ratio_target = target_width / target_height aspect_ratio_frame = input_width / input_height if aspect_ratio_frame > aspect_ratio_target: new_width = int(input_height * aspect_ratio_target) new_height = input_height x_start = (input_width - new_width) // 2 y_start = 0 else: new_width = input_width new_height = int(input_width / aspect_ratio_target) x_start = 0 y_start = (input_height - new_height) // 2 image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) image = image.resize((target_width, target_height)) frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() frame_tensor = (frame_tensor / 127.5) - 1.0 # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) return frame_tensor.unsqueeze(0).unsqueeze(2) def calculate_padding( source_height: int, source_width: int, target_height: int, target_width: int ) -> tuple[int, int, int, int]: # Calculate total padding needed pad_height = target_height - source_height pad_width = target_width - source_width # Calculate padding for each side pad_top = pad_height // 2 pad_bottom = pad_height - pad_top # Handles odd padding pad_left = pad_width // 2 pad_right = pad_width - pad_left # Handles odd padding # Return padded tensor # Padding format is (left, right, top, bottom) padding = (pad_left, pad_right, pad_top, pad_bottom) return padding def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: # Remove non-letters and convert to lowercase clean_text = "".join( char.lower() for char in text if char.isalpha() or char.isspace() ) # Split into words words = clean_text.split() # Build result string keeping track of length result = [] current_length = 0 for word in words: # Add word length plus 1 for underscore (except for first word) new_length = current_length + len(word) if new_length <= max_len: result.append(word) current_length += len(word) else: break return "-".join(result) # Generate output video name def get_unique_filename( base: str, ext: str, prompt: str, seed: int, resolution: tuple[int, int, int], dir: Path, endswith=None, index_range=1000, ) -> Path: base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" for i in range(index_range): filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" if not os.path.exists(filename): return filename raise FileExistsError( f"Could not find a unique filename after {index_range} attempts." ) def seed_everething(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def main( img2vid_image="", prompt="", txt2vid_analytics_toggle=False, negative_prompt="", frame_rate=25, seed=0, num_inference_steps=30, guidance_scale=3, height=512, width=768, num_frames=121, progress=gr.Progress(), ): logger = logging.get_logger(__name__) args = { "ckpt_dir": "Lightricks/LTX-Video", "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate, "prompt": prompt, "negative_prompt": negative_prompt, "seed": 0, "output_path": os.path.join(tempfile.gettempdir(), "gradio"), "num_images_per_prompt": 1, "input_image_path": img2vid_image, "input_video_path": "", "bfloat16": True, "disable_load_needed_only": False } logger.warning(f"Running generation with arguments: {args}") seed_everething(args['seed']) output_dir = ( Path(args['output_path']) if args['output_path'] else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") ) output_dir.mkdir(parents=True, exist_ok=True) # Load image if args['input_image_path']: media_items_prepad = load_image_to_tensor_with_resize_and_crop( args['input_image_path'], args['height'], args['width'] ) else: media_items_prepad = None height = args['height'] if args['height'] else media_items_prepad.shape[-2] width = args['width'] if args['width'] else media_items_prepad.shape[-1] num_frames = args['num_frames'] if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES: logger.warning( f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}." ) # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1) height_padded = ((height - 1) // 32 + 1) * 32 width_padded = ((width - 1) // 32 + 1) * 32 num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1 padding = calculate_padding(height, width, height_padded, width_padded) logger.warning( f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}" ) if media_items_prepad is not None: media_items = F.pad( media_items_prepad, padding, mode="constant", value=-1 ) # -1 is the value for padding since the image is normalized to -1, 1 else: media_items = None # Load models vae = load_vae(Path(args['ckpt_dir']) / "vae", txt2vid_analytics_toggle) unet = load_unet(Path(args['ckpt_dir']) / "unet", txt2vid_analytics_toggle) scheduler = load_scheduler(Path(args['ckpt_dir']) / "scheduler") patchifier = SymmetricPatchifier(patch_size=1) text_encoder = T5EncoderModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" ).to('cpu') # Force to CPU tokenizer = T5Tokenizer.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer" ) # Use submodels for the pipeline submodel_dict = { "transformer": unet, "patchifier": patchifier, "text_encoder": text_encoder, "tokenizer": tokenizer, "scheduler": scheduler, "vae": vae, } pipeline = LTXVideoPipeline(**submodel_dict) pipeline = pipeline.to('cpu') # Ensure pipeline runs on CPU # Prepare input for the pipeline sample = { "prompt": args['prompt'], "prompt_attention_mask": None, "negative_prompt": args['negative_prompt'], "negative_prompt_attention_mask": None, "media_items": media_items, } generator = torch.Generator(device="cpu").manual_seed(args['seed']) # Force CPU images = pipeline( num_inference_steps=args['num_inference_steps'], num_images_per_prompt=args['num_images_per_prompt'], guidance_scale=args['guidance_scale'], generator=generator, output_type="pt", callback_on_step_end=None, height=height_padded, width=width_padded, num_frames=num_frames_padded, frame_rate=args['frame_rate'], **sample, is_video=True, vae_per_channel_normalize=True, conditioning_method=( ConditioningMethod.FIRST_FRAME if media_items is not None else ConditioningMethod.UNCONDITIONAL ), mixed_precision=not args['bfloat16'], load_needed_only=not args['disable_load_needed_only'] ).images # Further processing and saving logic can go here...