multimodalart HF Staff commited on
Commit
7345819
·
verified ·
1 Parent(s): c266f1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -89
app.py CHANGED
@@ -14,7 +14,6 @@ if not os.path.exists(LTX_REPO_DIR):
14
  print(f"Cloning {LTX_REPO_URL}...")
15
  subprocess.run(["git", "clone", "--depth", "1", LTX_REPO_URL, LTX_REPO_DIR], check=True)
16
 
17
- # Install ltx-core and ltx-pipelines if not already installed
18
  try:
19
  import ltx_pipelines # noqa: F401
20
  except ImportError:
@@ -32,6 +31,7 @@ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
32
  import logging
33
  import random
34
  import tempfile
 
35
 
36
  import torch
37
  torch._dynamo.config.suppress_errors = True
@@ -40,7 +40,8 @@ torch._dynamo.config.disable = True
40
  import spaces
41
  import gradio as gr
42
  import numpy as np
43
- from huggingface_hub import hf_hub_download, snapshot_download
 
44
 
45
  from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
46
  from ltx_core.quantization import QuantizationPolicy
@@ -53,118 +54,178 @@ logging.getLogger().setLevel(logging.INFO)
53
  MAX_SEED = np.iinfo(np.int32).max
54
  DEFAULT_PROMPT = (
55
  "An astronaut hatches from a fragile egg on the surface of the Moon, "
56
- "the shell cracking and peeling apart in gentle low-gravity motion."
 
 
 
 
57
  )
58
  DEFAULT_HEIGHT = 1024
59
  DEFAULT_WIDTH = 1536
60
  DEFAULT_FRAME_RATE = 24.0
61
 
62
- # Download models from Hugging Face
63
  LTX_MODEL_REPO = "diffusers-internal-dev/ltx-23"
64
- GEMMA_MODEL_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
65
 
 
 
 
 
66
  print("=" * 80)
67
- print("Downloading models from Hugging Face...")
68
  print("=" * 80)
69
 
70
- DISTILLED_CHECKPOINT = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
71
- SPATIAL_UPSAMPLER = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
72
- GEMMA_ROOT = snapshot_download(repo_id=GEMMA_MODEL_REPO)
73
-
74
- print(f"Distilled checkpoint: {DISTILLED_CHECKPOINT}")
75
- print(f"Spatial upsampler: {SPATIAL_UPSAMPLER}")
76
- print(f"Gemma root: {GEMMA_ROOT}")
77
 
78
- # Initialize pipeline
79
- print("=" * 80)
80
- print("Loading LTX-2.3 Distilled pipeline...")
81
- print("=" * 80)
82
 
 
 
83
  pipeline = DistilledPipeline(
84
- distilled_checkpoint_path=DISTILLED_CHECKPOINT,
85
- spatial_upsampler_path=SPATIAL_UPSAMPLER,
86
- gemma_root=GEMMA_ROOT,
87
  loras=[],
88
  quantization=QuantizationPolicy.fp8_cast(),
89
  )
90
 
91
- # Preload all models so first request is fast.
92
- # On ZeroGPU, .to('cuda') is intercepted and actual GPU allocation
93
- # happens inside the @spaces.GPU decorated function.
94
- print("Preloading models...")
95
- ledger = pipeline.model_ledger
96
- _text_encoder = ledger.text_encoder()
97
- _transformer = ledger.transformer()
98
- _video_encoder = ledger.video_encoder()
99
- _video_decoder = ledger.video_decoder()
100
- _audio_decoder = ledger.audio_decoder()
101
- _vocoder = ledger.vocoder()
102
- _spatial_upsampler = ledger.spatial_upsampler()
103
-
104
- ledger.text_encoder = lambda: _text_encoder
105
- ledger.transformer = lambda: _transformer
106
- ledger.video_encoder = lambda: _video_encoder
107
- ledger.video_decoder = lambda: _video_decoder
108
- ledger.audio_decoder = lambda: _audio_decoder
109
- ledger.vocoder = lambda: _vocoder
110
- ledger.spatial_upsampler = lambda: _spatial_upsampler
111
-
112
- print("All models preloaded!")
 
 
113
 
114
 
115
  @spaces.GPU(duration=120, size='xlarge')
116
- @torch.inference_mode()
117
  def generate_video(
118
  input_image,
119
  prompt: str,
120
  duration: float,
121
- enhance_prompt: bool,
122
- seed: int,
123
- randomize_seed: bool,
124
- height: int,
125
- width: int,
126
  progress=gr.Progress(track_tqdm=True),
127
  ):
128
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
129
- num_frames = int(duration * DEFAULT_FRAME_RATE) + 1
130
- num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
131
-
132
- images = []
133
- if input_image is not None:
134
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
135
- temp_path = f.name
 
 
 
 
 
 
 
 
136
  if hasattr(input_image, "save"):
137
- input_image.save(temp_path)
138
  else:
