linoyts HF Staff commited on
Commit
c3f3413
·
verified ·
1 Parent(s): f48dcae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -13
app.py CHANGED
@@ -9,6 +9,9 @@ import numpy as np
9
  import random
10
  import spaces
11
  import gradio as gr
 
 
 
12
  from typing import Optional
13
  from huggingface_hub import hf_hub_download
14
  from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
@@ -33,11 +36,13 @@ DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the
33
 
34
  # HuggingFace Hub defaults
35
  DEFAULT_REPO_ID = "Lightricks/LTX-2"
36
- DEFAULT_GEMMA_REPO_ID = "google/gemma-3-12b-it-qat-q4_0-unquantized"
37
  DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
38
  DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384.safetensors"
39
  DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0.safetensors"
40
 
 
 
 
41
  def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
42
  """Download from HuggingFace Hub or use local checkpoint."""
43
  if repo_id is None and filename is None:
@@ -68,24 +73,36 @@ print(f"Initializing pipeline with:")
68
  print(f" checkpoint_path={checkpoint_path}")
69
  print(f" distilled_lora_path={distilled_lora_path}")
70
  print(f" spatial_upsampler_path={spatial_upsampler_path}")
71
- print(f" gemma_root={DEFAULT_GEMMA_REPO_ID}")
72
 
 
 
73
  pipeline = TI2VidTwoStagesPipeline(
74
  checkpoint_path=checkpoint_path,
75
  distilled_lora_path=distilled_lora_path,
76
  distilled_lora_strength=DEFAULT_LORA_STRENGTH,
77
  spatial_upsampler_path=spatial_upsampler_path,
78
- gemma_root=DEFAULT_GEMMA_REPO_ID,
79
  loras=[],
80
  fp8transformer=False,
81
  local_files_only=False
82
  )
83
 
 
 
 
 
 
 
 
 
 
84
  @spaces.GPU(duration=300)
