Vicente Alvarez commited on
Commit
b1a127d
·
1 Parent(s): 824f9f7

Switch to DistilledPipeline with pre-distilled sulphur_distil_bf16 checkpoint

Browse files
Files changed (1) hide show
  1. app.py +23 -82
app.py CHANGED
@@ -61,8 +61,7 @@ import gradio as gr
61
  import numpy as np
62
  from huggingface_hub import hf_hub_download, snapshot_download
63
 
64
- from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps
65
- from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline
66
  from ltx_pipelines.utils.args import ImageConditioningInput
67
  from ltx_pipelines.utils.media_io import encode_video
68
 
@@ -111,21 +110,17 @@ RESOLUTIONS = {
111
 
112
  # Model repos
113
  CHECKPOINT_REPO = "SulphurAI/Sulphur-2-base"
114
- DISTILL_LORA_REPO = "SulphurAI/Sulphur-2-base"
115
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
116
  GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
117
 
118
  # Download model checkpoints in parallel for speed
119
  print("=" * 80)
120
- print("Downloading Element-16 dev + distill LoRA + Gemma (parallel)...")
121
  print("=" * 80)
122
 
123
  def download_checkpoint():
124
- return hf_hub_download(repo_id=CHECKPOINT_REPO, filename="sulphur_dev_fp8mixed.safetensors")
125
-
126
- def download_lora():
127
- # Skip distill LoRA for fp8 - not compatible with mxfp8mixed format
128
- return None
129
 
130
  def download_upsampler():
131
  return hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
@@ -145,58 +140,27 @@ with ThreadPoolExecutor(max_workers=3) as executor:
145
  print(f"Checkpoint: {checkpoint_path}")
146
  print(f"Spatial upsampler: {spatial_upsampler_path}")
147
  print(f"Gemma root: {gemma_root}")
148
- print("Note: Using fp8 without distill LoRA - will use more inference steps")
149
-
150
- # Initialize pipeline with fp8 checkpoint (no distill LoRA for fp8 compatibility)
151
- pipeline = TI2VidTwoStagesHQPipeline(
152
- checkpoint_path=checkpoint_path,
153
- distilled_lora=[],
154
- distilled_lora_strength_stage_1=0.0,
155
- distilled_lora_strength_stage_2=0.0,
156
  spatial_upsampler_path=spatial_upsampler_path,
157
  gemma_root=gemma_root,
158
  loras=(),
159
  )
160
 
161
- # Preload all models for ZeroGPU tensor packing (BOTH stages!)
162
- print("Preloading all models (including Gemma and audio components)...")
163
-
164
- # Stage 1 models
165
- stage_1_ledger = pipeline.stage_1_model_ledger
166
- _transformer = stage_1_ledger.transformer()
167
- _video_encoder = stage_1_ledger.video_encoder()
168
- _video_decoder = stage_1_ledger.video_decoder()
169
- _audio_encoder = stage_1_ledger.audio_encoder()
170
- _audio_decoder = stage_1_ledger.audio_decoder()
171
- _vocoder = stage_1_ledger.vocoder()
172
- _spatial_upsampler_1 = stage_1_ledger.spatial_upsampler()
173
- _text_encoder = stage_1_ledger.text_encoder()
174
- _embeddings_processor = stage_1_ledger.gemma_embeddings_processor()
175
-
176
- stage_1_ledger.transformer = lambda: _transformer
177
- stage_1_ledger.video_encoder = lambda: _video_encoder
178
- stage_1_ledger.video_decoder = lambda: _video_decoder
179
- stage_1_ledger.audio_encoder = lambda: _audio_encoder
180
- stage_1_ledger.audio_decoder = lambda: _audio_decoder
181
- stage_1_ledger.vocoder = lambda: _vocoder
182
- stage_1_ledger.spatial_upsampler = lambda: _spatial_upsampler_1
183
- stage_1_ledger.text_encoder = lambda: _text_encoder
184
- stage_1_ledger.gemma_embeddings_processor = lambda: _embeddings_processor
185
-
186
- # Stage 2 models (critical - spatial upsampler is used here!)
187
- print("Preloading stage 2 models...")
188
- stage_2_ledger = pipeline.stage_2_model_ledger
189
- _spatial_upsampler_2 = stage_2_ledger.spatial_upsampler()
190
- _transformer_2 = stage_2_ledger.transformer()
191
- _video_encoder_2 = stage_2_ledger.video_encoder()
192
- _video_decoder_2 = stage_2_ledger.video_decoder()
193
-
194
- stage_2_ledger.spatial_upsampler = lambda: _spatial_upsampler_2
195
- stage_2_ledger.transformer = lambda: _transformer_2
196
- stage_2_ledger.video_encoder = lambda: _video_encoder_2
197
- stage_2_ledger.video_decoder = lambda: _video_decoder_2
198
-
199
- print("All models preloaded (stage 1 + stage 2)!")
200
 
201
  print("=" * 80)
202
  print("Pipeline ready!")
@@ -244,7 +208,7 @@ def on_highres_toggle(first_image, last_image, high_res):
244
  DEFAULT_NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走, blurry, glasses, deformed, subtitles, text, captions, worst quality, low quality, inconsistent motion, jittery, distorted"
245
 
246
 
247
- @spaces.GPU(duration=120) # More time needed for 30 inference steps
248
  @torch.inference_mode()
249
  def generate_video(
250
  first_image,
@@ -291,7 +255,6 @@ def generate_video(
291
  temp_last_path = Path(last_image)
292
  images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
293
 
294
- from ltx_core.components.guiders import MultiModalGuiderParams
295
  from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
296
 
297
  tiling_config = TilingConfig.default()
@@ -299,38 +262,16 @@ def generate_video(
299
 
300
  log_memory("before pipeline call")
301
 
302
- # Configure guider params
303
- video_guider_params = MultiModalGuiderParams(
304
- cfg_scale=3.0,
305
- stg_scale=0.0,
306
- rescale_scale=0.45,
307
- modality_scale=3.0,
308
- skip_step=0,
309
- stg_blocks=[],
310
- )
311
-
312
- audio_guider_params = MultiModalGuiderParams(
313
- cfg_scale=7.0,
314
- stg_scale=0.0,
315
- rescale_scale=1.0,
316
- modality_scale=3.0,
317
- skip_step=0,
318
- stg_blocks=[],
319
- )
320
-
321
- # Run inference - returns (video_frames_iter, audio)
322
  video_frames_iter, audio = pipeline(
323
  prompt=prompt,
324
- negative_prompt=negative_prompt,
325
  seed=current_seed,
326
  height=int(height),
327
  width=int(width),
328
  num_frames=num_frames,
329
  frame_rate=frame_rate,
330
- num_inference_steps=30, # More steps needed without distill LoRA
331
- video_guider_params=video_guider_params,
332
- audio_guider_params=audio_guider_params,
333
  images=images,
 
334
  )
335
 
336
  # Collect video frames
 
61
  import numpy as np
62
  from huggingface_hub import hf_hub_download, snapshot_download
63
 
64
+ from ltx_pipelines.distilled import DistilledPipeline
 
65
  from ltx_pipelines.utils.args import ImageConditioningInput
66
  from ltx_pipelines.utils.media_io import encode_video
67
 
 
110
 
111
  # Model repos
112
  CHECKPOINT_REPO = "SulphurAI/Sulphur-2-base"
 
113
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
114
  GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
115
 
116
  # Download model checkpoints in parallel for speed
117
  print("=" * 80)
118
+ print("Downloading Element-16 (pre-distilled) + Gemma (parallel)...")
119
  print("=" * 80)
120
 
121
  def download_checkpoint():
122
+ # Use pre-distilled checkpoint - no LoRA needed
123
+ return hf_hub_download(repo_id=CHECKPOINT_REPO, filename="sulphur_distil_bf16.safetensors")
 
 
 
124
 
125
  def download_upsampler():
126
  return hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
 
140
  print(f"Checkpoint: {checkpoint_path}")
141
  print(f"Spatial upsampler: {spatial_upsampler_path}")
142
  print(f"Gemma root: {gemma_root}")
143
+
144
+ # Initialize pipeline with pre-distilled checkpoint (no LoRA needed)
145
+ pipeline = DistilledPipeline(
146
+ distilled_checkpoint_path=checkpoint_path,
 
 
 
 
147
  spatial_upsampler_path=spatial_upsampler_path,
148
  gemma_root=gemma_root,
149
  loras=(),
150
  )
151
 
152
+ # Preload all models for ZeroGPU tensor packing
153
+ print("Preloading all pipeline components...")
154
+
155
+ # DistilledPipeline components are already instantiated, just access them to ensure loaded
156
+ _ = pipeline.prompt_encoder
157
+ _ = pipeline.image_conditioner
158
+ _ = pipeline.stage
159
+ _ = pipeline.upsampler
160
+ _ = pipeline.video_decoder
161
+ _ = pipeline.audio_decoder
162
+
163
+ print("All models preloaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  print("=" * 80)
166
  print("Pipeline ready!")
 
208
  DEFAULT_NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走, blurry, glasses, deformed, subtitles, text, captions, worst quality, low quality, inconsistent motion, jittery, distorted"
209
 
210
 
211
+ @spaces.GPU(duration=90)
212
  @torch.inference_mode()
213
  def generate_video(
214
  first_image,
 
255
  temp_last_path = Path(last_image)
256
  images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
257
 
 
258
  from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
259
 
260
  tiling_config = TilingConfig.default()
 
262
 
263
  log_memory("before pipeline call")
264
 
265
+ # Run inference - DistilledPipeline has simpler API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  video_frames_iter, audio = pipeline(
267
  prompt=prompt,
 
268
  seed=current_seed,
269
  height=int(height),
270
  width=int(width),
271
  num_frames=num_frames,
272
  frame_rate=frame_rate,
 
 
 
273
  images=images,
274
+ enhance_prompt=enhance_prompt,
275
  )
276
 
277
  # Collect video frames