139
- from shutil import copy2
140
- copy2(str(input_image), temp_path)
141
- images = [ImageConditioningInput(path=temp_path, frame_idx=0, strength=1.0)]
142
-
143
- tiling_config = TilingConfig.default()
144
- video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
145
-
146
- video, audio = pipeline(
147
- prompt=prompt,
148
- seed=current_seed,
149
- height=int(height),
150
- width=int(width),
151
- num_frames=num_frames,
152
- frame_rate=DEFAULT_FRAME_RATE,
153
- images=images,
154
- tiling_config=tiling_config,
155
- enhance_prompt=enhance_prompt,
156
- )
157
 
158
- output_path = tempfile.mktemp(suffix=".mp4")
159
- encode_video(
160
- video=video,
161
- fps=DEFAULT_FRAME_RATE,
162
- audio=audio,
163
- output_path=output_path,
164
- video_chunks_number=video_chunks_number,
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- return output_path, current_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
 
170
  with gr.Blocks(title="LTX-2.3 Distilled") as demo:
@@ -180,15 +241,16 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
180
  input_image = gr.Image(label="Input Image (Optional)", type="pil")
181
  prompt = gr.Textbox(
182
  label="Prompt",
183
- value=DEFAULT_PROMPT,
 
184
  lines=3,
185
- placeholder="Describe the video you want to generate...",
186
  )
187
  with gr.Row():
188
- duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=5.0, step=0.5)
189
  enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=True)
190
 
191
- generate_btn = gr.Button("Generate Video", variant="primary")
192
 
193
  with gr.Accordion("Advanced Settings", open=False):
194
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=10, step=1)
@@ -210,5 +272,9 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
210
  )
211
 
212
 
 
 
 
 
213
  if __name__ == "__main__":
214
- demo.launch(share=True)
 
14
  print(f"Cloning {LTX_REPO_URL}...")
15
  subprocess.run(["git", "clone", "--depth", "1", LTX_REPO_URL, LTX_REPO_DIR], check=True)
16
 
 
17
  try:
18
  import ltx_pipelines # noqa: F401
19
  except ImportError:
 
31
  import logging
32
  import random
33
  import tempfile
34
+ from pathlib import Path
35
 
36
  import torch
37
  torch._dynamo.config.suppress_errors = True
 
40
  import spaces
41
  import gradio as gr
42
  import numpy as np
43
+ from gradio_client import Client, handle_file
44
+ from huggingface_hub import hf_hub_download
45
 
46
  from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
47
  from ltx_core.quantization import QuantizationPolicy
 
54
  MAX_SEED = np.iinfo(np.int32).max
55
  DEFAULT_PROMPT = (
56
  "An astronaut hatches from a fragile egg on the surface of the Moon, "
57
+ "the shell cracking and peeling apart in gentle low-gravity motion. "
58
+ "Fine lunar dust lifts and drifts outward with each movement, floating "
59
+ "in slow arcs before settling back onto the ground. The astronaut pushes "
60
+ "free in a deliberate, weightless motion, small fragments of the egg "
61
+ "tumbling and spinning through the air."
62
  )
63
  DEFAULT_HEIGHT = 1024
64
  DEFAULT_WIDTH = 1536
65
  DEFAULT_FRAME_RATE = 24.0
66
 
67
+ # Model repo
68
  LTX_MODEL_REPO = "diffusers-internal-dev/ltx-23"
 
69
 
70
+ # Text encoder space URL - must be a 2.3-compatible text encoder
71
+ TEXT_ENCODER_SPACE = "multimodalart/gemma-text-encoder-ltx23"
72
+
73
+ # Download model checkpoints
74
  print("=" * 80)
75
+ print("Downloading LTX-2.3 distilled model...")
76
  print("=" * 80)
77
 
78
+ checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
79
+ spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
 
 
 
 
 
80
 
81
+ print(f"Checkpoint: {checkpoint_path}")
82
+ print(f"Spatial upsampler: {spatial_upsampler_path}")
 
 
83
 
84
+ # Initialize pipeline WITHOUT text encoder (gemma_root=None)
85
+ # Text encoding will be done by external space
86
  pipeline = DistilledPipeline(
87
+ distilled_checkpoint_path=checkpoint_path,
88
+ spatial_upsampler_path=spatial_upsampler_path,
89
+ gemma_root=None,
90
  loras=[],
91
  quantization=QuantizationPolicy.fp8_cast(),
92
  )
93
 
94
+ # Connect to text encoder space
95
+ print(f"Connecting to text encoder space: {TEXT_ENCODER_SPACE}")
96
+ try:
97
+ text_encoder_client = Client(TEXT_ENCODER_SPACE)
98
+ print("Text encoder client connected!")
99
+ except Exception as e:
100
+ print(f"Warning: Could not connect to text encoder space: {e}")
101
+ text_encoder_client = None
102
+
103
+ print("=" * 80)
104
+ print("Pipeline ready!")
105
+ print("=" * 80)
106
+
107
+
108
+ class PrecomputedTextEncoder(torch.nn.Module):
109
+ """Fake text encoder that returns pre-computed embeddings."""
110
+
111
+ def __init__(self, video_context, audio_context):
112
+ super().__init__()
113
+ self.video_context = video_context
114
+ self.audio_context = audio_context
115
+
116
+ def forward(self, text, padding_side="left"):
117
+ return self.video_context, self.audio_context, None
118
 
