Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| """Warm A2Vid driver — builds :class:`A2VidPipelineTwoStage` once and runs | |
| multiple scenes back-to-back without reloading Gemma, the 22B DiT, VAEs or | |
| the upsampler between calls. Saves ~85 s of model-load time per scene. | |
| Scenes are declared in a JSON manifest: | |
| { | |
| "scenes": [ | |
| { | |
| "name": "scene01", | |
| "audio": "scene01.wav", | |
| "prompt": "...", | |
| "num_frames": 241, | |
| "tail_from": null # no conditioning for first scene | |
| }, | |
| { | |
| "name": "scene02", | |
| "audio": "scene02.wav", | |
| "prompt": "...", | |
| "num_frames": 201, | |
| "tail_from": "scene01", # pin first 24 frames to scene01 tail | |
| "tail_seconds": 1.0, | |
| "tail_strength": 0.7 | |
| } | |
| ] | |
| } | |
| Each scene writes <out>/<name>.mp4 and its tail-frames get extracted if a | |
| later scene references it. | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import torch | |
| def extract_tail_frames(mp4: Path, out_prefix: Path, seconds: float, fps: float) -> list[Path]: | |
| """Extract the last ``seconds`` seconds of ``mp4`` as PNGs starting at index 0.""" | |
| dur = float(subprocess.check_output( | |
| ["ffprobe", "-v", "error", "-select_streams", "v:0", | |
| "-show_entries", "stream=duration", "-of", "csv=p=0", str(mp4)], | |
| ).decode().strip()) | |
| start = max(0.0, dur - seconds - 0.05) | |
| n_frames = int(round(seconds * fps)) | |
| # Clean stale | |
| for p in out_prefix.parent.glob(f"{out_prefix.name}_*.png"): | |
| p.unlink() | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-ss", f"{start:.3f}", "-i", str(mp4), | |
| "-vf", f"fps={fps}", "-frames:v", str(n_frames), | |
| "-start_number", "0", f"{out_prefix}_%03d.png", | |
| "-loglevel", "error"], | |
| check=True, | |
| ) | |
| return sorted(out_prefix.parent.glob(f"{out_prefix.name}_*.png")) | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--manifest", required=True, help="JSON scene manifest") | |
| ap.add_argument("--out-dir", required=True) | |
| ap.add_argument("--checkpoint-path", required=True) | |
| ap.add_argument("--gemma-root", required=True) | |
| ap.add_argument("--spatial-upsampler-path", required=True) | |
| ap.add_argument("--distilled-lora", required=True) | |
| ap.add_argument("--quantization", default="fp8-cast", | |
| choices=["fp8-cast", "none"]) | |
| ap.add_argument("--bnb-4bit", action="store_true", default=True, | |
| help="Load Gemma via bnb-4bit path (default on).") | |
| ap.add_argument("--no-bnb-4bit", dest="bnb_4bit", action="store_false") | |
| ap.add_argument("--seed", type=int, default=42) | |
| ap.add_argument("--num-inference-steps", type=int, default=30) | |
| ap.add_argument("--height", type=int, default=512) | |
| ap.add_argument("--width", type=int, default=768) | |
| ap.add_argument("--frame-rate", type=float, default=24.0) | |
| # Defaults for guider params (match a2vid_two_stage CLI defaults) | |
| ap.add_argument("--cfg-scale", type=float, default=2.5) | |
| ap.add_argument("--stg-scale", type=float, default=1.0) | |
| ap.add_argument("--rescale-scale", type=float, default=0.7) | |
| ap.add_argument("--modality-scale", type=float, default=2.5) | |
| ap.add_argument("--negative-prompt", default= | |
| "low quality, worst quality, blurry, distorted, artifacts, watermark, text, caption") | |
| args = ap.parse_args() | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # Import after argparse so --help is instant. | |
| from ltx_core.loader.registry import Registry | |
| from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number | |
| from ltx_core.quantization import QuantizationPolicy | |
| from ltx_pipelines.a2vid_two_stage import A2VidPipelineTwoStage | |
| from ltx_pipelines.utils.args import ImageConditioningInput | |
| from ltx_pipelines.utils.blocks import PromptEncoder | |
| from ltx_core.components.guiders import MultiModalGuiderParams | |
| from ltx_pipelines.utils.media_io import encode_video | |
| quant = QuantizationPolicy.fp8_cast() if args.quantization == "fp8-cast" else None | |
| registry = Registry() | |
| logging.info("Building warm A2Vid pipeline (loads Gemma, VAEs, DiT, upsampler)...") | |
| t0 = time.time() | |
| pipeline = A2VidPipelineTwoStage( | |
| checkpoint_path=args.checkpoint_path, | |
| distilled_lora=[(args.distilled_lora, 1.0, None)] if False else [], # see __init__ | |
| spatial_upsampler_path=args.spatial_upsampler_path, | |
| gemma_root=args.gemma_root, | |
| loras=(), | |
| quantization=quant, | |
| registry=registry, | |
| ) | |
| # Replace the pipeline's PromptEncoder with a warm+bnb one so subsequent | |
| # calls skip the Gemma load. A2VidPipelineTwoStage stored it as | |
| # self.prompt_encoder. | |
| logging.info("Replacing PromptEncoder with warm + bnb-4bit variant...") | |
| pipeline.prompt_encoder = PromptEncoder( | |
| checkpoint_path=args.checkpoint_path, | |
| gemma_root=args.gemma_root, | |
| dtype=torch.bfloat16, | |
| device=pipeline.device, | |
| registry=registry, | |
| warm=True, | |
| use_bnb_4bit=args.bnb_4bit, | |
| ) | |
| logging.info(f"Pipeline ready in {time.time() - t0:.1f}s") | |
| manifest = json.loads(Path(args.manifest).read_text()) | |
| tiling = TilingConfig.default() | |
| mp4_paths: dict[str, Path] = {} | |
| for scene in manifest["scenes"]: | |
| name = scene["name"] | |
| mp4 = out_dir / f"{name}.mp4" | |
| mp4_paths[name] = mp4 | |
| if mp4.exists(): | |
| logging.info(f"[{name}] skipping — already exists") | |
| continue | |
| num_frames = int(scene["num_frames"]) | |
| # Build image conditioning from an earlier scene's tail, if specified. | |
| images: list[ImageConditioningInput] = [] | |
| tail_from = scene.get("tail_from") | |
| if tail_from: | |
| src_mp4 = mp4_paths.get(tail_from) | |
| if src_mp4 is None or not src_mp4.exists(): | |
| raise RuntimeError(f"scene {name} needs tail from {tail_from} which hasn't been generated") | |
| secs = float(scene.get("tail_seconds", 1.0)) | |
| strength = float(scene.get("tail_strength", 0.7)) | |
| prefix = out_dir / f"{tail_from}_tail" | |
| logging.info(f"[{name}] extracting tail ({secs}s @ {args.frame_rate}fps) from {src_mp4.name}") | |
| tail_pngs = extract_tail_frames(src_mp4, prefix, secs, args.frame_rate) | |
| for i, png in enumerate(tail_pngs): | |
| images.append(ImageConditioningInput(str(png), i, strength)) | |
| logging.info(f"[{name}] generating {num_frames} frames, {len(images)} conditioning images") | |
| t1 = time.time() | |
| video, audio = pipeline( | |
| prompt=scene["prompt"], | |
| negative_prompt=args.negative_prompt, | |
| seed=args.seed, | |
| height=args.height, | |
| width=args.width, | |
| num_frames=num_frames, | |
| frame_rate=args.frame_rate, | |
| num_inference_steps=args.num_inference_steps, | |
| video_guider_params=MultiModalGuiderParams( | |
| cfg_scale=args.cfg_scale, | |
| stg_scale=args.stg_scale, | |
| rescale_scale=args.rescale_scale, | |
| modality_scale=args.modality_scale, | |
| ), | |
| images=images, | |
| tiling_config=tiling, | |
| audio_path=scene["audio"], | |
| audio_start_time=0.0, | |
| audio_max_duration=num_frames / args.frame_rate, | |
| ) | |
| encode_video( | |
| video=video, fps=args.frame_rate, audio=audio, | |
| output_path=str(mp4), | |
| video_chunks_number=get_video_chunks_number(num_frames, tiling), | |
| ) | |
| logging.info(f"[{name}] done in {time.time() - t1:.1f}s -> {mp4}") | |
| logging.info("All scenes done.") | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |