#!/usr/bin/env python3 # ruff: noqa: T201 """ CLI script for running LTX video/audio generation inference. Usage: # Text-to-Video + Audio (default behavior) python scripts/inference.py --checkpoint path/to/model.safetensors \ --text-encoder-path path/to/gemma \ --prompt "A cat playing with a ball" --output output.mp4 # Video only (skip audio) python scripts/inference.py --checkpoint path/to/model.safetensors \ --text-encoder-path path/to/gemma \ --prompt "A cat playing with a ball" --skip-audio --output output.mp4 # Image-to-Video python scripts/inference.py --checkpoint path/to/model.safetensors \ --text-encoder-path path/to/gemma \ --prompt "A cat walking" --condition-image first_frame.png --output output.mp4 # Video-to-Video (IC-LoRA style) python scripts/inference.py --checkpoint path/to/model.safetensors \ --text-encoder-path path/to/gemma \ --prompt "A cat turning into a dog" --reference-video input.mp4 --output output.mp4 # With LoRA weights python scripts/inference.py --checkpoint path/to/model.safetensors \ --text-encoder-path path/to/gemma \ --lora-path path/to/lora.safetensors \ --prompt "A cat in my custom style" --output output.mp4 """ import argparse import re from pathlib import Path import torch import torchaudio from peft import LoraConfig, get_peft_model, set_peft_model_state_dict from safetensors.torch import load_file from torchvision import transforms from ltx_trainer.model_loader import load_model from ltx_trainer.progress import StandaloneSamplingProgress from ltx_trainer.utils import open_image_as_srgb from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler from ltx_trainer.video_utils import read_video, save_video def load_image(image_path: str) -> torch.Tensor: """Load an image and convert to tensor [C, H, W] in [0, 1].""" image = open_image_as_srgb(image_path) transform = transforms.ToTensor() return transform(image) def extract_lora_target_modules(state_dict: dict[str, torch.Tensor]) -> list[str]: """Extract target module names from LoRA checkpoint keys. LoRA keys follow the pattern (after removing "diffusion_model." prefix): - transformer_blocks.0.attn1.to_k.lora_A.weight - transformer_blocks.0.ff.net.0.proj.lora_B.weight This extracts the full module path like "transformer_blocks.0.attn1.to_k". Using full paths is more robust than partial patterns. """ target_modules = set() # Pattern to extract everything before .lora_A or .lora_B pattern = re.compile(r"(.+)\.lora_[AB]\.") for key in state_dict: match = pattern.match(key) if match: module_path = match.group(1) target_modules.add(module_path) return sorted(target_modules) def load_lora_weights(transformer: torch.nn.Module, lora_path: str | Path) -> torch.nn.Module: """Load LoRA weights into the transformer model. The LoRA rank and target modules are automatically detected from the checkpoint. Alpha is set equal to rank (standard practice for inference). Args: transformer: The base transformer model lora_path: Path to the LoRA weights (.safetensors) Returns: The transformer model with LoRA weights applied """ print(f"Loading LoRA weights from {lora_path}...") # Load the LoRA state dict state_dict = load_file(str(lora_path)) # Remove "diffusion_model." prefix (ComfyUI-compatible format) state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()} # Extract target modules from the checkpoint target_modules = extract_lora_target_modules(state_dict) if not target_modules: raise ValueError(f"Could not extract target modules from LoRA checkpoint: {lora_path}") print(f" Detected {len(target_modules)} target modules") # Auto-detect rank from the first lora_A weight shape lora_rank = None for key, value in state_dict.items(): if "lora_A" in key and value.ndim == 2: lora_rank = value.shape[0] break if lora_rank is None: raise ValueError("Could not auto-detect LoRA rank from weights") print(f" LoRA rank: {lora_rank}") # Create LoRA config and wrap the model # Alpha = rank is standard for inference (maintains the trained scale) lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_rank, target_modules=target_modules, lora_dropout=0.0, init_lora_weights=True, ) # Wrap the transformer with PEFT to add LoRA layers transformer = get_peft_model(transformer, lora_config) # Load the LoRA weights base_model = transformer.get_base_model() set_peft_model_state_dict(base_model, state_dict) print("✓ LoRA weights loaded successfully") return transformer def main() -> None: # noqa: PLR0912, PLR0915 parser = argparse.ArgumentParser( description="LTX Video/Audio Generation", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Model arguments parser.add_argument( "--checkpoint", type=str, required=True, help="Path to model checkpoint (.safetensors)", ) parser.add_argument( "--text-encoder-path", type=str, required=True, help="Path to Gemma text encoder directory", ) # LoRA arguments parser.add_argument( "--lora-path", type=str, default=None, help="Path to LoRA weights (.safetensors)", ) # Generation arguments parser.add_argument( "--prompt", type=str, required=True, help="Text prompt for generation", ) parser.add_argument( "--negative-prompt", type=str, default="", help="Negative prompt", ) parser.add_argument( "--height", type=int, default=544, help="Video height (must be divisible by 32)", ) parser.add_argument( "--width", type=int, default=960, help="Video width (must be divisible by 32)", ) parser.add_argument( "--num-frames", type=int, default=97, help="Number of video frames (must be k*8 + 1)", ) parser.add_argument( "--frame-rate", type=float, default=25.0, help="Video frame rate", ) parser.add_argument( "--num-inference-steps", type=int, default=30, help="Number of denoising steps", ) parser.add_argument( "--guidance-scale", type=float, default=3.0, help="Classifier-free guidance scale (CFG)", ) parser.add_argument( "--stg-scale", type=float, default=1.0, help="STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. Default: 1.0", ) parser.add_argument( "--stg-blocks", type=int, nargs="*", default=[29], help="Which transformer blocks to perturb for STG. Default: 29 (single block).", ) parser.add_argument( "--stg-mode", type=str, default="stg_av", choices=["stg_av", "stg_v"], help="STG mode: 'stg_av' perturbs both audio and video, 'stg_v' perturbs video only", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducibility", ) # Conditioning arguments parser.add_argument( "--condition-image", type=str, default=None, help="Path to conditioning image for image-to-video generation", ) parser.add_argument( "--reference-video", type=str, default=None, help="Path to reference video for video-to-video generation (IC-LoRA style)", ) parser.add_argument( "--include-reference-in-output", action="store_true", help="Include reference video side-by-side with generated output (only for V2V)", ) # Audio arguments parser.add_argument( "--skip-audio", action="store_true", help="Skip audio generation (by default, audio is generated alongside video)", ) # Output arguments parser.add_argument( "--output", type=str, required=True, help="Output video path (.mp4)", ) parser.add_argument( "--audio-output", type=str, default=None, help="Output audio path (.wav, optional - if not provided, audio will be embedded in video)", ) # Device arguments parser.add_argument( "--device", type=str, default="cuda", help="Device to run on (cuda/cpu)", ) args = parser.parse_args() # Validate conditioning arguments if args.include_reference_in_output and args.reference_video is None: parser.error("--include-reference-in-output requires --reference-video") # Validate arguments generate_audio = not args.skip_audio print("=" * 80) print("LTX Video/Audio Generation") print("=" * 80) # Determine if we need VAE encoder (for image or video conditioning) need_vae_encoder = args.condition_image is not None or args.reference_video is not None components = load_model( checkpoint_path=args.checkpoint, device="cpu", # Load to CPU first, sampler will move to device as needed dtype=torch.bfloat16, with_video_vae_encoder=need_vae_encoder, with_video_vae_decoder=True, with_audio_vae_decoder=generate_audio, with_vocoder=generate_audio, with_text_encoder=True, text_encoder_path=args.text_encoder_path, ) # Apply LoRA weights if provided transformer = components.transformer if args.lora_path is not None: transformer = load_lora_weights(transformer, args.lora_path) # Load conditioning image if provided condition_image = None if args.condition_image: print(f"Loading conditioning image from {args.condition_image}...") condition_image = load_image(args.condition_image) # Load reference video if provided reference_video = None if args.reference_video: print(f"Loading reference video from {args.reference_video}...") reference_video, ref_fps = read_video(args.reference_video, max_frames=args.num_frames) print(f" Loaded {reference_video.shape[0]} frames @ {ref_fps:.1f} fps") # Determine generation mode if args.reference_video is not None and args.condition_image is not None: mode = "Video-to-Video + Image Conditioning (V2V+I2V)" elif args.reference_video is not None: mode = "Video-to-Video (V2V)" elif args.condition_image is not None: mode = "Image-to-Video (I2V)" else: mode = "Text-to-Video (T2V)" print("\n" + "=" * 80) print("Generation Parameters") print("=" * 80) print(f"Mode: {mode}") print(f"Prompt: {args.prompt}") if args.negative_prompt: print(f"Negative prompt: {args.negative_prompt}") print(f"Resolution: {args.width}x{args.height}") print(f"Frames: {args.num_frames} @ {args.frame_rate} fps") print(f"Inference steps: {args.num_inference_steps}") print(f"CFG scale: {args.guidance_scale}") if args.stg_scale > 0: blocks_str = args.stg_blocks if args.stg_blocks else "all" print(f"STG scale: {args.stg_scale} (mode: {args.stg_mode}, blocks: {blocks_str})") else: print("STG: disabled") print(f"Seed: {args.seed}") if args.lora_path: print(f"LoRA: {args.lora_path}") if condition_image is not None: print(f"Conditioning: Image ({args.condition_image})") if reference_video is not None: print(f"Reference: Video ({args.reference_video})") if args.include_reference_in_output: print(" → Will include reference side-by-side in output") if generate_audio: video_duration = args.num_frames / args.frame_rate print(f"Audio: Enabled (duration will match video: {video_duration:.2f}s)") print("=" * 80) print(f"\nGenerating {'video + audio' if generate_audio else 'video'}...") # Create generation config gen_config = GenerationConfig( prompt=args.prompt, negative_prompt=args.negative_prompt, height=args.height, width=args.width, num_frames=args.num_frames, frame_rate=args.frame_rate, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, seed=args.seed, condition_image=condition_image, reference_video=reference_video, generate_audio=generate_audio, include_reference_in_output=args.include_reference_in_output, stg_scale=args.stg_scale, stg_blocks=args.stg_blocks, stg_mode=args.stg_mode, ) # Generate with progress bar with StandaloneSamplingProgress(num_steps=args.num_inference_steps) as progress: # Create sampler with progress context sampler = ValidationSampler( transformer=transformer, vae_decoder=components.video_vae_decoder, vae_encoder=components.video_vae_encoder, text_encoder=components.text_encoder, audio_decoder=components.audio_vae_decoder if generate_audio else None, vocoder=components.vocoder if generate_audio else None, sampling_context=progress, ) video, audio = sampler.generate( config=gen_config, device=args.device, ) # Save video output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) # Get audio sample rate from vocoder if audio was generated audio_sample_rate = None if audio is not None and components.vocoder is not None: audio_sample_rate = components.vocoder.output_sample_rate save_video( video_tensor=video, output_path=output_path, fps=args.frame_rate, audio=audio, audio_sample_rate=audio_sample_rate, ) print(f"✓ Video saved to {args.output}") # Save separate audio file if requested if audio is not None and args.audio_output is not None: audio_output_path = Path(args.audio_output) audio_output_path.parent.mkdir(parents=True, exist_ok=True) torchaudio.save( str(audio_output_path), audio.cpu(), sample_rate=audio_sample_rate, ) duration = audio.shape[1] / audio_sample_rate print(f"✓ Audio saved: {duration:.2f}s at {audio_sample_rate}Hz") print("\n" + "=" * 80) print("Generation complete!") print("=" * 80) if __name__ == "__main__": main()