ltx-2-3 / app.py
jblast94's picture
Initial: LTX-2.3 video generator with Gradio UI
c48b53c verified
"""
LTX-2.3 Video Generator β€” Gradio Space
Text-to-Video & Image-to-Video using Lightricks LTX-2.3 distilled checkpoint.
"""
import os
import sys
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Optional
import spaces
import torch
import numpy as np
import gradio as gr
from PIL import Image
# ── Setup: install Lightricks LTX-2 packages ───────────────────────────
LTX_REPO = "https://github.com/Lightricks/LTX-2.git"
PACKAGES_DIR = Path("/tmp/ltx-packages")
PACKAGES_DIR.mkdir(parents=True, exist_ok=True)
def ensure_packages():
"""Clone LTX-2 repo and install packages if not already done."""
marker = PACKAGES_DIR / ".installed"
if marker.exists():
return
print("[setup] Installing LTX-2 packages...")
subprocess.run(
["git", "clone", "--depth", "1", LTX_REPO, str(PACKAGES_DIR / "repo")],
check=True, capture_output=True
)
for pkg in ["packages/ltx-core", "packages/ltx-pipelines"]:
pkg_path = PACKAGES_DIR / "repo" / pkg
if pkg_path.exists():
subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", str(pkg_path)],
check=True, capture_output=True
)
marker.touch()
print("[setup] Packages ready")
ensure_packages()
# ── Imports (after packages) ────────────────────────────────────────────
from huggingface_hub import hf_hub_download
from ltx_pipelines.distilled import DistilledPipeline
from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
# ── Constants ───────────────────────────────────────────────────────────
MODEL_ID = "Lightricks/LTX-2.3"
DISTILLED_CKPT = "ltx-2.3-22b-distilled.safetensors"
SPATIAL_UPSCALER = "ltx-2.3-spatial-upscaler-x2-1.1.safetensors"
GEMMA_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
CACHE_ROOT = Path("/tmp/ltx-cache")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16
_pipe = None
def download_checkpoint(repo_id: str, filename: str) -> str:
"""Download a checkpoint from Hugging Face Hub."""
return hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=str(CACHE_ROOT / "hub"),
)
@spaces.GPU
def load_pipeline():
"""Load or return cached pipeline."""
global _pipe
if _pipe is not None:
return _pipe
print(f"[load] Device: {DEVICE} | torch: {torch.__version__}")
# Download checkpoints
ckpt_path = download_checkpoint(MODEL_ID, DISTILLED_CKPT)
upscaler_path = download_checkpoint(MODEL_ID, SPATIAL_UPSCALER)
gemma_path = download_checkpoint(GEMMA_ID, "model.safetensors")
gemma_root = str(Path(gemma_path).parent)
print(f"[load] Checkpoint: {ckpt_path}")
print(f"[load] Gemma root: {gemma_root}")
_pipe = DistilledPipeline(
checkpoint_path=ckpt_path,
gemma_root=gemma_root,
spatial_upsampler_path=upscaler_path,
loras=[],
device=DEVICE,
fp8transformer=True,
)
print("[load] Pipeline ready")
return _pipe
@spaces.GPU
def generate_video(
prompt: str,
image: Optional[np.ndarray] = None,
negative_prompt: str = "",
num_frames: int = 49,
width: int = 768,
height: int = 512,
guidance_scale: float = 1.0,
num_inference_steps: int = 8,
seed: int = -1,
) -> str:
"""Generate video from text + optional image."""
pipe = load_pipeline()
if seed < 0:
seed = torch.randint(0, 2**31, (1,)).item()
generator = torch.Generator(device=DEVICE).manual_seed(seed)
# Process optional input image
cond_images = None
if image is not None:
pil = Image.fromarray(image).resize((width, height))
cond_images = [pil]
print(f"[gen] {prompt[:60]}... | {num_frames}f | seed={seed}")
# Build kwargs
kwargs = dict(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="pt",
condition_images=cond_images,
)
with torch.inference_mode():
frames = pipe(**kwargs)
out_video = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
# Export frames to video
from diffusers.utils import export_to_video
export_to_video(frames, out_video, fps=16)
print(f"[gen] βœ… {out_video}")
return out_video
# ── Gradio UI ───────────────────────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Soft(), title="LTX-2.3 Video Generator") as demo:
gr.Markdown(
"""
# 🎬 LTX‑2.3 Video Generator
**Lightricks LTX‑2.3** β€” 22B audio‑video foundation model
Uses the distilled checkpoint (8‑step turbo) for fast generation on ZeroGPU.
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="A cinematic drone shot over misty mountains at sunrise...",
lines=3,
)
with gr.Accordion("βš™οΈ Settings", open=False):
with gr.Row():
num_frames = gr.Slider(9, 97, value=49, step=8, label="Frames")
steps = gr.Slider(4, 24, value=8, step=1, label="Steps")
with gr.Row():
guidance = gr.Slider(0.5, 3.0, value=1.0, step=0.1, label="CFG Scale")
seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
with gr.Row():
width = gr.Dropdown([512, 576, 640, 768, 832, 896, 1024], value=768, label="Width")
height = gr.Dropdown([384, 448, 512, 576, 640, 704, 768], value=512, label="Height")
generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
with gr.Column(scale=1):
image_input = gr.Image(label="Input Image (optional)", type="numpy")
video_output = gr.Video(label="Generated Video", autoplay=True)
with gr.Row():
gr.Examples(
examples=[
["Cinematic drone shot over a misty mountain range at sunrise, golden light piercing clouds"],
["Fluffy Samoyed puppy playing in a tulip field, slow motion, shallow depth of field"],
["Cyberpunk city street at night, neon signs reflecting on wet pavement, cinematic lighting"],
["Majestic humpback whale breaching at sunset, slow motion, National Geographic style"],
["Time-lapse of cherry blossoms blooming in a Japanese garden, peaceful atmosphere"],
],
inputs=[prompt],
)
generate_btn.click(
fn=generate_video,
inputs=[prompt, image_input, num_frames, width, height, guidance, steps, seed],
outputs=video_output,
)
prompt.submit(
fn=generate_video,
inputs=[prompt, image_input, num_frames, width, height, guidance, steps, seed],
outputs=video_output,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)