linoy
inital commit
ebfc6b3
#!/usr/bin/env python3
# ruff: noqa: T201
"""
CLI script for running LTX video/audio generation inference.
Usage:
# Text-to-Video + Audio (default behavior)
python scripts/inference.py --checkpoint path/to/model.safetensors \
--text-encoder-path path/to/gemma \
--prompt "A cat playing with a ball" --output output.mp4
# Video only (skip audio)
python scripts/inference.py --checkpoint path/to/model.safetensors \
--text-encoder-path path/to/gemma \
--prompt "A cat playing with a ball" --skip-audio --output output.mp4
# Image-to-Video
python scripts/inference.py --checkpoint path/to/model.safetensors \
--text-encoder-path path/to/gemma \
--prompt "A cat walking" --condition-image first_frame.png --output output.mp4
# Video-to-Video (IC-LoRA style)
python scripts/inference.py --checkpoint path/to/model.safetensors \
--text-encoder-path path/to/gemma \
--prompt "A cat turning into a dog" --reference-video input.mp4 --output output.mp4
# With LoRA weights
python scripts/inference.py --checkpoint path/to/model.safetensors \
--text-encoder-path path/to/gemma \
--lora-path path/to/lora.safetensors \
--prompt "A cat in my custom style" --output output.mp4
"""
import argparse
import re
from pathlib import Path
import torch
import torchaudio
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
from safetensors.torch import load_file
from torchvision import transforms
from ltx_trainer.model_loader import load_model
from ltx_trainer.progress import StandaloneSamplingProgress
from ltx_trainer.utils import open_image_as_srgb
from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler
from ltx_trainer.video_utils import read_video, save_video
def load_image(image_path: str) -> torch.Tensor:
"""Load an image and convert to tensor [C, H, W] in [0, 1]."""
image = open_image_as_srgb(image_path)
transform = transforms.ToTensor()
return transform(image)
def extract_lora_target_modules(state_dict: dict[str, torch.Tensor]) -> list[str]:
"""Extract target module names from LoRA checkpoint keys.
LoRA keys follow the pattern (after removing "diffusion_model." prefix):
- transformer_blocks.0.attn1.to_k.lora_A.weight
- transformer_blocks.0.ff.net.0.proj.lora_B.weight
This extracts the full module path like "transformer_blocks.0.attn1.to_k".
Using full paths is more robust than partial patterns.
"""
target_modules = set()
# Pattern to extract everything before .lora_A or .lora_B
pattern = re.compile(r"(.+)\.lora_[AB]\.")
for key in state_dict:
match = pattern.match(key)
if match:
module_path = match.group(1)
target_modules.add(module_path)
return sorted(target_modules)
def load_lora_weights(transformer: torch.nn.Module, lora_path: str | Path) -> torch.nn.Module:
"""Load LoRA weights into the transformer model.
The LoRA rank and target modules are automatically detected from the checkpoint.
Alpha is set equal to rank (standard practice for inference).
Args:
transformer: The base transformer model
lora_path: Path to the LoRA weights (.safetensors)
Returns:
The transformer model with LoRA weights applied
"""
print(f"Loading LoRA weights from {lora_path}...")
# Load the LoRA state dict
state_dict = load_file(str(lora_path))
# Remove "diffusion_model." prefix (ComfyUI-compatible format)
state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()}
# Extract target modules from the checkpoint
target_modules = extract_lora_target_modules(state_dict)
if not target_modules:
raise ValueError(f"Could not extract target modules from LoRA checkpoint: {lora_path}")
print(f" Detected {len(target_modules)} target modules")
# Auto-detect rank from the first lora_A weight shape
lora_rank = None
for key, value in state_dict.items():
if "lora_A" in key and value.ndim == 2:
lora_rank = value.shape[0]
break
if lora_rank is None:
raise ValueError("Could not auto-detect LoRA rank from weights")
print(f" LoRA rank: {lora_rank}")
# Create LoRA config and wrap the model
# Alpha = rank is standard for inference (maintains the trained scale)
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank,
target_modules=target_modules,
lora_dropout=0.0,
init_lora_weights=True,
)
# Wrap the transformer with PEFT to add LoRA layers
transformer = get_peft_model(transformer, lora_config)
# Load the LoRA weights
base_model = transformer.get_base_model()
set_peft_model_state_dict(base_model, state_dict)
print("✓ LoRA weights loaded successfully")
return transformer
def main() -> None: # noqa: PLR0912, PLR0915
parser = argparse.ArgumentParser(
description="LTX Video/Audio Generation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model arguments
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to model checkpoint (.safetensors)",
)
parser.add_argument(
"--text-encoder-path",
type=str,
required=True,
help="Path to Gemma text encoder directory",
)
# LoRA arguments
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to LoRA weights (.safetensors)",
)
# Generation arguments
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Text prompt for generation",
)
parser.add_argument(
"--negative-prompt",
type=str,
default="",
help="Negative prompt",
)
parser.add_argument(
"--height",
type=int,
default=544,
help="Video height (must be divisible by 32)",
)
parser.add_argument(
"--width",
type=int,
default=960,
help="Video width (must be divisible by 32)",
)
parser.add_argument(
"--num-frames",
type=int,
default=97,
help="Number of video frames (must be k*8 + 1)",
)
parser.add_argument(
"--frame-rate",
type=float,
default=25.0,
help="Video frame rate",
)
parser.add_argument(
"--num-inference-steps",
type=int,
default=30,
help="Number of denoising steps",
)
parser.add_argument(
"--guidance-scale",
type=float,
default=3.0,
help="Classifier-free guidance scale (CFG)",
)
parser.add_argument(
"--stg-scale",
type=float,
default=1.0,
help="STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. Default: 1.0",
)
parser.add_argument(
"--stg-blocks",
type=int,
nargs="*",
default=[29],
help="Which transformer blocks to perturb for STG. Default: 29 (single block).",
)
parser.add_argument(
"--stg-mode",
type=str,
default="stg_av",
choices=["stg_av", "stg_v"],
help="STG mode: 'stg_av' perturbs both audio and video, 'stg_v' perturbs video only",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility",
)
# Conditioning arguments
parser.add_argument(
"--condition-image",
type=str,
default=None,
help="Path to conditioning image for image-to-video generation",
)
parser.add_argument(
"--reference-video",
type=str,
default=None,
help="Path to reference video for video-to-video generation (IC-LoRA style)",
)
parser.add_argument(
"--include-reference-in-output",
action="store_true",
help="Include reference video side-by-side with generated output (only for V2V)",
)
# Audio arguments
parser.add_argument(
"--skip-audio",
action="store_true",
help="Skip audio generation (by default, audio is generated alongside video)",
)
# Output arguments
parser.add_argument(
"--output",
type=str,
required=True,
help="Output video path (.mp4)",
)
parser.add_argument(
"--audio-output",
type=str,
default=None,
help="Output audio path (.wav, optional - if not provided, audio will be embedded in video)",
)
# Device arguments
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run on (cuda/cpu)",
)
args = parser.parse_args()
# Validate conditioning arguments
if args.include_reference_in_output and args.reference_video is None:
parser.error("--include-reference-in-output requires --reference-video")
# Validate arguments
generate_audio = not args.skip_audio
print("=" * 80)
print("LTX Video/Audio Generation")
print("=" * 80)
# Determine if we need VAE encoder (for image or video conditioning)
need_vae_encoder = args.condition_image is not None or args.reference_video is not None
components = load_model(
checkpoint_path=args.checkpoint,
device="cpu", # Load to CPU first, sampler will move to device as needed
dtype=torch.bfloat16,
with_video_vae_encoder=need_vae_encoder,
with_video_vae_decoder=True,
with_audio_vae_decoder=generate_audio,
with_vocoder=generate_audio,
with_text_encoder=True,
text_encoder_path=args.text_encoder_path,
)
# Apply LoRA weights if provided
transformer = components.transformer
if args.lora_path is not None:
transformer = load_lora_weights(transformer, args.lora_path)
# Load conditioning image if provided
condition_image = None
if args.condition_image:
print(f"Loading conditioning image from {args.condition_image}...")
condition_image = load_image(args.condition_image)
# Load reference video if provided
reference_video = None
if args.reference_video:
print(f"Loading reference video from {args.reference_video}...")
reference_video, ref_fps = read_video(args.reference_video, max_frames=args.num_frames)
print(f" Loaded {reference_video.shape[0]} frames @ {ref_fps:.1f} fps")
# Determine generation mode
if args.reference_video is not None and args.condition_image is not None:
mode = "Video-to-Video + Image Conditioning (V2V+I2V)"
elif args.reference_video is not None:
mode = "Video-to-Video (V2V)"
elif args.condition_image is not None:
mode = "Image-to-Video (I2V)"
else:
mode = "Text-to-Video (T2V)"
print("\n" + "=" * 80)
print("Generation Parameters")
print("=" * 80)
print(f"Mode: {mode}")
print(f"Prompt: {args.prompt}")
if args.negative_prompt:
print(f"Negative prompt: {args.negative_prompt}")
print(f"Resolution: {args.width}x{args.height}")
print(f"Frames: {args.num_frames} @ {args.frame_rate} fps")
print(f"Inference steps: {args.num_inference_steps}")
print(f"CFG scale: {args.guidance_scale}")
if args.stg_scale > 0:
blocks_str = args.stg_blocks if args.stg_blocks else "all"
print(f"STG scale: {args.stg_scale} (mode: {args.stg_mode}, blocks: {blocks_str})")
else:
print("STG: disabled")
print(f"Seed: {args.seed}")
if args.lora_path:
print(f"LoRA: {args.lora_path}")
if condition_image is not None:
print(f"Conditioning: Image ({args.condition_image})")
if reference_video is not None:
print(f"Reference: Video ({args.reference_video})")
if args.include_reference_in_output:
print(" → Will include reference side-by-side in output")
if generate_audio:
video_duration = args.num_frames / args.frame_rate
print(f"Audio: Enabled (duration will match video: {video_duration:.2f}s)")
print("=" * 80)
print(f"\nGenerating {'video + audio' if generate_audio else 'video'}...")
# Create generation config
gen_config = GenerationConfig(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
condition_image=condition_image,
reference_video=reference_video,
generate_audio=generate_audio,
include_reference_in_output=args.include_reference_in_output,
stg_scale=args.stg_scale,
stg_blocks=args.stg_blocks,
stg_mode=args.stg_mode,
)
# Generate with progress bar
with StandaloneSamplingProgress(num_steps=args.num_inference_steps) as progress:
# Create sampler with progress context
sampler = ValidationSampler(
transformer=transformer,
vae_decoder=components.video_vae_decoder,
vae_encoder=components.video_vae_encoder,
text_encoder=components.text_encoder,
audio_decoder=components.audio_vae_decoder if generate_audio else None,
vocoder=components.vocoder if generate_audio else None,
sampling_context=progress,
)
video, audio = sampler.generate(
config=gen_config,
device=args.device,
)
# Save video
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Get audio sample rate from vocoder if audio was generated
audio_sample_rate = None
if audio is not None and components.vocoder is not None:
audio_sample_rate = components.vocoder.output_sample_rate
save_video(
video_tensor=video,
output_path=output_path,
fps=args.frame_rate,
audio=audio,
audio_sample_rate=audio_sample_rate,
)
print(f"✓ Video saved to {args.output}")
# Save separate audio file if requested
if audio is not None and args.audio_output is not None:
audio_output_path = Path(args.audio_output)
audio_output_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(
str(audio_output_path),
audio.cpu(),
sample_rate=audio_sample_rate,
)
duration = audio.shape[1] / audio_sample_rate
print(f"✓ Audio saved: {duration:.2f}s at {audio_sample_rate}Hz")
print("\n" + "=" * 80)
print("Generation complete!")
print("=" * 80)
if __name__ == "__main__":
main()