119
 
120
  @spaces.GPU(duration=120, size='xlarge')
 
121
  def generate_video(
122
  input_image,
123
  prompt: str,
124
  duration: float,
125
+ enhance_prompt: bool = True,
126
+ seed: int = 42,
127
+ randomize_seed: bool = True,
128
+ height: int = DEFAULT_HEIGHT,
129
+ width: int = DEFAULT_WIDTH,
130
  progress=gr.Progress(track_tqdm=True),
131
  ):
132
+ """Generate a video based on the given parameters."""
133
+ try:
134
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
135
+
136
+ frame_rate = DEFAULT_FRAME_RATE
137
+ num_frames = int(duration * frame_rate) + 1
138
+ # 8k+1 format
139
+ num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
140
+
141
+ # Handle image input
142
+ images = []
143
+ temp_image_path = None
144
+ if input_image is not None:
145
+ output_dir = Path("outputs")
146
+ output_dir.mkdir(exist_ok=True)
147
+ temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
148
  if hasattr(input_image, "save"):
149
+ input_image.save(temp_image_path)
150
  else:
151
+ temp_image_path = Path(input_image)
152
+ images = [ImageConditioningInput(path=str(temp_image_path), frame_idx=0, strength=1.0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ # Get embeddings from text encoder space
155
+ print(f"Encoding prompt: {prompt}")
156
+
157
+ if text_encoder_client is None:
158
+ raise RuntimeError(
159
+ f"Text encoder client not connected. Please ensure the text encoder space "
160
+ f"({TEXT_ENCODER_SPACE}) is running and accessible."
161
+ )
162
+
163
+ try:
164
+ image_input = None
165
+ if temp_image_path is not None:
166
+ image_input = handle_file(str(temp_image_path))
167
+
168
+ result = text_encoder_client.predict(
169
+ prompt=prompt,
170
+ enhance_prompt=enhance_prompt,
171
+ input_image=image_input,
172
+ seed=current_seed,
173
+ negative_prompt="",
174
+ api_name="/encode_prompt",
175
+ )
176
+ embedding_path = result[0]
177
+ print(f"Embeddings received from: {embedding_path}")
178
+
179
+ embeddings = torch.load(embedding_path)
180
+ video_context = embeddings["video_context"].to("cuda")
181
+ audio_context = embeddings["audio_context"].to("cuda")
182
+ print("Embeddings loaded successfully")
183
+ except Exception as e:
184
+ raise RuntimeError(
185
+ f"Failed to get embeddings from text encoder space: {e}\n"
186
+ f"Please ensure {TEXT_ENCODER_SPACE} is running properly."
187
+ )
188
 
189
+ # Patch the model_ledger to return a fake text encoder with pre-computed embeddings
190
+ fake_encoder = PrecomputedTextEncoder(video_context, audio_context)
191
+ original_text_encoder_fn = pipeline.model_ledger.text_encoder
192
+ pipeline.model_ledger.text_encoder = lambda: fake_encoder
193
+
194
+ try:
195
+ tiling_config = TilingConfig.default()
196
+ video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
197
+
198
+ video, audio = pipeline(
199
+ prompt=prompt,
200
+ seed=current_seed,
201
+ height=height,
202
+ width=width,
203
+ num_frames=num_frames,
204
+ frame_rate=frame_rate,
205
+ images=images,
206
+ tiling_config=tiling_config,
207
+ enhance_prompt=False, # Already enhanced by text encoder space
208
+ )
209
+
210
+ output_path = tempfile.mktemp(suffix=".mp4")
211
+ encode_video(
212
+ video=video,
213
+ fps=frame_rate,
214
+ audio=audio,
215
+ output_path=output_path,
216
+ video_chunks_number=video_chunks_number,
217
+ )
218
+
219
+ return str(output_path), current_seed
220
+ finally:
221
+ # Restore original text encoder method
222
+ pipeline.model_ledger.text_encoder = original_text_encoder_fn
223
+
224
+ except Exception as e:
225
+ import traceback
226
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
227
+ print(error_msg)
228
+ return None, current_seed
229
 
230
 
231
  with gr.Blocks(title="LTX-2.3 Distilled") as demo:
 
241
  input_image = gr.Image(label="Input Image (Optional)", type="pil")
242
  prompt = gr.Textbox(
243
  label="Prompt",
244
+ info="for best results - make it as elaborate as possible",
245
+ value="Make this image come alive with cinematic motion, smooth animation",
246
  lines=3,
247
+ placeholder="Describe the motion and animation you want...",
248
  )
249
  with gr.Row():
250
+ duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
251
  enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=True)
252
 
253
+ generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
254
 
255
  with gr.Accordion("Advanced Settings", open=False):
256
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=10, step=1)
 
272
  )
273
 
274
 
275
+ css = """
276
+ .gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
277
+ """
278
+
279
  if __name__ == "__main__":
280
+ demo.launch(theme=gr.themes.Citrus(), css=css)