linoyts HF Staff commited on
Commit
ade6776
·
verified ·
1 Parent(s): afd2ca0

improvements + fixes [wip]

Browse files
Files changed (1) hide show
  1. app.py +64 -7
app.py CHANGED
@@ -8,10 +8,14 @@ sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
8
 
9
  import spaces
10
  import gradio as gr
 
11
  import numpy as np
12
  import random
 
13
  from typing import Optional
 
14
  from huggingface_hub import hf_hub_download
 
15
  from ltx_pipelines.distilled import DistilledPipeline
16
  from ltx_core.tiling import TilingConfig
17
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
@@ -31,11 +35,13 @@ DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the
31
 
32
  # HuggingFace Hub defaults
33
  DEFAULT_REPO_ID = "Lightricks/LTX-2"
34
- DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
35
  DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
36
  DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384.safetensors"
37
  DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0.safetensors"
38
 
 
 
 
39
  def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
40
  """Download from HuggingFace Hub or use local checkpoint."""
41
  if repo_id is None and filename is None:
@@ -66,7 +72,7 @@ print(f"Initializing pipeline with:")
66
  print(f" checkpoint_path={checkpoint_path}")
67
  print(f" distilled_lora_path={distilled_lora_path}")
68
  print(f" spatial_upsampler_path={spatial_upsampler_path}")
69
- print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
70
 
71
  # Load distilled LoRA as a regular LoRA
72
  loras = [
@@ -77,15 +83,26 @@ loras = [
77
  )
78
  ]
79
 
 
 
80
  pipeline = DistilledPipeline(
81
  checkpoint_path=checkpoint_path,
82
  spatial_upsampler_path=spatial_upsampler_path,
83
- gemma_root=DEFAULT_GEMMA_REPO_ID,
84
  loras=loras,
85
  fp8transformer=True,
86
  local_files_only=False,
87
  )
88
 
 
 
 
 
 
 
 
 
 
89
  print("=" * 80)
90
  print("Pipeline fully loaded and ready!")
91
  print("=" * 80)
