|
|
|
|
|
|
|
|
""" |
|
|
CLI script for running LTX video/audio generation inference. |
|
|
|
|
|
Usage: |
|
|
# Text-to-Video + Audio (default behavior) |
|
|
python scripts/inference.py --checkpoint path/to/model.safetensors \ |
|
|
--text-encoder-path path/to/gemma \ |
|
|
--prompt "A cat playing with a ball" --output output.mp4 |
|
|
|
|
|
# Video only (skip audio) |
|
|
python scripts/inference.py --checkpoint path/to/model.safetensors \ |
|
|
--text-encoder-path path/to/gemma \ |
|
|
--prompt "A cat playing with a ball" --skip-audio --output output.mp4 |
|
|
|
|
|
# Image-to-Video |
|
|
python scripts/inference.py --checkpoint path/to/model.safetensors \ |
|
|
--text-encoder-path path/to/gemma \ |
|
|
--prompt "A cat walking" --condition-image first_frame.png --output output.mp4 |
|
|
|
|
|
# Video-to-Video (IC-LoRA style) |
|
|
python scripts/inference.py --checkpoint path/to/model.safetensors \ |
|
|
--text-encoder-path path/to/gemma \ |
|
|
--prompt "A cat turning into a dog" --reference-video input.mp4 --output output.mp4 |
|
|
|
|
|
# With LoRA weights |
|
|
python scripts/inference.py --checkpoint path/to/model.safetensors \ |
|
|
--text-encoder-path path/to/gemma \ |
|
|
--lora-path path/to/lora.safetensors \ |
|
|
--prompt "A cat in my custom style" --output output.mp4 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import re |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict |
|
|
from safetensors.torch import load_file |
|
|
from torchvision import transforms |
|
|
|
|
|
from ltx_trainer.model_loader import load_model |
|
|
from ltx_trainer.progress import StandaloneSamplingProgress |
|
|
from ltx_trainer.utils import open_image_as_srgb |
|
|
from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler |
|
|
from ltx_trainer.video_utils import read_video, save_video |
|
|
|
|
|
|
|
|
def load_image(image_path: str) -> torch.Tensor: |
|
|
"""Load an image and convert to tensor [C, H, W] in [0, 1].""" |
|
|
image = open_image_as_srgb(image_path) |
|
|
transform = transforms.ToTensor() |
|
|
return transform(image) |
|
|
|
|
|
|
|
|
def extract_lora_target_modules(state_dict: dict[str, torch.Tensor]) -> list[str]: |
|
|
"""Extract target module names from LoRA checkpoint keys. |
|
|
|
|
|
LoRA keys follow the pattern (after removing "diffusion_model." prefix): |
|
|
- transformer_blocks.0.attn1.to_k.lora_A.weight |
|
|
- transformer_blocks.0.ff.net.0.proj.lora_B.weight |
|
|
|
|
|
This extracts the full module path like "transformer_blocks.0.attn1.to_k". |
|
|
Using full paths is more robust than partial patterns. |
|
|
""" |
|
|
target_modules = set() |
|
|
|
|
|
pattern = re.compile(r"(.+)\.lora_[AB]\.") |
|
|
|
|
|
for key in state_dict: |
|
|
match = pattern.match(key) |
|
|
if match: |
|
|
module_path = match.group(1) |
|
|
target_modules.add(module_path) |
|
|
|
|
|
return sorted(target_modules) |
|
|
|
|
|
|
|
|
def load_lora_weights(transformer: torch.nn.Module, lora_path: str | Path) -> torch.nn.Module: |
|
|
"""Load LoRA weights into the transformer model. |
|
|
|
|
|
The LoRA rank and target modules are automatically detected from the checkpoint. |
|
|
Alpha is set equal to rank (standard practice for inference). |
|
|
|
|
|
Args: |
|
|
transformer: The base transformer model |
|
|
lora_path: Path to the LoRA weights (.safetensors) |
|
|
|
|
|
Returns: |
|
|
The transformer model with LoRA weights applied |
|
|
""" |
|
|
print(f"Loading LoRA weights from {lora_path}...") |
|
|
|
|
|
|
|
|
state_dict = load_file(str(lora_path)) |
|
|
|
|
|
|
|
|
state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()} |
|
|
|
|
|
|
|
|
target_modules = extract_lora_target_modules(state_dict) |
|
|
if not target_modules: |
|
|
raise ValueError(f"Could not extract target modules from LoRA checkpoint: {lora_path}") |
|
|
print(f" Detected {len(target_modules)} target modules") |
|
|
|
|
|
|
|
|
lora_rank = None |
|
|
for key, value in state_dict.items(): |
|
|
if "lora_A" in key and value.ndim == 2: |
|
|
lora_rank = value.shape[0] |
|
|
break |
|
|
if lora_rank is None: |
|
|
raise ValueError("Could not auto-detect LoRA rank from weights") |
|
|
print(f" LoRA rank: {lora_rank}") |
|
|
|
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=lora_rank, |
|
|
lora_alpha=lora_rank, |
|
|
target_modules=target_modules, |
|
|
lora_dropout=0.0, |
|
|
init_lora_weights=True, |
|
|
) |
|
|
|
|
|
|
|
|
transformer = get_peft_model(transformer, lora_config) |
|
|
|
|
|
|
|
|
base_model = transformer.get_base_model() |
|
|
set_peft_model_state_dict(base_model, state_dict) |
|
|
|
|
|
print("✓ LoRA weights loaded successfully") |
|
|
return transformer |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="LTX Video/Audio Generation", |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to model checkpoint (.safetensors)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text-encoder-path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to Gemma text encoder directory", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--lora-path", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to LoRA weights (.safetensors)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Text prompt for generation", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--negative-prompt", |
|
|
type=str, |
|
|
default="", |
|
|
help="Negative prompt", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--height", |
|
|
type=int, |
|
|
default=544, |
|
|
help="Video height (must be divisible by 32)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--width", |
|
|
type=int, |
|
|
default=960, |
|
|
help="Video width (must be divisible by 32)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-frames", |
|
|
type=int, |
|
|
default=97, |
|
|
help="Number of video frames (must be k*8 + 1)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frame-rate", |
|
|
type=float, |
|
|
default=25.0, |
|
|
help="Video frame rate", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-inference-steps", |
|
|
type=int, |
|
|
default=30, |
|
|
help="Number of denoising steps", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--guidance-scale", |
|
|
type=float, |
|
|
default=3.0, |
|
|
help="Classifier-free guidance scale (CFG)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--stg-scale", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. Default: 1.0", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--stg-blocks", |
|
|
type=int, |
|
|
nargs="*", |
|
|
default=[29], |
|
|
help="Which transformer blocks to perturb for STG. Default: 29 (single block).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--stg-mode", |
|
|
type=str, |
|
|
default="stg_av", |
|
|
choices=["stg_av", "stg_v"], |
|
|
help="STG mode: 'stg_av' perturbs both audio and video, 'stg_v' perturbs video only", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=42, |
|
|
help="Random seed for reproducibility", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--condition-image", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to conditioning image for image-to-video generation", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--reference-video", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to reference video for video-to-video generation (IC-LoRA style)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--include-reference-in-output", |
|
|
action="store_true", |
|
|
help="Include reference video side-by-side with generated output (only for V2V)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--skip-audio", |
|
|
action="store_true", |
|
|
help="Skip audio generation (by default, audio is generated alongside video)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Output video path (.mp4)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--audio-output", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Output audio path (.wav, optional - if not provided, audio will be embedded in video)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cuda", |
|
|
help="Device to run on (cuda/cpu)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.include_reference_in_output and args.reference_video is None: |
|
|
parser.error("--include-reference-in-output requires --reference-video") |
|
|
|
|
|
|
|
|
generate_audio = not args.skip_audio |
|
|
|
|
|
print("=" * 80) |
|
|
print("LTX Video/Audio Generation") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
need_vae_encoder = args.condition_image is not None or args.reference_video is not None |
|
|
|
|
|
components = load_model( |
|
|
checkpoint_path=args.checkpoint, |
|
|
device="cpu", |
|
|
dtype=torch.bfloat16, |
|
|
with_video_vae_encoder=need_vae_encoder, |
|
|
with_video_vae_decoder=True, |
|
|
with_audio_vae_decoder=generate_audio, |
|
|
with_vocoder=generate_audio, |
|
|
with_text_encoder=True, |
|
|
text_encoder_path=args.text_encoder_path, |
|
|
) |
|
|
|
|
|
|
|
|
transformer = components.transformer |
|
|
if args.lora_path is not None: |
|
|
transformer = load_lora_weights(transformer, args.lora_path) |
|
|
|
|
|
|
|
|
condition_image = None |
|
|
if args.condition_image: |
|
|
print(f"Loading conditioning image from {args.condition_image}...") |
|
|
condition_image = load_image(args.condition_image) |
|
|
|
|
|
|
|
|
reference_video = None |
|
|
if args.reference_video: |
|
|
print(f"Loading reference video from {args.reference_video}...") |
|
|
reference_video, ref_fps = read_video(args.reference_video, max_frames=args.num_frames) |
|
|
print(f" Loaded {reference_video.shape[0]} frames @ {ref_fps:.1f} fps") |
|
|
|
|
|
|
|
|
if args.reference_video is not None and args.condition_image is not None: |
|
|
mode = "Video-to-Video + Image Conditioning (V2V+I2V)" |
|
|
elif args.reference_video is not None: |
|
|
mode = "Video-to-Video (V2V)" |
|
|
elif args.condition_image is not None: |
|
|
mode = "Image-to-Video (I2V)" |
|
|
else: |
|
|
mode = "Text-to-Video (T2V)" |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Generation Parameters") |
|
|
print("=" * 80) |
|
|
print(f"Mode: {mode}") |
|
|
print(f"Prompt: {args.prompt}") |
|
|
if args.negative_prompt: |
|
|
print(f"Negative prompt: {args.negative_prompt}") |
|
|
print(f"Resolution: {args.width}x{args.height}") |
|
|
print(f"Frames: {args.num_frames} @ {args.frame_rate} fps") |
|
|
print(f"Inference steps: {args.num_inference_steps}") |
|
|
print(f"CFG scale: {args.guidance_scale}") |
|
|
if args.stg_scale > 0: |
|
|
blocks_str = args.stg_blocks if args.stg_blocks else "all" |
|
|
print(f"STG scale: {args.stg_scale} (mode: {args.stg_mode}, blocks: {blocks_str})") |
|
|
else: |
|
|
print("STG: disabled") |
|
|
print(f"Seed: {args.seed}") |
|
|
if args.lora_path: |
|
|
print(f"LoRA: {args.lora_path}") |
|
|
if condition_image is not None: |
|
|
print(f"Conditioning: Image ({args.condition_image})") |
|
|
if reference_video is not None: |
|
|
print(f"Reference: Video ({args.reference_video})") |
|
|
if args.include_reference_in_output: |
|
|
print(" → Will include reference side-by-side in output") |
|
|
if generate_audio: |
|
|
video_duration = args.num_frames / args.frame_rate |
|
|
print(f"Audio: Enabled (duration will match video: {video_duration:.2f}s)") |
|
|
print("=" * 80) |
|
|
|
|
|
print(f"\nGenerating {'video + audio' if generate_audio else 'video'}...") |
|
|
|
|
|
|
|
|
gen_config = GenerationConfig( |
|
|
prompt=args.prompt, |
|
|
negative_prompt=args.negative_prompt, |
|
|
height=args.height, |
|
|
width=args.width, |
|
|
num_frames=args.num_frames, |
|
|
frame_rate=args.frame_rate, |
|
|
num_inference_steps=args.num_inference_steps, |
|
|
guidance_scale=args.guidance_scale, |
|
|
seed=args.seed, |
|
|
condition_image=condition_image, |
|
|
reference_video=reference_video, |
|
|
generate_audio=generate_audio, |
|
|
include_reference_in_output=args.include_reference_in_output, |
|
|
stg_scale=args.stg_scale, |
|
|
stg_blocks=args.stg_blocks, |
|
|
stg_mode=args.stg_mode, |
|
|
) |
|
|
|
|
|
|
|
|
with StandaloneSamplingProgress(num_steps=args.num_inference_steps) as progress: |
|
|
|
|
|
sampler = ValidationSampler( |
|
|
transformer=transformer, |
|
|
vae_decoder=components.video_vae_decoder, |
|
|
vae_encoder=components.video_vae_encoder, |
|
|
text_encoder=components.text_encoder, |
|
|
audio_decoder=components.audio_vae_decoder if generate_audio else None, |
|
|
vocoder=components.vocoder if generate_audio else None, |
|
|
sampling_context=progress, |
|
|
) |
|
|
video, audio = sampler.generate( |
|
|
config=gen_config, |
|
|
device=args.device, |
|
|
) |
|
|
|
|
|
|
|
|
output_path = Path(args.output) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
audio_sample_rate = None |
|
|
if audio is not None and components.vocoder is not None: |
|
|
audio_sample_rate = components.vocoder.output_sample_rate |
|
|
|
|
|
save_video( |
|
|
video_tensor=video, |
|
|
output_path=output_path, |
|
|
fps=args.frame_rate, |
|
|
audio=audio, |
|
|
audio_sample_rate=audio_sample_rate, |
|
|
) |
|
|
print(f"✓ Video saved to {args.output}") |
|
|
|
|
|
|
|
|
if audio is not None and args.audio_output is not None: |
|
|
audio_output_path = Path(args.audio_output) |
|
|
audio_output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
torchaudio.save( |
|
|
str(audio_output_path), |
|
|
audio.cpu(), |
|
|
sample_rate=audio_sample_rate, |
|
|
) |
|
|
duration = audio.shape[1] / audio_sample_rate |
|
|
print(f"✓ Audio saved: {duration:.2f}s at {audio_sample_rate}Hz") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Generation complete!") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|