multimodalart's picture
multimodalart HF Staff
Update app.py
07b8a5b verified
"""
LTX-2.3 Gemma Text Encoder Space
Encodes text prompts using Gemma-3-12B for LTX-2.3 video generation.
Supports prompt enhancement for better results.
"""
import os
import subprocess
import sys
import time
from pathlib import Path
# Clone LTX-2 repo and install packages
LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
if not os.path.exists(LTX_REPO_DIR):
print(f"Cloning {LTX_REPO_URL}...")
subprocess.run(["git", "clone", "--depth", "1", LTX_REPO_URL, LTX_REPO_DIR], check=True)
print("Installing ltx-core and ltx-pipelines from cloned repo...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "-e",
os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
"-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines")],
check=True,
)
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
import numpy as np
import spaces
import gradio as gr
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from ltx_pipelines.utils import ModelLedger
from ltx_pipelines.utils.helpers import generate_enhanced_prompt, encode_prompts
MAX_SEED = np.iinfo(np.int32).max
# Model repos
LTX_MODEL_REPO = "diffusers-internal-dev/ltx-23"
LTX_CHECKPOINT_FILENAME = "ltx-2.3-22b-distilled.safetensors"
GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
# Download models
print("=" * 80)
print("Downloading models...")
print("=" * 80)
checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename=LTX_CHECKPOINT_FILENAME)
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
print(f"Checkpoint: {checkpoint_path}")
print(f"Gemma root: {gemma_root}")
# Initialize model ledger
print("=" * 80)
print("Loading Gemma Text Encoder for LTX-2.3...")
print("=" * 80)
device = "cuda"
model_ledger = ModelLedger(
dtype=torch.bfloat16,
device=device,
checkpoint_path=checkpoint_path,
gemma_root_path=gemma_root,
)
print("=" * 80)
print("Text encoder loaded and ready!")
print("=" * 80)
@spaces.GPU()
def encode_prompt(
prompt: str,
enhance_prompt: bool = True,
input_image=None,
seed: int = 42,
negative_prompt: str = "",
):
"""
Encode a text prompt using Gemma text encoder for LTX-2.3.
Uses encode_prompts() which handles text_encoder + embeddings_processor.
Returns:
tuple: (file_path, enhanced_prompt_text, status_message)
"""
start_time = time.time()
try:
# Build list of prompts to encode
prompts = [prompt]
if negative_prompt:
prompts.append(negative_prompt)
image_path = None
if input_image is not None:
temp_dir = Path("temp_images")
temp_dir.mkdir(exist_ok=True)
temp_image_path = temp_dir / f"temp_{int(time.time())}.jpg"
if hasattr(input_image, "save"):
input_image.save(temp_image_path)
else:
temp_image_path = input_image
image_path = str(temp_image_path)
# encode_prompts handles: text_encoder.encode() -> embeddings_processor
results = encode_prompts(
prompts,
model_ledger,
enhance_first_prompt=enhance_prompt,
enhance_prompt_image=image_path,
enhance_prompt_seed=seed,
)
ctx_p = results[0]
video_context = ctx_p.video_encoding
audio_context = ctx_p.audio_encoding
# Save embeddings
output_dir = Path("embeddings")
output_dir.mkdir(exist_ok=True)
output_path = output_dir / f"embedding_{int(time.time())}.pt"
embedding_data = {
"video_context": video_context.cpu(),
"audio_context": audio_context.cpu() if audio_context is not None else None,
"prompt": prompt,
}
if negative_prompt and len(results) > 1:
ctx_n = results[1]
embedding_data["video_context_negative"] = ctx_n.video_encoding.cpu()
embedding_data["audio_context_negative"] = ctx_n.audio_encoding.cpu() if ctx_n.audio_encoding is not None else None
embedding_data["negative_prompt"] = negative_prompt
torch.save(embedding_data, output_path)
elapsed_time = time.time() - start_time
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
peak = torch.cuda.max_memory_allocated() / 1024**3
status = f"Encoded in {elapsed_time:.2f}s | VRAM: {allocated:.2f}GB allocated, {peak:.2f}GB peak"
else:
status = f"Encoded in {elapsed_time:.2f}s (CPU mode)"
return str(output_path), prompt, status
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, prompt, error_msg
with gr.Blocks(title="LTX-2.3 Gemma Text Encoder") as demo:
gr.Markdown("# LTX-2.3 Gemma Text Encoder")
gr.Markdown(
"Encode text prompts using Gemma-3-12B for LTX-2.3 video generation. "
"This space generates embeddings used by the LTX-2.3 generation space."
)
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5,
value="An astronaut hatches from a fragile egg on the surface of the Moon",
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="Enter negative prompt for CFG...",
lines=2,
value="",
)
enhance_checkbox = gr.Checkbox(
label="Enhance Prompt",
value=True,
info="Use Gemma to automatically enhance your prompt",
)
with gr.Accordion("Prompt Enhancement Settings", open=False):
input_image = gr.Image(label="Reference Image (Optional)", type="filepath")
enhancement_seed = gr.Slider(
label="Enhancement Seed", minimum=0, maximum=MAX_SEED, value=42, step=1
)
encode_btn = gr.Button("Encode Prompt", variant="primary", size="lg")
with gr.Column():
embedding_file = gr.File(label="Embedding File (.pt)")
enhanced_prompt_output = gr.Textbox(label="Final Prompt Used", lines=5)
status_output = gr.Textbox(label="Status", lines=2)
encode_btn.click(
fn=encode_prompt,
inputs=[prompt_input, enhance_checkbox, input_image, enhancement_seed, negative_prompt_input],
outputs=[embedding_file, enhanced_prompt_output, status_output],
)
if __name__ == "__main__":
demo.launch()