Spaces:
Paused
Paused
| import spaces | |
| import gradio as gr | |
| import torch | |
| import random | |
| import numpy as np | |
| from diffusers.pipelines.ltx2 import ( | |
| LTX2Pipeline, | |
| LTX2LatentUpsamplePipeline | |
| ) | |
| from diffusers.pipelines.ltx2.latent_upsampler import ( | |
| LTX2LatentUpsamplerModel | |
| ) | |
| from diffusers.pipelines.ltx2.utils import ( | |
| DISTILLED_SIGMA_VALUES, | |
| STAGE_2_DISTILLED_SIGMA_VALUES | |
| ) | |
| from diffusers.pipelines.ltx2.export_utils import encode_video | |
| # ============================================================ | |
| # 🔥 GLOBAL SETTINGS (H200 SAFE) | |
| # ============================================================ | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_grad_enabled(False) | |
| torch.backends.cuda.enable_flash_sdp(True) | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| DEVICE = "cuda" | |
| DTYPE = torch.bfloat16 | |
| MODEL_ID = "rootonchair/LTX-2-19b-distilled" | |
| print("🚀 Loading LTX-2 Text-to-Video...") | |
| pipe = LTX2Pipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| ).to(DEVICE) | |
| latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( | |
| MODEL_ID, | |
| subfolder="latent_upsampler", | |
| torch_dtype=DTYPE, | |
| ).to(DEVICE) | |
| upsample_pipe = LTX2LatentUpsamplePipeline( | |
| vae=pipe.vae, | |
| latent_upsampler=latent_upsampler | |
| ).to(DEVICE) | |
| pipe.load_lora_weights( | |
| "Lightricks/LTX-2-19b-IC-LoRA-Detailer", | |
| adapter_name="camera_control" | |
| ) | |
| pipe.fuse_lora(lora_scale=0.8) | |
| pipe.unload_lora_weights() | |
| print("🔥 Model fully loaded on CUDA.") | |
| print("✅ Model Loaded") | |
| # ============================================================ | |
| # 🎬 GENERATION | |
| # ============================================================ | |
| def generate(prompt, negative_prompt, duration, seed): | |
| if seed == -1: | |
| seed = random.randint(0, 1_000_000) | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| width = 768 | |
| height = 512 | |
| fps = 24 | |
| total_frames = int(duration * fps) | |
| num_frames = max((round(total_frames / 8) * 8) + 1, 9) | |
| with torch.inference_mode(): | |
| # Stage 1 | |
| video_latent, audio_latent = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_frames=num_frames, | |
| frame_rate=fps, | |
| num_inference_steps=8, | |
| sigmas=DISTILLED_SIGMA_VALUES, | |
| guidance_scale=1.0, | |
| generator=generator, | |
| output_type="latent", | |
| return_dict=False, | |
| ) | |
| # Latent Upscale | |
| upscaled_video_latent = upsample_pipe( | |
| latents=video_latent, | |
| output_type="latent", | |
| return_dict=False, | |
| )[0] | |
| # Stage 2 | |
| video, audio = pipe( | |
| latents=upscaled_video_latent, | |
| audio_latents=audio_latent, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=3, | |
| noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], | |
| sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, | |
| generator=generator, | |
| guidance_scale=1.0, | |
| output_type="np", | |
| return_dict=False, | |
| ) | |
| video = (video * 255).round().astype("uint8") | |
| video = torch.from_numpy(video) | |
| output_path = f"t2v_{seed}.mp4" | |
| encode_video( | |
| video[0], | |
| fps=fps, | |
| audio=audio[0].float().cpu(), | |
| audio_sample_rate=pipe.vocoder.config.output_sampling_rate, | |
| output_path=output_path, | |
| ) | |
| return output_path, seed | |
| # ============================================================ | |
| # 🖥️ UI | |
| # ============================================================ | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎬 LTX-2 Distilled Text-to-Video") | |
| result = gr.Video() | |
| prompt = gr.Textbox( | |
| value="A cinematic sunset over the ocean", | |
| lines=3, | |
| label="Prompt" | |
| ) | |
| negative_prompt = gr.Textbox( | |
| value="low quality, distorted, glitchy", | |
| label="Negative Prompt" | |
| ) | |
| duration = gr.Slider(1, 12, value=4, label="Duration (seconds)") | |
| seed = gr.Number(value=-1, precision=0, label="Seed") | |
| btn = gr.Button("Generate", variant="primary") | |
| btn.click( | |
| generate, | |
| inputs=[prompt, negative_prompt, duration, seed], | |
| outputs=[result, seed] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |