LTX2_distill / text.py
rahul7star's picture
Update text.py
dfb0ede verified
import spaces
import gradio as gr
import torch
import random
import numpy as np
from diffusers.pipelines.ltx2 import (
LTX2Pipeline,
LTX2LatentUpsamplePipeline
)
from diffusers.pipelines.ltx2.latent_upsampler import (
LTX2LatentUpsamplerModel
)
from diffusers.pipelines.ltx2.utils import (
DISTILLED_SIGMA_VALUES,
STAGE_2_DISTILLED_SIGMA_VALUES
)
from diffusers.pipelines.ltx2.export_utils import encode_video
# ============================================================
# 🔥 GLOBAL SETTINGS (H200 SAFE)
# ============================================================
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
DEVICE = "cuda"
DTYPE = torch.bfloat16
MODEL_ID = "rootonchair/LTX-2-19b-distilled"
print("🚀 Loading LTX-2 Text-to-Video...")
pipe = LTX2Pipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
).to(DEVICE)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
MODEL_ID,
subfolder="latent_upsampler",
torch_dtype=DTYPE,
).to(DEVICE)
upsample_pipe = LTX2LatentUpsamplePipeline(
vae=pipe.vae,
latent_upsampler=latent_upsampler
).to(DEVICE)
pipe.load_lora_weights(
"Lightricks/LTX-2-19b-IC-LoRA-Detailer",
adapter_name="camera_control"
)
pipe.fuse_lora(lora_scale=0.8)
pipe.unload_lora_weights()
print("🔥 Model fully loaded on CUDA.")
print("✅ Model Loaded")
# ============================================================
# 🎬 GENERATION
# ============================================================
@spaces.GPU(duration=85, size="xlarge")
def generate(prompt, negative_prompt, duration, seed):
if seed == -1:
seed = random.randint(0, 1_000_000)
generator = torch.Generator(device="cuda").manual_seed(seed)
width = 768
height = 512
fps = 24
total_frames = int(duration * fps)
num_frames = max((round(total_frames / 8) * 8) + 1, 9)
with torch.inference_mode():
# Stage 1
video_latent, audio_latent = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=num_frames,
frame_rate=fps,
num_inference_steps=8,
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
generator=generator,
output_type="latent",
return_dict=False,
)
# Latent Upscale
upscaled_video_latent = upsample_pipe(
latents=video_latent,
output_type="latent",
return_dict=False,
)[0]
# Stage 2
video, audio = pipe(
latents=upscaled_video_latent,
audio_latents=audio_latent,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=3,
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
generator=generator,
guidance_scale=1.0,
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
output_path = f"t2v_{seed}.mp4"
encode_video(
video[0],
fps=fps,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path=output_path,
)
return output_path, seed
# ============================================================
# 🖥️ UI
# ============================================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎬 LTX-2 Distilled Text-to-Video")
result = gr.Video()
prompt = gr.Textbox(
value="A cinematic sunset over the ocean",
lines=3,
label="Prompt"
)
negative_prompt = gr.Textbox(
value="low quality, distorted, glitchy",
label="Negative Prompt"
)
duration = gr.Slider(1, 12, value=4, label="Duration (seconds)")
seed = gr.Number(value=-1, precision=0, label="Seed")
btn = gr.Button("Generate", variant="primary")
btn.click(
generate,
inputs=[prompt, negative_prompt, duration, seed],
outputs=[result, seed]
)
if __name__ == "__main__":
demo.queue().launch()