import os import subprocess import sys import math # Disable torch.compile / dynamo before any torch import os.environ["TORCH_COMPILE_DISABLE"] = "1" os.environ["TORCHDYNAMO_DISABLE"] = "1" # Install xformers for memory-efficient attention subprocess.run( [sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False ) # Clone LTX-2 repo and install packages LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git" LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2") if not os.path.exists(LTX_REPO_DIR): print(f"Cloning {LTX_REPO_URL}...") subprocess.run(["git", "clone", "--depth", "1", LTX_REPO_URL, LTX_REPO_DIR], check=True) print("Installing ltx-core and ltx-pipelines from cloned repo...") subprocess.run( [ sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-core"), "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines") ], check=True, ) sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src")) sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src")) import logging import random import tempfile from pathlib import Path import torch torch._dynamo.config.suppress_errors = True torch._dynamo.config.disable = True import spaces import gradio as gr import numpy as np from huggingface_hub import hf_hub_download, snapshot_download from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number from ltx_core.quantization import QuantizationPolicy from ltx_pipelines.distilled import DistilledPipeline from ltx_pipelines.utils.args import ImageConditioningInput from ltx_pipelines.utils.media_io import encode_video # Force-patch xformers attention into the LTX attention module. from ltx_core.model.transformer import attention as _attn_mod print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}") try: from xformers.ops import memory_efficient_attention as _mea _attn_mod.memory_efficient_attention = _mea print(f"[ATTN] After patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}") except Exception as e: print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}") logging.getLogger().setLevel(logging.INFO) MAX_SEED = np.iinfo(np.int32).max DEFAULT_FRAME_RATE = 24.0 # LTX-2.3 官方单次最长 20 秒 MAX_DURATION_SECONDS = 20.0 # 为了降低 20 秒长视频时的失败率,做一个长视频阈值 LONG_VIDEO_THRESHOLD = 10.0 # Resolution presets: (width, height) RESOLUTIONS = { "high": { "16:9": (1536, 1024), "9:16": (1024, 1536), "1:1": (1024, 1024), }, "low": { "16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768), }, } # Model repos LTX_MODEL_REPO = "Lightricks/LTX-2.3" GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized" print("=" * 80) print("Downloading LTX-2.3 distilled model + Gemma...") print("=" * 80) checkpoint_path = hf_hub_download( repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors" ) spatial_upsampler_path = hf_hub_download( repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors" ) gemma_root = snapshot_download(repo_id=GEMMA_REPO) print(f"Checkpoint: {checkpoint_path}") print(f"Spatial upsampler: {spatial_upsampler_path}") print(f"Gemma root: {gemma_root}") # Initialize pipeline WITH text encoder pipeline = DistilledPipeline( distilled_checkpoint_path=checkpoint_path, spatial_upsampler_path=spatial_upsampler_path, gemma_root=gemma_root, loras=[], quantization=QuantizationPolicy.fp8_cast(), ) # Preload all models for ZeroGPU tensor packing. print("Preloading all models (including Gemma)...") ledger = pipeline.model_ledger _transformer = ledger.transformer() _video_encoder = ledger.video_encoder() _video_decoder = ledger.video_decoder() _audio_decoder = ledger.audio_decoder() _vocoder = ledger.vocoder() _spatial_upsampler = ledger.spatial_upsampler() _text_encoder = ledger.text_encoder() _embeddings_processor = ledger.gemma_embeddings_processor() ledger.transformer = lambda: _transformer ledger.video_encoder = lambda: _video_encoder ledger.video_decoder = lambda: _video_decoder ledger.audio_decoder = lambda: _audio_decoder ledger.vocoder = lambda: _vocoder ledger.spatial_upsampler = lambda: _spatial_upsampler ledger.text_encoder = lambda: _text_encoder ledger.gemma_embeddings_processor = lambda: _embeddings_processor print("All models preloaded (including Gemma text encoder)!") print("=" * 80) print("Pipeline ready!") print("=" * 80) def log_memory(tag: str): """打印显存信息,便于排查长视频生成问题。""" if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 peak = torch.cuda.max_memory_allocated() / 1024**3 free, total = torch.cuda.mem_get_info() print( f"[VRAM {tag}] " f"allocated={allocated:.2f}GB " f"peak={peak:.2f}GB " f"free={free / 1024**3:.2f}GB " f"total={total / 1024**3:.2f}GB" ) def detect_aspect_ratio(image) -> str: """根据输入图像自动匹配最接近的宽高比。""" if image is None: return "16:9" if hasattr(image, "size"): w, h = image.size elif hasattr(image, "shape"): h, w = image.shape[:2] else: return "16:9" ratio = w / h candidates = { "16:9": 16 / 9, "9:16": 9 / 16, "1:1": 1.0, } return min(candidates, key=lambda k: abs(ratio - candidates[k])) def get_resolution_by_state(image, high_res: bool, duration: float): """ 根据图片比例、分辨率开关、时长,返回最终建议分辨率。 为了让 20 秒视频更稳定,长视频强制降到 low preset。 """ aspect = detect_aspect_ratio(image) # 10秒以上统一走 low,显著降低 OOM 和超时概率 if duration > LONG_VIDEO_THRESHOLD: tier = "low" else: tier = "high" if high_res else "low" w, h = RESOLUTIONS[tier][aspect] return w, h, tier, aspect def on_image_upload(image, high_res, duration): """上传图片后,自动设置分辨率。""" w, h, tier, aspect = get_resolution_by_state(image, bool(high_res), float(duration)) tip = f"已自动匹配比例 {aspect},当前使用 {tier} 分辨率:{w}×{h}" return gr.update(value=w), gr.update(value=h), gr.update(value=tip) def on_highres_toggle(image, high_res, duration): """切换高分辨率开关时,联动分辨率。""" w, h, tier, aspect = get_resolution_by_state(image, bool(high_res), float(duration)) if float(duration) > LONG_VIDEO_THRESHOLD and bool(high_res): tip = f"当前时长 {duration:.1f}s,已为稳定性自动降为 low 分辨率:{w}×{h}" else: tip = f"已自动匹配比例 {aspect},当前使用 {tier} 分辨率:{w}×{h}" return gr.update(value=w), gr.update(value=h), gr.update(value=tip) def on_duration_change(image, high_res, duration): """切换时长时,也同步调整分辨率策略。""" w, h, tier, aspect = get_resolution_by_state(image, bool(high_res), float(duration)) if float(duration) > LONG_VIDEO_THRESHOLD: tip = ( f"当前时长 {duration:.1f}s,已自动切换到 low 分辨率 {w}×{h}," f"以降低显存占用和超时风险。" ) else: tip = f"当前时长 {duration:.1f}s,比例 {aspect},使用 {tier} 分辨率:{w}×{h}" return gr.update(value=w), gr.update(value=h), gr.update(value=tip) def clamp_int(v, min_v, max_v): """整数安全钳制。""" return max(min_v, min(int(v), max_v)) def align_num_frames(duration: float, frame_rate: float) -> int: """ 将帧数对齐到 LTX 常用的 8n+1 形式。 例如: 20秒 * 24fps = 480 帧 对齐后为 481 帧 """ raw_frames = int(duration * frame_rate) + 1 aligned_frames = ((raw_frames - 1 + 7) // 8) * 8 + 1 return aligned_frames # 20 秒视频推理时间明显更长,因此把 GPU duration 提高 @spaces.GPU(duration=240) @torch.inference_mode() def generate_video( input_image, prompt: str, duration: float, enhance_prompt: bool = True, seed: int = 42, randomize_seed: bool = True, height: int = 1024, width: int = 1536, high_res: bool = True, progress=gr.Progress(track_tqdm=True), ): current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) try: torch.cuda.reset_peak_memory_stats() log_memory("start") # ---------- 参数安全限制 ---------- duration = max(1.0, min(float(duration), MAX_DURATION_SECONDS)) frame_rate = DEFAULT_FRAME_RATE num_frames = align_num_frames(duration, frame_rate) # 宽高做整数与边界保护 width = clamp_int(width, 256, 2048) height = clamp_int(height, 256, 2048) # 长视频时自动降级分辨率,提高成功率 safe_w, safe_h, safe_tier, safe_aspect = get_resolution_by_state(input_image, bool(high_res), duration) if duration > LONG_VIDEO_THRESHOLD: if width != safe_w or height != safe_h: print( f"[SAFE] Long video detected ({duration:.1f}s). " f"Override resolution from {width}x{height} to {safe_w}x{safe_h}" ) width, height = safe_w, safe_h print( f"Generating: {height}x{width}, " f"{num_frames} frames ({duration:.1f}s), " f"seed={current_seed}, high_res={high_res}, safe_tier={safe_tier}" ) images = [] if input_image is not None: output_dir = Path("outputs") output_dir.mkdir(exist_ok=True) temp_image_path = output_dir / f"temp_input_{current_seed}.jpg" if hasattr(input_image, "save"): input_image.save(temp_image_path) else: temp_image_path = Path(input_image) images = [ ImageConditioningInput( path=str(temp_image_path), frame_idx=0, strength=1.0 ) ] tiling_config = TilingConfig.default() video_chunks_number = get_video_chunks_number(num_frames, tiling_config) log_memory("before pipeline call") video, audio = pipeline( prompt=prompt, seed=current_seed, height=int(height), width=int(width), num_frames=num_frames, frame_rate=frame_rate, images=images, tiling_config=tiling_config, enhance_prompt=bool(enhance_prompt), ) log_memory("after pipeline call") output_path = tempfile.mktemp(suffix=".mp4") encode_video( video=video, fps=frame_rate, audio=audio, output_path=output_path, video_chunks_number=video_chunks_number, ) log_memory("after encode_video") return str(output_path), current_seed, ( f"生成成功:{duration:.1f} 秒,{num_frames} 帧,输出分辨率 {width}×{height}" ) except Exception as e: import traceback log_memory("on error") err = f"{type(e).__name__}: {str(e)}" print(f"Error: {err}\n{traceback.format_exc()}") user_msg = ( "生成失败。\n" f"错误:{err}\n\n" "建议:\n" "1. 20秒视频请优先使用低分辨率\n" "2. 先关闭 High Resolution\n" "3. 输入图尽量简单,减少复杂运动\n" "4. 如在 ZeroGPU / Hugging Face Space 上运行,长视频可能仍会因排队或时限失败" ) return None, current_seed, user_msg with gr.Blocks(title="LTX-2.3 Distilled") as demo: gr.Markdown("# LTX-2.3 Distilled (22B): Fast Audio-Video Generation") gr.Markdown( "Fast and high quality video + audio generation \n" "[[model]](https://huggingface.co/Lightricks/LTX-2.3) " "[[code]](https://github.com/Lightricks/LTX-2)" ) gr.Markdown( "说明:已支持最长 20 秒视频。为提高成功率,超过 10 秒时会自动切换为低分辨率。" ) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image (Optional)", type="pil") prompt = gr.Textbox( label="Prompt", info="for best results - make it as elaborate as possible", value="Make this image come alive with cinematic motion, smooth animation", lines=3, placeholder="Describe the motion and animation you want...", ) with gr.Row(): duration = gr.Slider( label="Duration (seconds)", minimum=1.0, maximum=20.0, # 改为 20 秒 value=3.0, step=0.1 ) with gr.Column(): enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False) high_res = gr.Checkbox(label="High Resolution", value=True) generate_btn = gr.Button("Generate Video", variant="primary", size="lg") with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=10, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) with gr.Row(): width = gr.Number(label="Width", value=1536, precision=0) height = gr.Number(label="Height", value=1024, precision=0) status_text = gr.Textbox( label="Status", value="就绪", interactive=False, lines=4 ) with gr.Column(): output_video = gr.Video(label="Generated Video", autoplay=True) # 上传图片时自动调整 input_image.change( fn=on_image_upload, inputs=[input_image, high_res, duration], outputs=[width, height, status_text], ) # 切换高分辨率时自动调整 high_res.change( fn=on_highres_toggle, inputs=[input_image, high_res, duration], outputs=[width, height, status_text], ) # 切换时长时自动调整 duration.change( fn=on_duration_change, inputs=[input_image, high_res, duration], outputs=[width, height, status_text], ) generate_btn.click( fn=generate_video, inputs=[ input_image, prompt, duration, enhance_prompt, seed, randomize_seed, height, width, high_res, ], outputs=[output_video, seed, status_text], ) css = """ .fillable {max-width: 1200px !important;} """ if __name__ == "__main__": demo.launch(theme=gr.themes.Citrus(), css=css)