linoyts HF Staff commited on
Commit
5625a61
·
verified ·
1 Parent(s): 2576965

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LTX-2 Gemma Text Encoder Space
3
+ Encodes text prompts using Gemma-3-12B for LTX-2 video generation.
4
+ Supports prompt enhancement for better results.
5
+ """
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import spaces
10
+ import gradio as gr
11
+ import torch
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ # Import from public LTX-2 package
15
+ # Install with: pip install git+https://github.com/Lightricks/LTX-2.git
16
+ from ltx_pipelines.utils import ModelLedger, get_device
17
+ from ltx_pipelines.utils.helpers import generate_enhanced_prompt
18
+
19
+ # HuggingFace Hub defaults
20
+ DEFAULT_REPO_ID = "Lightricks/LTX-2"
21
+ DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
22
+ DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
23
+
24
+ def get_hub_or_local_checkpoint(repo_id: str, filename: str):
25
+ """Download from HuggingFace Hub."""
26
+ print(f"Downloading {filename} from {repo_id}...")
27
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
28
+ print(f"Downloaded to {ckpt_path}")
29
+ return ckpt_path
30
+
31
+ # Initialize model ledger and text encoder at startup (load once, keep in memory)
32
+ print("=" * 80)
33
+ print("Loading Gemma Text Encoder...")
34
+ print("=" * 80)
35
+
36
+ checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
37
+ device = get_device()
38
+
39
+ print(f"Initializing text encoder with:")
40
+ print(f" checkpoint_path={checkpoint_path}")
41
+ print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
42
+ print(f" device={device}")
43
+
44
+ model_ledger = ModelLedger(
45
+ dtype=torch.bfloat16,
46
+ device=device,
47
+ checkpoint_path=checkpoint_path,
48
+ gemma_root_path=DEFAULT_GEMMA_REPO_ID,
49
+ local_files_only=False,
50
+ )
51
+
52
+ # Load text encoder once and keep it in memory
53
+ text_encoder = model_ledger.text_encoder()
54
+
55
+ print("=" * 80)
56
+ print("Text encoder loaded and ready!")
57
+ print("=" * 80)
58
+
59
+ def encode_text_simple(text_encoder, prompt: str):
60
+ """Simple text encoding without using pipeline_utils."""
61
+ v_context, a_context, _ = text_encoder(prompt)
62
+ return v_context, a_context
63
+
64
+ @spaces.GPU()
65
+ def encode_prompt(
66
+ prompt: str,
67
+ enhance_prompt: bool = False,
68
+ input_image = None,
69
+ seed: int = 42
70
+ ):
71
+ """
72
+ Encode a text prompt using Gemma text encoder.
73
+
74
+ Args:
75
+ prompt: Text prompt to encode
76
+ enhance_prompt: Whether to use AI to enhance the prompt
77
+ input_image: Optional image for image-to-video enhancement
78
+ seed: Random seed for prompt enhancement
79
+
80
+ Returns:
81
+ tuple: (file_path, enhanced_prompt_text, status_message)
82
+ """
83
+ start_time = time.time()
84
+
85
+ try:
86
+ # Enhance prompt if requested
87
+ final_prompt = prompt
88
+ if enhance_prompt:
89
+ if input_image is not None:
90
+ # Save image temporarily
91
+ temp_dir = Path("temp_images")
92
+ temp_dir.mkdir(exist_ok=True)
93
+ temp_image_path = temp_dir / f"temp_{int(time.time())}.jpg"
94
+ if hasattr(input_image, 'save'):
95
+ input_image.save(temp_image_path)
96
+ else:
97
+ temp_image_path = input_image
98
+
99
+ final_prompt = generate_enhanced_prompt(
100
+ text_encoder=text_encoder,
101
+ prompt=prompt,
102
+ image_path=str(temp_image_path),
103
+ seed=seed
104
+ )
105
+ else:
106
+ final_prompt = generate_enhanced_prompt(
107
+ text_encoder=text_encoder,
108
+ prompt=prompt,
109
+ image_path=None,
110
+ seed=seed
111
+ )
112
+
113
+ # Encode the prompt using the pre-loaded text encoder
114
+ video_context, audio_context = encode_text_simple(text_encoder, final_prompt)
115
+
116
+ # Save embeddings to file
117
+ output_dir = Path("embeddings")
118
+ output_dir.mkdir(exist_ok=True)
119
+ output_path = output_dir / f"embedding_{int(time.time())}.pt"
120
+
121
+ # Save both video and audio contexts
122
+ torch.save({
123
+ 'video_context': video_context.cpu(),
124
+ 'audio_context': audio_context.cpu(),
125
+ 'prompt': final_prompt,
126
+ 'original_prompt': prompt if enhance_prompt else final_prompt,
127
+ }, output_path)
128
+
129
+ # Get memory stats
130
+ elapsed_time = time.time() - start_time
131
+ if torch.cuda.is_available():
132
+ allocated = torch.cuda.memory_allocated() / 1024**3
133
+ peak = torch.cuda.max_memory_allocated() / 1024**3
134
+ status = f"✓ Encoded in {elapsed_time:.2f}s | VRAM: {allocated:.2f}GB allocated, {peak:.2f}GB peak"
135
+ else:
136
+ status = f"✓ Encoded in {elapsed_time:.2f}s (CPU mode)"
137
+
138
+ return str(output_path), final_prompt, status
139
+
140
+ except Exception as e:
141
+ import traceback
142
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
143
+ print(error_msg)
144
+ return None, prompt, error_msg
145
+
146
+
147
+ # Create Gradio interface
148
+ with gr.Blocks(title="LTX-2 Gemma Text Encoder") as demo:
149
+ gr.Markdown("# LTX-2 Gemma Text Encoder 🎯")
150
+ gr.Markdown("""
151
+ Encode text prompts using Gemma-3-12B for LTX-2 video generation.
152
+ This space generates embeddings that can be used by the main LTX-2 generation space.
153
+ """)
154
+
155
+ with gr.Row():
156
+ with gr.Column():
157
+ prompt_input = gr.Textbox(
158
+ label="Prompt",
159
+ placeholder="Enter your prompt here...",
160
+ lines=5,
161
+ value="An astronaut hatches from a fragile egg on the surface of the Moon"
162
+ )
163
+
164
+ enhance_checkbox = gr.Checkbox(
165
+ label="Enhance Prompt with AI",
166
+ value=False,
167
+ info="Use Gemma to automatically enhance your prompt for better results"
168
+ )
169
+
170
+ with gr.Accordion("Prompt Enhancement Settings", open=False):
171
+ input_image = gr.Image(
172
+ label="Reference Image (Optional)",
173
+ type="pil",
174
+ info="Provide an image for image-to-video prompt enhancement"
175
+ )
176
+ enhancement_seed = gr.Slider(
177
+ label="Enhancement Seed",
178
+ minimum=0,
179
+ maximum=10000,
180
+ value=42,
181
+ step=1,
182
+ info="Random seed for prompt enhancement"
183
+ )
184
+
185
+ encode_btn = gr.Button("Encode Prompt", variant="primary", size="lg")
186
+
187
+ with gr.Column():
188
+ embedding_file = gr.File(label="Embedding File (.pt)")
189
+ enhanced_prompt_output = gr.Textbox(
190
+ label="Final Prompt Used",
191
+ lines=5,
192
+ info="This is the prompt that was encoded (enhanced if enabled)"
193
+ )
194
+ status_output = gr.Textbox(label="Status", lines=2)
195
+
196
+ encode_btn.click(
197
+ fn=encode_prompt,
198
+ inputs=[prompt_input, enhance_checkbox, input_image, enhancement_seed],
199
+ outputs=[embedding_file, enhanced_prompt_output, status_output]
200
+ )
201
+
202
+ css = '''
203
+ .gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
204
+ '''
205
+
206
+ if __name__ == "__main__":
207
+ demo.launch(css=css)