#!/usr/bin/env python3 """Warm A2Vid driver — builds :class:`A2VidPipelineTwoStage` once and runs multiple scenes back-to-back without reloading Gemma, the 22B DiT, VAEs or the upsampler between calls. Saves ~85 s of model-load time per scene. Scenes are declared in a JSON manifest: { "scenes": [ { "name": "scene01", "audio": "scene01.wav", "prompt": "...", "num_frames": 241, "tail_from": null # no conditioning for first scene }, { "name": "scene02", "audio": "scene02.wav", "prompt": "...", "num_frames": 201, "tail_from": "scene01", # pin first 24 frames to scene01 tail "tail_seconds": 1.0, "tail_strength": 0.7 } ] } Each scene writes /.mp4 and its tail-frames get extracted if a later scene references it. """ import argparse import json import logging import os import subprocess import sys import time from pathlib import Path import torch def extract_tail_frames(mp4: Path, out_prefix: Path, seconds: float, fps: float) -> list[Path]: """Extract the last ``seconds`` seconds of ``mp4`` as PNGs starting at index 0.""" dur = float(subprocess.check_output( ["ffprobe", "-v", "error", "-select_streams", "v:0", "-show_entries", "stream=duration", "-of", "csv=p=0", str(mp4)], ).decode().strip()) start = max(0.0, dur - seconds - 0.05) n_frames = int(round(seconds * fps)) # Clean stale for p in out_prefix.parent.glob(f"{out_prefix.name}_*.png"): p.unlink() subprocess.run( ["ffmpeg", "-y", "-ss", f"{start:.3f}", "-i", str(mp4), "-vf", f"fps={fps}", "-frames:v", str(n_frames), "-start_number", "0", f"{out_prefix}_%03d.png", "-loglevel", "error"], check=True, ) return sorted(out_prefix.parent.glob(f"{out_prefix.name}_*.png")) def main(): ap = argparse.ArgumentParser() ap.add_argument("--manifest", required=True, help="JSON scene manifest") ap.add_argument("--out-dir", required=True) ap.add_argument("--checkpoint-path", required=True) ap.add_argument("--gemma-root", required=True) ap.add_argument("--spatial-upsampler-path", required=True) ap.add_argument("--distilled-lora", required=True) ap.add_argument("--quantization", default="fp8-cast", choices=["fp8-cast", "none"]) ap.add_argument("--bnb-4bit", action="store_true", default=True, help="Load Gemma via bnb-4bit path (default on).") ap.add_argument("--no-bnb-4bit", dest="bnb_4bit", action="store_false") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--num-inference-steps", type=int, default=30) ap.add_argument("--height", type=int, default=512) ap.add_argument("--width", type=int, default=768) ap.add_argument("--frame-rate", type=float, default=24.0) # Defaults for guider params (match a2vid_two_stage CLI defaults) ap.add_argument("--cfg-scale", type=float, default=2.5) ap.add_argument("--stg-scale", type=float, default=1.0) ap.add_argument("--rescale-scale", type=float, default=0.7) ap.add_argument("--modality-scale", type=float, default=2.5) ap.add_argument("--negative-prompt", default= "low quality, worst quality, blurry, distorted, artifacts, watermark, text, caption") args = ap.parse_args() logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) # Import after argparse so --help is instant. from ltx_core.loader.registry import Registry from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number from ltx_core.quantization import QuantizationPolicy from ltx_pipelines.a2vid_two_stage import A2VidPipelineTwoStage from ltx_pipelines.utils.args import ImageConditioningInput from ltx_pipelines.utils.blocks import PromptEncoder from ltx_core.components.guiders import MultiModalGuiderParams from ltx_pipelines.utils.media_io import encode_video quant = QuantizationPolicy.fp8_cast() if args.quantization == "fp8-cast" else None registry = Registry() logging.info("Building warm A2Vid pipeline (loads Gemma, VAEs, DiT, upsampler)...") t0 = time.time() pipeline = A2VidPipelineTwoStage( checkpoint_path=args.checkpoint_path, distilled_lora=[(args.distilled_lora, 1.0, None)] if False else [], # see __init__ spatial_upsampler_path=args.spatial_upsampler_path, gemma_root=args.gemma_root, loras=(), quantization=quant, registry=registry, ) # Replace the pipeline's PromptEncoder with a warm+bnb one so subsequent # calls skip the Gemma load. A2VidPipelineTwoStage stored it as # self.prompt_encoder. logging.info("Replacing PromptEncoder with warm + bnb-4bit variant...") pipeline.prompt_encoder = PromptEncoder( checkpoint_path=args.checkpoint_path, gemma_root=args.gemma_root, dtype=torch.bfloat16, device=pipeline.device, registry=registry, warm=True, use_bnb_4bit=args.bnb_4bit, ) logging.info(f"Pipeline ready in {time.time() - t0:.1f}s") manifest = json.loads(Path(args.manifest).read_text()) tiling = TilingConfig.default() mp4_paths: dict[str, Path] = {} for scene in manifest["scenes"]: name = scene["name"] mp4 = out_dir / f"{name}.mp4" mp4_paths[name] = mp4 if mp4.exists(): logging.info(f"[{name}] skipping — already exists") continue num_frames = int(scene["num_frames"]) # Build image conditioning from an earlier scene's tail, if specified. images: list[ImageConditioningInput] = [] tail_from = scene.get("tail_from") if tail_from: src_mp4 = mp4_paths.get(tail_from) if src_mp4 is None or not src_mp4.exists(): raise RuntimeError(f"scene {name} needs tail from {tail_from} which hasn't been generated") secs = float(scene.get("tail_seconds", 1.0)) strength = float(scene.get("tail_strength", 0.7)) prefix = out_dir / f"{tail_from}_tail" logging.info(f"[{name}] extracting tail ({secs}s @ {args.frame_rate}fps) from {src_mp4.name}") tail_pngs = extract_tail_frames(src_mp4, prefix, secs, args.frame_rate) for i, png in enumerate(tail_pngs): images.append(ImageConditioningInput(str(png), i, strength)) logging.info(f"[{name}] generating {num_frames} frames, {len(images)} conditioning images") t1 = time.time() video, audio = pipeline( prompt=scene["prompt"], negative_prompt=args.negative_prompt, seed=args.seed, height=args.height, width=args.width, num_frames=num_frames, frame_rate=args.frame_rate, num_inference_steps=args.num_inference_steps, video_guider_params=MultiModalGuiderParams( cfg_scale=args.cfg_scale, stg_scale=args.stg_scale, rescale_scale=args.rescale_scale, modality_scale=args.modality_scale, ), images=images, tiling_config=tiling, audio_path=scene["audio"], audio_start_time=0.0, audio_max_duration=num_frames / args.frame_rate, ) encode_video( video=video, fps=args.frame_rate, audio=audio, output_path=str(mp4), video_chunks_number=get_video_chunks_number(num_frames, tiling), ) logging.info(f"[{name}] done in {time.time() - t1:.1f}s -> {mp4}") logging.info("All scenes done.") if __name__ == "__main__": sys.exit(main())