Dramabox / scripts /warm_a2vid.py
Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
#!/usr/bin/env python3
"""Warm A2Vid driver — builds :class:`A2VidPipelineTwoStage` once and runs
multiple scenes back-to-back without reloading Gemma, the 22B DiT, VAEs or
the upsampler between calls. Saves ~85 s of model-load time per scene.
Scenes are declared in a JSON manifest:
{
"scenes": [
{
"name": "scene01",
"audio": "scene01.wav",
"prompt": "...",
"num_frames": 241,
"tail_from": null # no conditioning for first scene
},
{
"name": "scene02",
"audio": "scene02.wav",
"prompt": "...",
"num_frames": 201,
"tail_from": "scene01", # pin first 24 frames to scene01 tail
"tail_seconds": 1.0,
"tail_strength": 0.7
}
]
}
Each scene writes <out>/<name>.mp4 and its tail-frames get extracted if a
later scene references it.
"""
import argparse
import json
import logging
import os
import subprocess
import sys
import time
from pathlib import Path
import torch
def extract_tail_frames(mp4: Path, out_prefix: Path, seconds: float, fps: float) -> list[Path]:
"""Extract the last ``seconds`` seconds of ``mp4`` as PNGs starting at index 0."""
dur = float(subprocess.check_output(
["ffprobe", "-v", "error", "-select_streams", "v:0",
"-show_entries", "stream=duration", "-of", "csv=p=0", str(mp4)],
).decode().strip())
start = max(0.0, dur - seconds - 0.05)
n_frames = int(round(seconds * fps))
# Clean stale
for p in out_prefix.parent.glob(f"{out_prefix.name}_*.png"):
p.unlink()
subprocess.run(
["ffmpeg", "-y", "-ss", f"{start:.3f}", "-i", str(mp4),
"-vf", f"fps={fps}", "-frames:v", str(n_frames),
"-start_number", "0", f"{out_prefix}_%03d.png",
"-loglevel", "error"],
check=True,
)
return sorted(out_prefix.parent.glob(f"{out_prefix.name}_*.png"))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--manifest", required=True, help="JSON scene manifest")
ap.add_argument("--out-dir", required=True)
ap.add_argument("--checkpoint-path", required=True)
ap.add_argument("--gemma-root", required=True)
ap.add_argument("--spatial-upsampler-path", required=True)
ap.add_argument("--distilled-lora", required=True)
ap.add_argument("--quantization", default="fp8-cast",
choices=["fp8-cast", "none"])
ap.add_argument("--bnb-4bit", action="store_true", default=True,
help="Load Gemma via bnb-4bit path (default on).")
ap.add_argument("--no-bnb-4bit", dest="bnb_4bit", action="store_false")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--num-inference-steps", type=int, default=30)
ap.add_argument("--height", type=int, default=512)
ap.add_argument("--width", type=int, default=768)
ap.add_argument("--frame-rate", type=float, default=24.0)
# Defaults for guider params (match a2vid_two_stage CLI defaults)
ap.add_argument("--cfg-scale", type=float, default=2.5)
ap.add_argument("--stg-scale", type=float, default=1.0)
ap.add_argument("--rescale-scale", type=float, default=0.7)
ap.add_argument("--modality-scale", type=float, default=2.5)
ap.add_argument("--negative-prompt", default=
"low quality, worst quality, blurry, distorted, artifacts, watermark, text, caption")
args = ap.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# Import after argparse so --help is instant.
from ltx_core.loader.registry import Registry
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
from ltx_core.quantization import QuantizationPolicy
from ltx_pipelines.a2vid_two_stage import A2VidPipelineTwoStage
from ltx_pipelines.utils.args import ImageConditioningInput
from ltx_pipelines.utils.blocks import PromptEncoder
from ltx_core.components.guiders import MultiModalGuiderParams
from ltx_pipelines.utils.media_io import encode_video
quant = QuantizationPolicy.fp8_cast() if args.quantization == "fp8-cast" else None
registry = Registry()
logging.info("Building warm A2Vid pipeline (loads Gemma, VAEs, DiT, upsampler)...")
t0 = time.time()
pipeline = A2VidPipelineTwoStage(
checkpoint_path=args.checkpoint_path,
distilled_lora=[(args.distilled_lora, 1.0, None)] if False else [], # see __init__
spatial_upsampler_path=args.spatial_upsampler_path,
gemma_root=args.gemma_root,
loras=(),
quantization=quant,
registry=registry,
)
# Replace the pipeline's PromptEncoder with a warm+bnb one so subsequent
# calls skip the Gemma load. A2VidPipelineTwoStage stored it as
# self.prompt_encoder.
logging.info("Replacing PromptEncoder with warm + bnb-4bit variant...")
pipeline.prompt_encoder = PromptEncoder(
checkpoint_path=args.checkpoint_path,
gemma_root=args.gemma_root,
dtype=torch.bfloat16,
device=pipeline.device,
registry=registry,
warm=True,
use_bnb_4bit=args.bnb_4bit,
)
logging.info(f"Pipeline ready in {time.time() - t0:.1f}s")
manifest = json.loads(Path(args.manifest).read_text())
tiling = TilingConfig.default()
mp4_paths: dict[str, Path] = {}
for scene in manifest["scenes"]:
name = scene["name"]
mp4 = out_dir / f"{name}.mp4"
mp4_paths[name] = mp4
if mp4.exists():
logging.info(f"[{name}] skipping — already exists")
continue
num_frames = int(scene["num_frames"])
# Build image conditioning from an earlier scene's tail, if specified.
images: list[ImageConditioningInput] = []
tail_from = scene.get("tail_from")
if tail_from:
src_mp4 = mp4_paths.get(tail_from)
if src_mp4 is None or not src_mp4.exists():
raise RuntimeError(f"scene {name} needs tail from {tail_from} which hasn't been generated")
secs = float(scene.get("tail_seconds", 1.0))
strength = float(scene.get("tail_strength", 0.7))
prefix = out_dir / f"{tail_from}_tail"
logging.info(f"[{name}] extracting tail ({secs}s @ {args.frame_rate}fps) from {src_mp4.name}")
tail_pngs = extract_tail_frames(src_mp4, prefix, secs, args.frame_rate)
for i, png in enumerate(tail_pngs):
images.append(ImageConditioningInput(str(png), i, strength))
logging.info(f"[{name}] generating {num_frames} frames, {len(images)} conditioning images")
t1 = time.time()
video, audio = pipeline(
prompt=scene["prompt"],
negative_prompt=args.negative_prompt,
seed=args.seed,
height=args.height,
width=args.width,
num_frames=num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
video_guider_params=MultiModalGuiderParams(
cfg_scale=args.cfg_scale,
stg_scale=args.stg_scale,
rescale_scale=args.rescale_scale,
modality_scale=args.modality_scale,
),
images=images,
tiling_config=tiling,
audio_path=scene["audio"],
audio_start_time=0.0,
audio_max_duration=num_frames / args.frame_rate,
)
encode_video(
video=video, fps=args.frame_rate, audio=audio,
output_path=str(mp4),
video_chunks_number=get_video_chunks_number(num_frames, tiling),
)
logging.info(f"[{name}] done in {time.time() - t1:.1f}s -> {mp4}")
logging.info("All scenes done.")
if __name__ == "__main__":
sys.exit(main())