linoyts's picture
linoyts HF Staff
Update app.py (#1)
023d014 verified
"""
LTX-2 Gemma Text Encoder Space
Encodes text prompts using Gemma-3-12B for LTX-2 video generation.
Supports prompt enhancement for better results.
"""
import time
from pathlib import Path
import numpy as np
import spaces
import gradio as gr
import torch
from huggingface_hub import hf_hub_download,snapshot_download
MAX_SEED = np.iinfo(np.int32).max
# Import from public LTX-2 package
# Install with: pip install git+https://github.com/Lightricks/LTX-2.git
from ltx_pipelines.utils import ModelLedger
from ltx_pipelines.utils.helpers import generate_enhanced_prompt
# HuggingFace Hub defaults
DEFAULT_REPO_ID = "Lightricks/LTX-2"
DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
def get_hub_or_local_checkpoint(repo_id: str, filename: str):
"""Download from HuggingFace Hub."""
print(f"Downloading {filename} from {repo_id}...")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Downloaded to {ckpt_path}")
return ckpt_path
def download_gemma_model(repo_id: str):
"""Download the full Gemma model directory."""
print(f"Downloading Gemma model from {repo_id}...")
local_dir = snapshot_download(repo_id=repo_id)
print(f"Gemma model downloaded to {local_dir}")
return local_dir
# Initialize model ledger and text encoder at startup (load once, keep in memory)
print("=" * 80)
print("Loading Gemma Text Encoder...")
print("=" * 80)
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
gemma_local_path = download_gemma_model(DEFAULT_GEMMA_REPO_ID)
device = "cuda"
print(f"Initializing text encoder with:")
print(f" checkpoint_path={checkpoint_path}")
print(f" gemma_root={gemma_local_path}")
print(f" device={device}")
model_ledger = ModelLedger(
dtype=torch.bfloat16,
device=device,
checkpoint_path=checkpoint_path,
gemma_root_path=DEFAULT_GEMMA_REPO_ID,
local_files_only=False
)
# Load text encoder once and keep it in memory
text_encoder = model_ledger.text_encoder()
print("=" * 80)
print("Text encoder loaded and ready!")
print("=" * 80)
def encode_text_simple(text_encoder, prompt: str):
"""Simple text encoding without using pipeline_utils."""
v_context, a_context, _ = text_encoder(prompt)
return v_context, a_context
@spaces.GPU()
def encode_prompt(
prompt: str,
enhance_prompt: bool = True,
input_image = None,
seed: int = 42,
negative_prompt: str = ""
):
"""
Encode a text prompt using Gemma text encoder.
Args:
prompt: Text prompt to encode
enhance_prompt: Whether to use AI to enhance the prompt
input_image: Optional image for image-to-video enhancement
seed: Random seed for prompt enhancement
negative_prompt: Optional negative prompt for CFG (two-stage pipeline)
Returns:
tuple: (file_path, enhanced_prompt_text, status_message)
"""
start_time = time.time()
try:
# Enhance prompt if requested
final_prompt = prompt
if enhance_prompt:
if input_image is not None:
# Save image temporarily
temp_dir = Path("temp_images")
temp_dir.mkdir(exist_ok=True)
temp_image_path = temp_dir / f"temp_{int(time.time())}.jpg"
if hasattr(input_image, 'save'):
input_image.save(temp_image_path)
else:
temp_image_path = input_image
final_prompt = generate_enhanced_prompt(
text_encoder=text_encoder,
prompt=prompt,
image_path=str(temp_image_path),
seed=seed
)
else:
final_prompt = generate_enhanced_prompt(
text_encoder=text_encoder,
prompt=prompt,
image_path=None,
seed=seed
)
# Encode the positive prompt using the pre-loaded text encoder
video_context, audio_context = encode_text_simple(text_encoder, final_prompt)
# Encode negative prompt if provided
video_context_negative = None
audio_context_negative = None
if negative_prompt:
video_context_negative, audio_context_negative = encode_text_simple(text_encoder, negative_prompt)
# Save embeddings to file
output_dir = Path("embeddings")
output_dir.mkdir(exist_ok=True)
output_path = output_dir / f"embedding_{int(time.time())}.pt"
# Save embeddings (with negative contexts if provided)
embedding_data = {
'video_context': video_context.cpu(),
'audio_context': audio_context.cpu(),
'prompt': final_prompt,
'original_prompt': prompt if enhance_prompt else final_prompt,
}
# Add negative contexts if they were encoded
if video_context_negative is not None:
embedding_data['video_context_negative'] = video_context_negative.cpu()
embedding_data['audio_context_negative'] = audio_context_negative.cpu()
embedding_data['negative_prompt'] = negative_prompt
torch.save(embedding_data, output_path)
# Get memory stats
elapsed_time = time.time() - start_time
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
peak = torch.cuda.max_memory_allocated() / 1024**3
status = f"✓ Encoded in {elapsed_time:.2f}s | VRAM: {allocated:.2f}GB allocated, {peak:.2f}GB peak"
else:
status = f"✓ Encoded in {elapsed_time:.2f}s (CPU mode)"
return str(output_path), final_prompt, status
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, prompt, error_msg
# Create Gradio interface
with gr.Blocks(title="LTX-2 Gemma Text Encoder") as demo:
gr.Markdown("# LTX-2 Gemma Text Encoder 🎯")
gr.Markdown("""
Encode text prompts using Gemma-3-12B for LTX-2 video generation.
This space generates embeddings that can be used by the main LTX-2 generation space.
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5,
value="An astronaut hatches from a fragile egg on the surface of the Moon"
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="Enter negative prompt for CFG (used by two-stage pipeline)...",
lines=2,
value=""
)
enhance_checkbox = gr.Checkbox(
label="Enhance Prompt",
value=True,
info="Use Gemma to automatically enhance your prompt for better results"
)
with gr.Accordion("Prompt Enhancement Settings", open=False):
input_image = gr.Image(
label="Reference Image (Optional)",
type="filepath",
)
enhancement_seed = gr.Slider(
label="Enhancement Seed",
minimum=0,
maximum=MAX_SEED,
value=42,
step=1,
info="Random seed for prompt enhancement"
)
encode_btn = gr.Button("Encode Prompt", variant="primary", size="lg")
with gr.Column():
embedding_file = gr.File(label="Embedding File (.pt)")
enhanced_prompt_output = gr.Textbox(
label="Final Prompt Used",
lines=5,
info="This is the prompt that was encoded (enhanced if enabled)"
)
status_output = gr.Textbox(label="Status", lines=2)
encode_btn.click(
fn=encode_prompt,
inputs=[prompt_input, enhance_checkbox, input_image, enhancement_seed, negative_prompt_input],
outputs=[embedding_file, enhanced_prompt_output, status_output]
)
css = '''
.gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
'''
if __name__ == "__main__":
demo.launch(css=css)