85
  def generate_video(
86
  input_image,
87
  prompt: str,
88
  duration: float,
 
89
  negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
90
  seed: int = 42,
91
  randomize_seed: bool = True,
@@ -107,20 +124,56 @@ def generate_video(
107
  # Create output directory if it doesn't exist
108
  output_dir = Path("outputs")
109
  output_dir.mkdir(exist_ok=True)
110
- output_path = output_dir / f"video_{seed}.mp4"
111
 
112
  # Handle image input
113
  images = []
 
114
  if input_image is not None:
115
  # Save uploaded image temporarily
116
- temp_image_path = output_dir / f"temp_input_{seed}.jpg"
117
  if hasattr(input_image, 'save'):
118
  input_image.save(temp_image_path)
119
  else:
120
  # If it's a file path already
121
- temp_image_path = input_image
122
  # Format: (image_path, frame_idx, strength)
123
  images = [(str(temp_image_path), 0, 1.0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Run inference - progress automatically tracks tqdm from pipeline
126
  pipeline(
@@ -136,6 +189,8 @@ def generate_video(
136
  cfg_guidance_scale=cfg_guidance_scale,
137
  images=images,
138
  tiling_config=TilingConfig.default(),
 
 
139
  )
140
 
141
  return str(output_path), current_seed
@@ -166,13 +221,18 @@ with gr.Blocks(title="LTX-2 Video 🎥🔈") as demo:
166
  placeholder="Describe the motion and animation you want..."
167
  )
168
 
169
- duration = gr.Slider(
170
- label="Duration (seconds)",
171
- minimum=1.0,
172
- maximum=10.0,
173
- value=3.0,
174
- step=0.1
175
- )
 
 
 
 
 
176
 
177
  generate_btn = gr.Button("Generate Video", variant="primary")
178
 
 
9
  import random
10
  import spaces
11
  import gradio as gr
12
+ from gradio_client import Client, handle_file
13
+ import torch
14
+ from pathlib import Path
15
  from typing import Optional
16
  from huggingface_hub import hf_hub_download
17
  from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
 
36
 
37
  # HuggingFace Hub defaults
38
  DEFAULT_REPO_ID = "Lightricks/LTX-2"
 
39
  DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
40
  DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384.safetensors"
41
  DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0.safetensors"
42
 
43
+ # Text encoder space URL
44
+ TEXT_ENCODER_SPACE = "linoyts/gemma-text-encoder"
45
+
46
  def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
47
  """Download from HuggingFace Hub or use local checkpoint."""
48
  if repo_id is None and filename is None:
 
73
  print(f" checkpoint_path={checkpoint_path}")
74
  print(f" distilled_lora_path={distilled_lora_path}")
75
  print(f" spatial_upsampler_path={spatial_upsampler_path}")
76
+ print(f" text_encoder_space={TEXT_ENCODER_SPACE}")
77
 
78
+ # Initialize pipeline WITHOUT text encoder (gemma_root=None)
79
+ # Text encoding will be done by external space
80
  pipeline = TI2VidTwoStagesPipeline(
81
  checkpoint_path=checkpoint_path,
82
  distilled_lora_path=distilled_lora_path,
83
  distilled_lora_strength=DEFAULT_LORA_STRENGTH,
84
  spatial_upsampler_path=spatial_upsampler_path,
85
+ gemma_root=None,
86
  loras=[],
87
  fp8transformer=False,
88
  local_files_only=False
89
  )
90
 
91
+ # Initialize text encoder client
92
+ print(f"Connecting to text encoder space: {TEXT_ENCODER_SPACE}")
93
+ try:
94
+ text_encoder_client = Client(TEXT_ENCODER_SPACE)
95
+ print("✓ Text encoder client connected!")
96
+ except Exception as e:
97
+ print(f"⚠ Warning: Could not connect to text encoder space: {e}")
98
+ text_encoder_client = None
99
+
100
  @spaces.GPU(duration=300)
101
  def generate_video(
102
  input_image,
103
  prompt: str,
104
  duration: float,
105
+ enhance_prompt: bool = True,
106
  negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
107
  seed: int = 42,
108
  randomize_seed: bool = True,
 
124
  # Create output directory if it doesn't exist
125
  output_dir = Path("outputs")
126
  output_dir.mkdir(exist_ok=True)
127
+ output_path = output_dir / f"video_{current_seed}.mp4"
128
 
129
  # Handle image input
130
  images = []
131
+ temp_image_path = None # Initialize to None
132
  if input_image is not None:
133
  # Save uploaded image temporarily
134
+ temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
135
  if hasattr(input_image, 'save'):
136
  input_image.save(temp_image_path)
137
  else:
138
  # If it's a file path already
139
+ temp_image_path = Path(input_image)
140
  # Format: (image_path, frame_idx, strength)
141
  images = [(str(temp_image_path), 0, 1.0)]
142
+ # Get embeddings from text encoder space
143
+ print(f"Encoding prompt: {prompt}")
144
+
145
+ if text_encoder_client is None:
146
+ raise RuntimeError(
147
+ f"Text encoder client not connected. Please ensure the text encoder space "
148
+ f"({TEXT_ENCODER_SPACE}) is running and accessible."
149
+ )
150
+
151
+ try:
152
+ # Prepare image for upload if it exists
153
+ image_input = None
154
+ if temp_image_path is not None:
155
+ image_input = handle_file(str(temp_image_path))
156
+
157
+ result = text_encoder_client.predict(
158
+ prompt=prompt,
159
+ enhance_prompt=enhance_prompt,
160
+ input_image=image_input,
161
+ seed=current_seed,
162
+ api_name="/encode_prompt"
163
+ )
164
+ embedding_path = result[0] # Path to .pt file
165
+ print(f"Embeddings received from: {embedding_path}")
166
+
167
+ # Load embeddings
168
+ embeddings = torch.load(embedding_path)
169
+ video_context = embeddings['video_context']
170
+ audio_context = embeddings['audio_context']
171
+ print("✓ Embeddings loaded successfully")
172
+ except Exception as e:
173
+ raise RuntimeError(
174
+ f"Failed to get embeddings from text encoder space: {e}\n"
175
+ f"Please ensure {TEXT_ENCODER_SPACE} is running properly."
176
+ )
177
 
178
  # Run inference - progress automatically tracks tqdm from pipeline
179
  pipeline(
 
189
  cfg_guidance_scale=cfg_guidance_scale,
190
  images=images,
191
  tiling_config=TilingConfig.default(),
192
+ video_context=video_context,
193
+ audio_context=audio_context,
194
  )
195
 
196
  return str(output_path), current_seed
 
221
  placeholder="Describe the motion and animation you want..."
222
  )
223
 
224
+ with gr.Row():
225
+ duration = gr.Slider(
226
+ label="Duration (seconds)",
227
+ minimum=1.0,
228
+ maximum=10.0,
229
+ value=3.0,
230
+ step=0.1
231
+ )
232
+ enhance_prompt = gr.Checkbox(
233
+ label="Enhance Prompt",
234
+ value=True
235
+ )
236
 
237
  generate_btn = gr.Button("Generate Video", variant="primary")
238