@@ -113,20 +130,58 @@ def generate_video(
113
  # Create output directory if it doesn't exist
114
  output_dir = Path("outputs")
115
  output_dir.mkdir(exist_ok=True)
116
- output_path = output_dir / f"video_{seed}.mp4"
117
 
118
  # Handle image input
119
  images = []
 
 
120
  if input_image is not None:
121
  # Save uploaded image temporarily
122
- temp_image_path = output_dir / f"temp_input_{seed}.jpg"
123
  if hasattr(input_image, 'save'):
124
  input_image.save(temp_image_path)
125
  else:
126
  # If it's a file path already
127
- temp_image_path = input_image
128
  # Format: (image_path, frame_idx, strength)
129
  images = [(str(temp_image_path), 0, 1.0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # Run inference - progress automatically tracks tqdm from pipeline
132
  pipeline(
@@ -139,6 +194,8 @@ def generate_video(
139
  frame_rate=frame_rate,
140
  images=images,
141
  tiling_config=TilingConfig.default(),
 
 
142
  )
143
 
144
  return str(output_path), current_seed
@@ -250,4 +307,4 @@ css = '''
250
  .gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
251
  '''
252
  if __name__ == "__main__":
253
- demo.launch(theme=gr.themes.Citrus(), css=css)
 
8
 
9
  import spaces
10
  import gradio as gr
11
+ from gradio_client import Client, handle_file
12
  import numpy as np
13
  import random
14
+ import torch
15
  from typing import Optional
16
+ from pathlib import Path
17
  from huggingface_hub import hf_hub_download
18
+ from gradio_client import Client
19
  from ltx_pipelines.distilled import DistilledPipeline
20
  from ltx_core.tiling import TilingConfig
21
  from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
 
35
 
36
  # HuggingFace Hub defaults
37
  DEFAULT_REPO_ID = "Lightricks/LTX-2"
 
38
  DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
39
  DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384.safetensors"
40
  DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0.safetensors"
41
 
42
+ # Text encoder space URL
43
+ TEXT_ENCODER_SPACE = "linoyts/gemma-text-encoder"
44
+
45
  def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
46
  """Download from HuggingFace Hub or use local checkpoint."""
47
  if repo_id is None and filename is None:
 
72
  print(f" checkpoint_path={checkpoint_path}")
73
  print(f" distilled_lora_path={distilled_lora_path}")
74
  print(f" spatial_upsampler_path={spatial_upsampler_path}")
75
+ print(f" text_encoder_space={TEXT_ENCODER_SPACE}")
76
 
77
  # Load distilled LoRA as a regular LoRA
78
  loras = [
 
83
  )
84
  ]
85
 
86
+ # Initialize pipeline WITHOUT text encoder (gemma_root=None)
87
+ # Text encoding will be done by external space
88
  pipeline = DistilledPipeline(
89
  checkpoint_path=checkpoint_path,
90
  spatial_upsampler_path=spatial_upsampler_path,
91
+ gemma_root=None, # No text encoder in this space
92
  loras=loras,
93
  fp8transformer=True,
94
  local_files_only=False,
95
  )
96
 
97
+ # Initialize text encoder client
98
+ print(f"Connecting to text encoder space: {TEXT_ENCODER_SPACE}")
99
+ try:
100
+ text_encoder_client = Client(TEXT_ENCODER_SPACE)
101
+ print("✓ Text encoder client connected!")
102
+ except Exception as e:
103
+ print(f"⚠ Warning: Could not connect to text encoder space: {e}")
104
+ text_encoder_client = None
105
+
106
  print("=" * 80)
107
  print("Pipeline fully loaded and ready!")
108
  print("=" * 80)
 
130
  # Create output directory if it doesn't exist
131
  output_dir = Path("outputs")
132
  output_dir.mkdir(exist_ok=True)
133
+ output_path = output_dir / f"video_{current_seed}.mp4"
134
 
135
  # Handle image input
136
  images = []
137
+ temp_image_path = None # Initialize to None
138
+
139
  if input_image is not None:
140
  # Save uploaded image temporarily
141
+ temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
142
  if hasattr(input_image, 'save'):
143
  input_image.save(temp_image_path)
144
  else:
145
  # If it's a file path already
146
+ temp_image_path = Path(input_image)
147
  # Format: (image_path, frame_idx, strength)
148
  images = [(str(temp_image_path), 0, 1.0)]
149
+
150
+ # Get embeddings from text encoder space
151
+ print(f"Encoding prompt: {prompt}")
152
+
153
+ if text_encoder_client is None:
154
+ raise RuntimeError(
155
+ f"Text encoder client not connected. Please ensure the text encoder space "
156
+ f"({TEXT_ENCODER_SPACE}) is running and accessible."
157
+ )
158
+
159
+ try:
160
+ # Prepare image for upload if it exists
161
+ image_input = None
162
+ if temp_image_path is not None:
163
+ image_input = handle_file(str(temp_image_path))
164
+
165
+ result = text_encoder_client.predict(
166
+ prompt=prompt,
167
+ enhance_prompt=True,
168
+ input_image=image_input,
169
+ seed=current_seed,
170
+ api_name="/encode_prompt"
171
+ )
172
+ embedding_path = result[0] # Path to .pt file
173
+ print(f"Embeddings received from: {embedding_path}")
174
+
175
+ # Load embeddings
176
+ embeddings = torch.load(embedding_path)
177
+ video_context = embeddings['video_context']
178
+ audio_context = embeddings['audio_context']
179
+ print("✓ Embeddings loaded successfully")
180
+ except Exception as e:
181
+ raise RuntimeError(
182
+ f"Failed to get embeddings from text encoder space: {e}\n"
183
+ f"Please ensure {TEXT_ENCODER_SPACE} is running properly."
184
+ )
185
 
186
  # Run inference - progress automatically tracks tqdm from pipeline
187
  pipeline(
 
194
  frame_rate=frame_rate,
195
  images=images,
196
  tiling_config=TilingConfig.default(),
197
+ video_context=video_context,
198
+ audio_context=audio_context,
199
  )
200
 
201
  return str(output_path), current_seed
 
307
  .gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
308
  '''
309
  if __name__ == "__main__":
310
+ demo.launch(theme=gr.themes.Citrus(), css=css, share=True)