""" LTX-2 Gemma Text Encoder Space Encodes text prompts using Gemma-3-12B for LTX-2 video generation. Supports prompt enhancement for better results. """ import time from pathlib import Path import numpy as np import spaces import gradio as gr import torch from huggingface_hub import hf_hub_download,snapshot_download MAX_SEED = np.iinfo(np.int32).max # Import from public LTX-2 package # Install with: pip install git+https://github.com/Lightricks/LTX-2.git from ltx_pipelines.utils import ModelLedger from ltx_pipelines.utils.helpers import generate_enhanced_prompt # HuggingFace Hub defaults DEFAULT_REPO_ID = "Lightricks/LTX-2" DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized" DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors" def get_hub_or_local_checkpoint(repo_id: str, filename: str): """Download from HuggingFace Hub.""" print(f"Downloading {filename} from {repo_id}...") ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) print(f"Downloaded to {ckpt_path}") return ckpt_path def download_gemma_model(repo_id: str): """Download the full Gemma model directory.""" print(f"Downloading Gemma model from {repo_id}...") local_dir = snapshot_download(repo_id=repo_id) print(f"Gemma model downloaded to {local_dir}") return local_dir # Initialize model ledger and text encoder at startup (load once, keep in memory) print("=" * 80) print("Loading Gemma Text Encoder...") print("=" * 80) checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME) gemma_local_path = download_gemma_model(DEFAULT_GEMMA_REPO_ID) device = "cuda" print(f"Initializing text encoder with:") print(f" checkpoint_path={checkpoint_path}") print(f" gemma_root={gemma_local_path}") print(f" device={device}") model_ledger = ModelLedger( dtype=torch.bfloat16, device=device, checkpoint_path=checkpoint_path, gemma_root_path=DEFAULT_GEMMA_REPO_ID, local_files_only=False ) # Load text encoder once and keep it in memory text_encoder = model_ledger.text_encoder() print("=" * 80) print("Text encoder loaded and ready!") print("=" * 80) def encode_text_simple(text_encoder, prompt: str): """Simple text encoding without using pipeline_utils.""" v_context, a_context, _ = text_encoder(prompt) return v_context, a_context @spaces.GPU() def encode_prompt( prompt: str, enhance_prompt: bool = True, input_image = None, seed: int = 42, negative_prompt: str = "" ): """ Encode a text prompt using Gemma text encoder. Args: prompt: Text prompt to encode enhance_prompt: Whether to use AI to enhance the prompt input_image: Optional image for image-to-video enhancement seed: Random seed for prompt enhancement negative_prompt: Optional negative prompt for CFG (two-stage pipeline) Returns: tuple: (file_path, enhanced_prompt_text, status_message) """ start_time = time.time() try: # Enhance prompt if requested final_prompt = prompt if enhance_prompt: if input_image is not None: # Save image temporarily temp_dir = Path("temp_images") temp_dir.mkdir(exist_ok=True) temp_image_path = temp_dir / f"temp_{int(time.time())}.jpg" if hasattr(input_image, 'save'): input_image.save(temp_image_path) else: temp_image_path = input_image final_prompt = generate_enhanced_prompt( text_encoder=text_encoder, prompt=prompt, image_path=str(temp_image_path), seed=seed ) else: final_prompt = generate_enhanced_prompt( text_encoder=text_encoder, prompt=prompt, image_path=None, seed=seed ) # Encode the positive prompt using the pre-loaded text encoder video_context, audio_context = encode_text_simple(text_encoder, final_prompt) # Encode negative prompt if provided video_context_negative = None audio_context_negative = None if negative_prompt: video_context_negative, audio_context_negative = encode_text_simple(text_encoder, negative_prompt) # Save embeddings to file output_dir = Path("embeddings") output_dir.mkdir(exist_ok=True) output_path = output_dir / f"embedding_{int(time.time())}.pt" # Save embeddings (with negative contexts if provided) embedding_data = { 'video_context': video_context.cpu(), 'audio_context': audio_context.cpu(), 'prompt': final_prompt, 'original_prompt': prompt if enhance_prompt else final_prompt, } # Add negative contexts if they were encoded if video_context_negative is not None: embedding_data['video_context_negative'] = video_context_negative.cpu() embedding_data['audio_context_negative'] = audio_context_negative.cpu() embedding_data['negative_prompt'] = negative_prompt torch.save(embedding_data, output_path) # Get memory stats elapsed_time = time.time() - start_time if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 peak = torch.cuda.max_memory_allocated() / 1024**3 status = f"✓ Encoded in {elapsed_time:.2f}s | VRAM: {allocated:.2f}GB allocated, {peak:.2f}GB peak" else: status = f"✓ Encoded in {elapsed_time:.2f}s (CPU mode)" return str(output_path), final_prompt, status except Exception as e: import traceback error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return None, prompt, error_msg # Create Gradio interface with gr.Blocks(title="LTX-2 Gemma Text Encoder") as demo: gr.Markdown("# LTX-2 Gemma Text Encoder 🎯") gr.Markdown(""" Encode text prompts using Gemma-3-12B for LTX-2 video generation. This space generates embeddings that can be used by the main LTX-2 generation space. """) with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=5, value="An astronaut hatches from a fragile egg on the surface of the Moon" ) negative_prompt_input = gr.Textbox( label="Negative Prompt (Optional)", placeholder="Enter negative prompt for CFG (used by two-stage pipeline)...", lines=2, value="" ) enhance_checkbox = gr.Checkbox( label="Enhance Prompt", value=True, info="Use Gemma to automatically enhance your prompt for better results" ) with gr.Accordion("Prompt Enhancement Settings", open=False): input_image = gr.Image( label="Reference Image (Optional)", type="filepath", ) enhancement_seed = gr.Slider( label="Enhancement Seed", minimum=0, maximum=MAX_SEED, value=42, step=1, info="Random seed for prompt enhancement" ) encode_btn = gr.Button("Encode Prompt", variant="primary", size="lg") with gr.Column(): embedding_file = gr.File(label="Embedding File (.pt)") enhanced_prompt_output = gr.Textbox( label="Final Prompt Used", lines=5, info="This is the prompt that was encoded (enhanced if enabled)" ) status_output = gr.Textbox(label="Status", lines=2) encode_btn.click( fn=encode_prompt, inputs=[prompt_input, enhance_checkbox, input_image, enhancement_seed, negative_prompt_input], outputs=[embedding_file, enhanced_prompt_output, status_output] ) css = ''' .gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important} ''' if __name__ == "__main__": demo.launch(css=css)