alexnasa commited on
Commit
1fec881
·
verified ·
1 Parent(s): 5b43fde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -47
app.py CHANGED
@@ -7,15 +7,15 @@ sys.path.insert(0, str(current_dir / "packages" / "ltx-pipelines" / "src"))
7
  sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
8
 
9
  import spaces
 
 
10
  import gradio as gr
11
- from gradio_client import Client, handle_file
12
  import numpy as np
13
  import random
14
  import torch
15
  from typing import Optional
16
  from pathlib import Path
17
- from huggingface_hub import hf_hub_download
18
- from gradio_client import Client
19
  from ltx_pipelines.distilled import DistilledPipeline
20
  from ltx_core.model.video_vae import TilingConfig
21
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
@@ -29,7 +29,165 @@ from ltx_pipelines.utils.constants import (
29
  DEFAULT_LORA_STRENGTH,
30
  )
31
 
 
32
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Default prompt from docstring example
34
  DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
35
 
@@ -86,6 +244,7 @@ loras = [
86
  # Initialize pipeline WITHOUT text encoder (gemma_root=None)
87
  # Text encoding will be done by external space
88
  pipeline = DistilledPipeline(
 
89
  checkpoint_path=checkpoint_path,
90
  spatial_upsampler_path=spatial_upsampler_path,
91
  gemma_root=None, # No text encoder in this space
@@ -93,23 +252,18 @@ pipeline = DistilledPipeline(
93
  fp8transformer=False,
94
  local_files_only=False,
95
  )
 
96
  pipeline._video_encoder = pipeline.model_ledger.video_encoder()
97
  pipeline._transformer = pipeline.model_ledger.transformer()
 
 
98
 
99
- # Initialize text encoder client
100
- print(f"Connecting to text encoder space: {TEXT_ENCODER_SPACE}")
101
- try:
102
- text_encoder_client = Client(TEXT_ENCODER_SPACE)
103
- print("✓ Text encoder client connected!")
104
- except Exception as e:
105
- print(f"⚠ Warning: Could not connect to text encoder space: {e}")
106
- text_encoder_client = None
107
 
108
  print("=" * 80)
109
  print("Pipeline fully loaded and ready!")
110
  print("=" * 80)
111
 
112
- @spaces.GPU(duration=300)
113
  def generate_video(
114
  input_image,
115
  prompt: str,
@@ -118,9 +272,10 @@ def generate_video(
118
  seed: int = 42,
119
  randomize_seed: bool = True,
120
  height: int = DEFAULT_1_STAGE_HEIGHT,
121
- width: int = DEFAULT_1_STAGE_WIDTH,
122
  progress=gr.Progress(track_tqdm=True)
123
  ):
 
124
  """Generate a video based on the given parameters."""
125
  try:
126
  # Randomize seed if checkbox is enabled
@@ -153,39 +308,25 @@ def generate_video(
153
  # Get embeddings from text encoder space
154
  print(f"Encoding prompt: {prompt}")
155
 
156
- if text_encoder_client is None:
157
- raise RuntimeError(
158
- f"Text encoder client not connected. Please ensure the text encoder space "
159
- f"({TEXT_ENCODER_SPACE}) is running and accessible."
160
- )
161
-
162
- try:
163
- # Prepare image for upload if it exists
164
- image_input = None
165
- if temp_image_path is not None:
166
- image_input = handle_file(str(temp_image_path))
167
-
168
- result = text_encoder_client.predict(
169
- prompt=prompt,
170
- enhance_prompt=enhance_prompt,
171
- input_image=image_input,
172
- seed=current_seed,
173
- negative_prompt="",
174
- api_name="/encode_prompt"
175
- )
176
- embedding_path = result[0] # Path to .pt file
177
- print(f"Embeddings received from: {embedding_path}")
178
-
179
- # Load embeddings
180
- embeddings = torch.load(embedding_path)
181
- video_context = embeddings['video_context']
182
- audio_context = embeddings['audio_context']
183
- print("✓ Embeddings loaded successfully")
184
- except Exception as e:
185
- raise RuntimeError(
186
- f"Failed to get embeddings from text encoder space: {e}\n"
187
- f"Please ensure {TEXT_ENCODER_SPACE} is running properly."
188
- )
189
 
190
  # Run inference - progress automatically tracks tqdm from pipeline
191
  pipeline(
@@ -321,4 +462,4 @@ css = '''
321
  .gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
322
  '''
323
  if __name__ == "__main__":
324
- demo.launch(theme=gr.themes.Citrus(), css=css)
 
7
  sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
8
 
9
  import spaces
10
+ import flash_attn_interface
11
+ import time
12
  import gradio as gr
 
13
  import numpy as np
14
  import random
15
  import torch
16
  from typing import Optional
17
  from pathlib import Path
18
+ from huggingface_hub import hf_hub_download, snapshot_download
 
19
  from ltx_pipelines.distilled import DistilledPipeline
20
  from ltx_core.model.video_vae import TilingConfig
21
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
 
29
  DEFAULT_LORA_STRENGTH,
30
  )
31
 
32
+
33
  MAX_SEED = np.iinfo(np.int32).max
34
+ # Import from public LTX-2 package
35
+ # Install with: pip install git+https://github.com/Lightricks/LTX-2.git
36
+ from ltx_pipelines.utils import ModelLedger
37
+ from ltx_pipelines.utils.helpers import generate_enhanced_prompt
38
+
39
+ # HuggingFace Hub defaults
40
+ DEFAULT_REPO_ID = "Lightricks/LTX-2"
41
+ DEFAULT_GEMMA_REPO_ID = "unsloth/gemma-3-12b-it-qat-bnb-4bit"
42
+ DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev.safetensors"
43
+
44
+
45
+ def get_hub_or_local_checkpoint(repo_id: str, filename: str):
46
+ """Download from HuggingFace Hub."""
47
+ print(f"Downloading {filename} from {repo_id}...")
48
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
49
+ print(f"Downloaded to {ckpt_path}")
50
+ return ckpt_path
51
+
52
+ def download_gemma_model(repo_id: str):
53
+ """Download the full Gemma model directory."""
54
+ print(f"Downloading Gemma model from {repo_id}...")
55
+ local_dir = snapshot_download(repo_id=repo_id)
56
+ print(f"Gemma model downloaded to {local_dir}")
57
+ return local_dir
58
+
59
+ # Initialize model ledger and text encoder at startup (load once, keep in memory)
60
+ print("=" * 80)
61
+ print("Loading Gemma Text Encoder...")
62
+ print("=" * 80)
63
+
64
+ checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
65
+ gemma_local_path = download_gemma_model(DEFAULT_GEMMA_REPO_ID)
66
+ device = "cuda"
67
+
68
+ print(f"Initializing text encoder with:")
69
+ print(f" checkpoint_path={checkpoint_path}")
70
+ print(f" gemma_root={gemma_local_path}")
71
+ print(f" device={device}")
72
+
73
+
74
+ model_ledger = ModelLedger(
75
+ dtype=torch.bfloat16,
76
+ device=device,
77
+ checkpoint_path=checkpoint_path,
78
+ gemma_root_path=DEFAULT_GEMMA_REPO_ID,
79
+ local_files_only=False
80
+ )
81
+
82
+
83
+ # Load text encoder once and keep it in memory
84
+ text_encoder = model_ledger.text_encoder()
85
+
86
+ print("=" * 80)
87
+ print("Text encoder loaded and ready!")
88
+ print("=" * 80)
89
+
90
+ def encode_text_simple(text_encoder, prompt: str):
91
+ """Simple text encoding without using pipeline_utils."""
92
+ v_context, a_context, _ = text_encoder(prompt)
93
+ return v_context, a_context
94
+
95
+ @spaces.GPU()
96
+ def encode_prompt(
97
+ prompt: str,
98
+ enhance_prompt: bool = True,
99
+ input_image = None,
100
+ seed: int = 42,
101
+ negative_prompt: str = ""
102
+ ):
103
+ """
104
+ Encode a text prompt using Gemma text encoder.
105
+ Args:
106
+ prompt: Text prompt to encode
107
+ enhance_prompt: Whether to use AI to enhance the prompt
108
+ input_image: Optional image for image-to-video enhancement
109
+ seed: Random seed for prompt enhancement
110
+ negative_prompt: Optional negative prompt for CFG (two-stage pipeline)
111
+ Returns:
112
+ tuple: (file_path, enhanced_prompt_text, status_message)
113
+ """
114
+ start_time = time.time()
115
+
116
+ try:
117
+ # Enhance prompt if requested
118
+ final_prompt = prompt
119
+ if enhance_prompt:
120
+ if input_image is not None:
121
+ # Save image temporarily
122
+ temp_dir = Path("temp_images")
123
+ temp_dir.mkdir(exist_ok=True)
124
+ temp_image_path = temp_dir / f"temp_{int(time.time())}.jpg"
125
+ if hasattr(input_image, 'save'):
126
+ input_image.save(temp_image_path)
127
+ else:
128
+ temp_image_path = input_image
129
+
130
+ final_prompt = generate_enhanced_prompt(
131
+ text_encoder=text_encoder,
132
+ prompt=prompt,
133
+ image_path=str(temp_image_path),
134
+ seed=seed
135
+ )
136
+ else:
137
+ final_prompt = generate_enhanced_prompt(
138
+ text_encoder=text_encoder,
139
+ prompt=prompt,
140
+ image_path=None,
141
+ seed=seed
142
+ )
143
+
144
+ # Encode the positive prompt using the pre-loaded text encoder
145
+ video_context, audio_context = encode_text_simple(text_encoder, final_prompt)
146
+
147
+ # Encode negative prompt if provided
148
+ video_context_negative = None
149
+ audio_context_negative = None
150
+ if negative_prompt:
151
+ video_context_negative, audio_context_negative = encode_text_simple(text_encoder, negative_prompt)
152
+
153
+ # Save embeddings to file
154
+ output_dir = Path("embeddings")
155
+ output_dir.mkdir(exist_ok=True)
156
+ output_path = output_dir / f"embedding_{int(time.time())}.pt"
157
+
158
+ # Save embeddings (with negative contexts if provided)
159
+ embedding_data = {
160
+ 'video_context': video_context.cpu(),
161
+ 'audio_context': audio_context.cpu(),
162
+ 'prompt': final_prompt,
163
+ 'original_prompt': prompt if enhance_prompt else final_prompt,
164
+ }
165
+
166
+ # Add negative contexts if they were encoded
167
+ if video_context_negative is not None:
168
+ embedding_data['video_context_negative'] = video_context_negative.cpu()
169
+ embedding_data['audio_context_negative'] = audio_context_negative.cpu()
170
+ embedding_data['negative_prompt'] = negative_prompt
171
+
172
+ torch.save(embedding_data, output_path)
173
+
174
+ # Get memory stats
175
+ elapsed_time = time.time() - start_time
176
+ if torch.cuda.is_available():
177
+ allocated = torch.cuda.memory_allocated() / 1024**3
178
+ peak = torch.cuda.max_memory_allocated() / 1024**3
179
+ status = f"✓ Encoded in {elapsed_time:.2f}s | VRAM: {allocated:.2f}GB allocated, {peak:.2f}GB peak"
180
+ else:
181
+ status = f"✓ Encoded in {elapsed_time:.2f}s (CPU mode)"
182
+
183
+ return str(output_path), final_prompt, status
184
+
185
+ except Exception as e:
186
+ import traceback
187
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
188
+ print(error_msg)
189
+ return None, prompt, error_msg
190
+
191
  # Default prompt from docstring example
192
  DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
193
 
 
244
  # Initialize pipeline WITHOUT text encoder (gemma_root=None)
245
  # Text encoding will be done by external space
246
  pipeline = DistilledPipeline(
247
+ device=torch.device("cuda"),
248
  checkpoint_path=checkpoint_path,
249
  spatial_upsampler_path=spatial_upsampler_path,
250
  gemma_root=None, # No text encoder in this space
 
252
  fp8transformer=False,
253
  local_files_only=False,
254
  )
255
+
256
  pipeline._video_encoder = pipeline.model_ledger.video_encoder()
257
  pipeline._transformer = pipeline.model_ledger.transformer()
258
+ # pipeline.device = torch.device("cuda")
259
+ # pipeline.model_ledger.device = torch.device("cuda")
260
 
 
 
 
 
 
 
 
 
261
 
262
  print("=" * 80)
263
  print("Pipeline fully loaded and ready!")
264
  print("=" * 80)
265
 
266
+ @spaces.GPU(duration=80)
267
  def generate_video(
268
  input_image,
269
  prompt: str,
 
272
  seed: int = 42,
273
  randomize_seed: bool = True,
274
  height: int = DEFAULT_1_STAGE_HEIGHT,
275
+ width: int = DEFAULT_1_STAGE_WIDTH ,
276
  progress=gr.Progress(track_tqdm=True)
277
  ):
278
+
279
  """Generate a video based on the given parameters."""
280
  try:
281
  # Randomize seed if checkbox is enabled
 
308
  # Get embeddings from text encoder space
309
  print(f"Encoding prompt: {prompt}")
310
 
311
+ # Prepare image for upload if it exists
312
+ image_input = None
313
+
314
+
315
+ result = encode_prompt(
316
+ prompt=prompt,
317
+ enhance_prompt=enhance_prompt,
318
+ input_image=input_image,
319
+ seed=current_seed,
320
+ negative_prompt="",
321
+ )
322
+ embedding_path = result[0] # Path to .pt file
323
+ print(f"Embeddings received from: {embedding_path}")
324
+
325
+ # Load embeddings
326
+ embeddings = torch.load(embedding_path)
327
+ video_context = embeddings['video_context']
328
+ audio_context = embeddings['audio_context']
329
+ print("✓ Embeddings loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  # Run inference - progress automatically tracks tqdm from pipeline
332
  pipeline(
 
462
  .gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
463
  '''
464
  if __name__ == "__main__":
465
+ demo.launch(theme=gr.themes.Citrus(), css=css)