LTX-2-3 / app.py
cpuai's picture
Update app.py
f23241b verified
import os
import subprocess
import sys
import math
# Disable torch.compile / dynamo before any torch import
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
# Install xformers for memory-efficient attention
subprocess.run(
[sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"],
check=False
)
# Clone LTX-2 repo and install packages
LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
if not os.path.exists(LTX_REPO_DIR):
print(f"Cloning {LTX_REPO_URL}...")
subprocess.run(["git", "clone", "--depth", "1", LTX_REPO_URL, LTX_REPO_DIR], check=True)
print("Installing ltx-core and ltx-pipelines from cloned repo...")
subprocess.run(
[
sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "-e",
os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
"-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines")
],
check=True,
)
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
import logging
import random
import tempfile
from pathlib import Path
import torch
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True
import spaces
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download, snapshot_download
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
from ltx_core.quantization import QuantizationPolicy
from ltx_pipelines.distilled import DistilledPipeline
from ltx_pipelines.utils.args import ImageConditioningInput
from ltx_pipelines.utils.media_io import encode_video
# Force-patch xformers attention into the LTX attention module.
from ltx_core.model.transformer import attention as _attn_mod
print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
try:
from xformers.ops import memory_efficient_attention as _mea
_attn_mod.memory_efficient_attention = _mea
print(f"[ATTN] After patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
except Exception as e:
print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
logging.getLogger().setLevel(logging.INFO)
MAX_SEED = np.iinfo(np.int32).max
DEFAULT_FRAME_RATE = 24.0
# LTX-2.3 官方单次最长 20 秒
MAX_DURATION_SECONDS = 20.0
# 为了降低 20 秒长视频时的失败率,做一个长视频阈值
LONG_VIDEO_THRESHOLD = 10.0
# Resolution presets: (width, height)
RESOLUTIONS = {
"high": {
"16:9": (1536, 1024),
"9:16": (1024, 1536),
"1:1": (1024, 1024),
},
"low": {
"16:9": (768, 512),
"9:16": (512, 768),
"1:1": (768, 768),
},
}
# Model repos
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
print("=" * 80)
print("Downloading LTX-2.3 distilled model + Gemma...")
print("=" * 80)
checkpoint_path = hf_hub_download(
repo_id=LTX_MODEL_REPO,
filename="ltx-2.3-22b-distilled.safetensors"
)
spatial_upsampler_path = hf_hub_download(
repo_id=LTX_MODEL_REPO,
filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors"
)
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
print(f"Checkpoint: {checkpoint_path}")
print(f"Spatial upsampler: {spatial_upsampler_path}")
print(f"Gemma root: {gemma_root}")
# Initialize pipeline WITH text encoder
pipeline = DistilledPipeline(
distilled_checkpoint_path=checkpoint_path,
spatial_upsampler_path=spatial_upsampler_path,
gemma_root=gemma_root,
loras=[],
quantization=QuantizationPolicy.fp8_cast(),
)
# Preload all models for ZeroGPU tensor packing.
print("Preloading all models (including Gemma)...")
ledger = pipeline.model_ledger
_transformer = ledger.transformer()
_video_encoder = ledger.video_encoder()
_video_decoder = ledger.video_decoder()
_audio_decoder = ledger.audio_decoder()
_vocoder = ledger.vocoder()
_spatial_upsampler = ledger.spatial_upsampler()
_text_encoder = ledger.text_encoder()
_embeddings_processor = ledger.gemma_embeddings_processor()
ledger.transformer = lambda: _transformer
ledger.video_encoder = lambda: _video_encoder
ledger.video_decoder = lambda: _video_decoder
ledger.audio_decoder = lambda: _audio_decoder
ledger.vocoder = lambda: _vocoder
ledger.spatial_upsampler = lambda: _spatial_upsampler
ledger.text_encoder = lambda: _text_encoder
ledger.gemma_embeddings_processor = lambda: _embeddings_processor
print("All models preloaded (including Gemma text encoder)!")
print("=" * 80)
print("Pipeline ready!")
print("=" * 80)
def log_memory(tag: str):
"""打印显存信息,便于排查长视频生成问题。"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
peak = torch.cuda.max_memory_allocated() / 1024**3
free, total = torch.cuda.mem_get_info()
print(
f"[VRAM {tag}] "
f"allocated={allocated:.2f}GB "
f"peak={peak:.2f}GB "
f"free={free / 1024**3:.2f}GB "
f"total={total / 1024**3:.2f}GB"
)
def detect_aspect_ratio(image) -> str:
"""根据输入图像自动匹配最接近的宽高比。"""
if image is None:
return "16:9"
if hasattr(image, "size"):
w, h = image.size
elif hasattr(image, "shape"):
h, w = image.shape[:2]
else:
return "16:9"
ratio = w / h
candidates = {
"16:9": 16 / 9,
"9:16": 9 / 16,
"1:1": 1.0,
}
return min(candidates, key=lambda k: abs(ratio - candidates[k]))
def get_resolution_by_state(image, high_res: bool, duration: float):
"""
根据图片比例、分辨率开关、时长,返回最终建议分辨率。
为了让 20 秒视频更稳定,长视频强制降到 low preset。
"""
aspect = detect_aspect_ratio(image)
# 10秒以上统一走 low,显著降低 OOM 和超时概率
if duration > LONG_VIDEO_THRESHOLD:
tier = "low"
else:
tier = "high" if high_res else "low"
w, h = RESOLUTIONS[tier][aspect]
return w, h, tier, aspect
def on_image_upload(image, high_res, duration):
"""上传图片后,自动设置分辨率。"""
w, h, tier, aspect = get_resolution_by_state(image, bool(high_res), float(duration))
tip = f"已自动匹配比例 {aspect},当前使用 {tier} 分辨率:{w}×{h}"
return gr.update(value=w), gr.update(value=h), gr.update(value=tip)
def on_highres_toggle(image, high_res, duration):
"""切换高分辨率开关时,联动分辨率。"""
w, h, tier, aspect = get_resolution_by_state(image, bool(high_res), float(duration))
if float(duration) > LONG_VIDEO_THRESHOLD and bool(high_res):
tip = f"当前时长 {duration:.1f}s,已为稳定性自动降为 low 分辨率:{w}×{h}"
else:
tip = f"已自动匹配比例 {aspect},当前使用 {tier} 分辨率:{w}×{h}"
return gr.update(value=w), gr.update(value=h), gr.update(value=tip)
def on_duration_change(image, high_res, duration):
"""切换时长时,也同步调整分辨率策略。"""
w, h, tier, aspect = get_resolution_by_state(image, bool(high_res), float(duration))
if float(duration) > LONG_VIDEO_THRESHOLD:
tip = (
f"当前时长 {duration:.1f}s,已自动切换到 low 分辨率 {w}×{h},"
f"以降低显存占用和超时风险。"
)
else:
tip = f"当前时长 {duration:.1f}s,比例 {aspect},使用 {tier} 分辨率:{w}×{h}"
return gr.update(value=w), gr.update(value=h), gr.update(value=tip)
def clamp_int(v, min_v, max_v):
"""整数安全钳制。"""
return max(min_v, min(int(v), max_v))
def align_num_frames(duration: float, frame_rate: float) -> int:
"""
将帧数对齐到 LTX 常用的 8n+1 形式。
例如:
20秒 * 24fps = 480 帧
对齐后为 481 帧
"""
raw_frames = int(duration * frame_rate) + 1
aligned_frames = ((raw_frames - 1 + 7) // 8) * 8 + 1
return aligned_frames
# 20 秒视频推理时间明显更长,因此把 GPU duration 提高
@spaces.GPU(duration=240)
@torch.inference_mode()
def generate_video(
input_image,
prompt: str,
duration: float,
enhance_prompt: bool = True,
seed: int = 42,
randomize_seed: bool = True,
height: int = 1024,
width: int = 1536,
high_res: bool = True,
progress=gr.Progress(track_tqdm=True),
):
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
try:
torch.cuda.reset_peak_memory_stats()
log_memory("start")
# ---------- 参数安全限制 ----------
duration = max(1.0, min(float(duration), MAX_DURATION_SECONDS))
frame_rate = DEFAULT_FRAME_RATE
num_frames = align_num_frames(duration, frame_rate)
# 宽高做整数与边界保护
width = clamp_int(width, 256, 2048)
height = clamp_int(height, 256, 2048)
# 长视频时自动降级分辨率,提高成功率
safe_w, safe_h, safe_tier, safe_aspect = get_resolution_by_state(input_image, bool(high_res), duration)
if duration > LONG_VIDEO_THRESHOLD:
if width != safe_w or height != safe_h:
print(
f"[SAFE] Long video detected ({duration:.1f}s). "
f"Override resolution from {width}x{height} to {safe_w}x{safe_h}"
)
width, height = safe_w, safe_h
print(
f"Generating: {height}x{width}, "
f"{num_frames} frames ({duration:.1f}s), "
f"seed={current_seed}, high_res={high_res}, safe_tier={safe_tier}"
)
images = []
if input_image is not None:
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
if hasattr(input_image, "save"):
input_image.save(temp_image_path)
else:
temp_image_path = Path(input_image)
images = [
ImageConditioningInput(
path=str(temp_image_path),
frame_idx=0,
strength=1.0
)
]
tiling_config = TilingConfig.default()
video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
log_memory("before pipeline call")
video, audio = pipeline(
prompt=prompt,
seed=current_seed,
height=int(height),
width=int(width),
num_frames=num_frames,
frame_rate=frame_rate,
images=images,
tiling_config=tiling_config,
enhance_prompt=bool(enhance_prompt),
)
log_memory("after pipeline call")
output_path = tempfile.mktemp(suffix=".mp4")
encode_video(
video=video,
fps=frame_rate,
audio=audio,
output_path=output_path,
video_chunks_number=video_chunks_number,
)
log_memory("after encode_video")
return str(output_path), current_seed, (
f"生成成功:{duration:.1f} 秒,{num_frames} 帧,输出分辨率 {width}×{height}"
)
except Exception as e:
import traceback
log_memory("on error")
err = f"{type(e).__name__}: {str(e)}"
print(f"Error: {err}\n{traceback.format_exc()}")
user_msg = (
"生成失败。\n"
f"错误:{err}\n\n"
"建议:\n"
"1. 20秒视频请优先使用低分辨率\n"
"2. 先关闭 High Resolution\n"
"3. 输入图尽量简单,减少复杂运动\n"
"4. 如在 ZeroGPU / Hugging Face Space 上运行,长视频可能仍会因排队或时限失败"
)
return None, current_seed, user_msg
with gr.Blocks(title="LTX-2.3 Distilled") as demo:
gr.Markdown("# LTX-2.3 Distilled (22B): Fast Audio-Video Generation")
gr.Markdown(
"Fast and high quality video + audio generation \n"
"[[model]](https://huggingface.co/Lightricks/LTX-2.3) "
"[[code]](https://github.com/Lightricks/LTX-2)"
)
gr.Markdown(
"说明:已支持最长 20 秒视频。为提高成功率,超过 10 秒时会自动切换为低分辨率。"
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image (Optional)", type="pil")
prompt = gr.Textbox(
label="Prompt",
info="for best results - make it as elaborate as possible",
value="Make this image come alive with cinematic motion, smooth animation",
lines=3,
placeholder="Describe the motion and animation you want...",
)
with gr.Row():
duration = gr.Slider(
label="Duration (seconds)",
minimum=1.0,
maximum=20.0, # 改为 20 秒
value=3.0,
step=0.1
)
with gr.Column():
enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
high_res = gr.Checkbox(label="High Resolution", value=True)
generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=10, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
width = gr.Number(label="Width", value=1536, precision=0)
height = gr.Number(label="Height", value=1024, precision=0)
status_text = gr.Textbox(
label="Status",
value="就绪",
interactive=False,
lines=4
)
with gr.Column():
output_video = gr.Video(label="Generated Video", autoplay=True)
# 上传图片时自动调整
input_image.change(
fn=on_image_upload,
inputs=[input_image, high_res, duration],
outputs=[width, height, status_text],
)
# 切换高分辨率时自动调整
high_res.change(
fn=on_highres_toggle,
inputs=[input_image, high_res, duration],
outputs=[width, height, status_text],
)
# 切换时长时自动调整
duration.change(
fn=on_duration_change,
inputs=[input_image, high_res, duration],
outputs=[width, height, status_text],
)
generate_btn.click(
fn=generate_video,
inputs=[
input_image,
prompt,
duration,
enhance_prompt,
seed,
randomize_seed,
height,
width,
high_res,
],
outputs=[output_video, seed, status_text],
)
css = """
.fillable {max-width: 1200px !important;}
"""
if __name__ == "__main__":
demo.launch(theme=gr.themes.Citrus(), css=css)