| |
| """Warm A2Vid driver — builds :class:`A2VidPipelineTwoStage` once and runs |
| multiple scenes back-to-back without reloading Gemma, the 22B DiT, VAEs or |
| the upsampler between calls. Saves ~85 s of model-load time per scene. |
| |
| Scenes are declared in a JSON manifest: |
| |
| { |
| "scenes": [ |
| { |
| "name": "scene01", |
| "audio": "scene01.wav", |
| "prompt": "...", |
| "num_frames": 241, |
| "tail_from": null # no conditioning for first scene |
| }, |
| { |
| "name": "scene02", |
| "audio": "scene02.wav", |
| "prompt": "...", |
| "num_frames": 201, |
| "tail_from": "scene01", # pin first 24 frames to scene01 tail |
| "tail_seconds": 1.0, |
| "tail_strength": 0.7 |
| } |
| ] |
| } |
| |
| Each scene writes <out>/<name>.mp4 and its tail-frames get extracted if a |
| later scene references it. |
| """ |
| import argparse |
| import json |
| import logging |
| import os |
| import subprocess |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
|
|
|
|
| def extract_tail_frames(mp4: Path, out_prefix: Path, seconds: float, fps: float) -> list[Path]: |
| """Extract the last ``seconds`` seconds of ``mp4`` as PNGs starting at index 0.""" |
| dur = float(subprocess.check_output( |
| ["ffprobe", "-v", "error", "-select_streams", "v:0", |
| "-show_entries", "stream=duration", "-of", "csv=p=0", str(mp4)], |
| ).decode().strip()) |
| start = max(0.0, dur - seconds - 0.05) |
| n_frames = int(round(seconds * fps)) |
| |
| for p in out_prefix.parent.glob(f"{out_prefix.name}_*.png"): |
| p.unlink() |
| subprocess.run( |
| ["ffmpeg", "-y", "-ss", f"{start:.3f}", "-i", str(mp4), |
| "-vf", f"fps={fps}", "-frames:v", str(n_frames), |
| "-start_number", "0", f"{out_prefix}_%03d.png", |
| "-loglevel", "error"], |
| check=True, |
| ) |
| return sorted(out_prefix.parent.glob(f"{out_prefix.name}_*.png")) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--manifest", required=True, help="JSON scene manifest") |
| ap.add_argument("--out-dir", required=True) |
| ap.add_argument("--checkpoint-path", required=True) |
| ap.add_argument("--gemma-root", required=True) |
| ap.add_argument("--spatial-upsampler-path", required=True) |
| ap.add_argument("--distilled-lora", required=True) |
| ap.add_argument("--quantization", default="fp8-cast", |
| choices=["fp8-cast", "none"]) |
| ap.add_argument("--bnb-4bit", action="store_true", default=True, |
| help="Load Gemma via bnb-4bit path (default on).") |
| ap.add_argument("--no-bnb-4bit", dest="bnb_4bit", action="store_false") |
| ap.add_argument("--seed", type=int, default=42) |
| ap.add_argument("--num-inference-steps", type=int, default=30) |
| ap.add_argument("--height", type=int, default=512) |
| ap.add_argument("--width", type=int, default=768) |
| ap.add_argument("--frame-rate", type=float, default=24.0) |
| |
| ap.add_argument("--cfg-scale", type=float, default=2.5) |
| ap.add_argument("--stg-scale", type=float, default=1.0) |
| ap.add_argument("--rescale-scale", type=float, default=0.7) |
| ap.add_argument("--modality-scale", type=float, default=2.5) |
| ap.add_argument("--negative-prompt", default= |
| "low quality, worst quality, blurry, distorted, artifacts, watermark, text, caption") |
| args = ap.parse_args() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| from ltx_core.loader.registry import Registry |
| from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number |
| from ltx_core.quantization import QuantizationPolicy |
| from ltx_pipelines.a2vid_two_stage import A2VidPipelineTwoStage |
| from ltx_pipelines.utils.args import ImageConditioningInput |
| from ltx_pipelines.utils.blocks import PromptEncoder |
| from ltx_core.components.guiders import MultiModalGuiderParams |
| from ltx_pipelines.utils.media_io import encode_video |
|
|
| quant = QuantizationPolicy.fp8_cast() if args.quantization == "fp8-cast" else None |
| registry = Registry() |
|
|
| logging.info("Building warm A2Vid pipeline (loads Gemma, VAEs, DiT, upsampler)...") |
| t0 = time.time() |
| pipeline = A2VidPipelineTwoStage( |
| checkpoint_path=args.checkpoint_path, |
| distilled_lora=[(args.distilled_lora, 1.0, None)] if False else [], |
| spatial_upsampler_path=args.spatial_upsampler_path, |
| gemma_root=args.gemma_root, |
| loras=(), |
| quantization=quant, |
| registry=registry, |
| ) |
| |
| |
| |
| logging.info("Replacing PromptEncoder with warm + bnb-4bit variant...") |
| pipeline.prompt_encoder = PromptEncoder( |
| checkpoint_path=args.checkpoint_path, |
| gemma_root=args.gemma_root, |
| dtype=torch.bfloat16, |
| device=pipeline.device, |
| registry=registry, |
| warm=True, |
| use_bnb_4bit=args.bnb_4bit, |
| ) |
| logging.info(f"Pipeline ready in {time.time() - t0:.1f}s") |
|
|
| manifest = json.loads(Path(args.manifest).read_text()) |
| tiling = TilingConfig.default() |
| mp4_paths: dict[str, Path] = {} |
|
|
| for scene in manifest["scenes"]: |
| name = scene["name"] |
| mp4 = out_dir / f"{name}.mp4" |
| mp4_paths[name] = mp4 |
| if mp4.exists(): |
| logging.info(f"[{name}] skipping — already exists") |
| continue |
|
|
| num_frames = int(scene["num_frames"]) |
| |
| images: list[ImageConditioningInput] = [] |
| tail_from = scene.get("tail_from") |
| if tail_from: |
| src_mp4 = mp4_paths.get(tail_from) |
| if src_mp4 is None or not src_mp4.exists(): |
| raise RuntimeError(f"scene {name} needs tail from {tail_from} which hasn't been generated") |
| secs = float(scene.get("tail_seconds", 1.0)) |
| strength = float(scene.get("tail_strength", 0.7)) |
| prefix = out_dir / f"{tail_from}_tail" |
| logging.info(f"[{name}] extracting tail ({secs}s @ {args.frame_rate}fps) from {src_mp4.name}") |
| tail_pngs = extract_tail_frames(src_mp4, prefix, secs, args.frame_rate) |
| for i, png in enumerate(tail_pngs): |
| images.append(ImageConditioningInput(str(png), i, strength)) |
|
|
| logging.info(f"[{name}] generating {num_frames} frames, {len(images)} conditioning images") |
| t1 = time.time() |
| video, audio = pipeline( |
| prompt=scene["prompt"], |
| negative_prompt=args.negative_prompt, |
| seed=args.seed, |
| height=args.height, |
| width=args.width, |
| num_frames=num_frames, |
| frame_rate=args.frame_rate, |
| num_inference_steps=args.num_inference_steps, |
| video_guider_params=MultiModalGuiderParams( |
| cfg_scale=args.cfg_scale, |
| stg_scale=args.stg_scale, |
| rescale_scale=args.rescale_scale, |
| modality_scale=args.modality_scale, |
| ), |
| images=images, |
| tiling_config=tiling, |
| audio_path=scene["audio"], |
| audio_start_time=0.0, |
| audio_max_duration=num_frames / args.frame_rate, |
| ) |
| encode_video( |
| video=video, fps=args.frame_rate, audio=audio, |
| output_path=str(mp4), |
| video_chunks_number=get_video_chunks_number(num_frames, tiling), |
| ) |
| logging.info(f"[{name}] done in {time.time() - t1:.1f}s -> {mp4}") |
|
|
| logging.info("All scenes done.") |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|