Vicente Alvarez commited on
Commit
e162f46
·
1 Parent(s): ce3e28e

Sulphur dev + distill LoRA with TI2VidTwoStagesHQPipeline

Browse files
Files changed (1) hide show
  1. app.py +78 -211
app.py CHANGED
@@ -52,24 +52,10 @@ import gradio as gr
52
  import numpy as np
53
  from huggingface_hub import hf_hub_download, snapshot_download
54
 
55
- from ltx_core.components.diffusion_steps import EulerDiffusionStep
56
- from ltx_core.components.noisers import GaussianNoiser
57
- from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
58
- from ltx_core.model.upsampler import upsample_video
59
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as vae_decode_video
60
- from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
61
- from ltx_pipelines.distilled import DistilledPipeline
62
- from ltx_pipelines.utils import euler_denoising_loop
63
  from ltx_pipelines.utils.args import ImageConditioningInput
64
- from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
65
- from ltx_pipelines.utils.helpers import (
66
- cleanup_memory,
67
- combined_image_conditionings,
68
- denoise_video_only,
69
- encode_prompts,
70
- simple_denoising_func,
71
- )
72
- from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
73
 
74
  # Patch attention backend into the LTX attention module.
75
  import torch.nn.functional as F
@@ -114,217 +100,69 @@ RESOLUTIONS = {
114
  }
115
 
116
 
117
- class LTX23DistilledA2VPipeline(DistilledPipeline):
118
- """DistilledPipeline with optional audio conditioning."""
119
-
120
- def __call__(
121
- self,
122
- prompt: str,
123
- seed: int,
124
- height: int,
125
- width: int,
126
- num_frames: int,
127
- frame_rate: float,
128
- images: list[ImageConditioningInput],
129
- audio_path: str | None = None,
130
- tiling_config: TilingConfig | None = None,
131
- enhance_prompt: bool = False,
132
- ):
133
- # Standard path when no audio input is provided.
134
- if audio_path is None:
135
- return super().__call__(
136
- prompt=prompt,
137
- seed=seed,
138
- height=height,
139
- width=width,
140
- num_frames=num_frames,
141
- frame_rate=frame_rate,
142
- images=images,
143
- tiling_config=tiling_config,
144
- enhance_prompt=enhance_prompt,
145
- )
146
-
147
- generator = torch.Generator(device=self.device).manual_seed(seed)
148
- noiser = GaussianNoiser(generator=generator)
149
- stepper = EulerDiffusionStep()
150
- dtype = torch.bfloat16
151
-
152
- (ctx_p,) = encode_prompts(
153
- [prompt],
154
- self.model_ledger,
155
- enhance_first_prompt=enhance_prompt,
156
- enhance_prompt_image=images[0].path if len(images) > 0 else None,
157
- )
158
- video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
159
-
160
- video_duration = num_frames / frame_rate
161
- decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
162
- if decoded_audio is None:
163
- raise ValueError(f"Could not extract audio stream from {audio_path}")
164
-
165
- encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
166
- audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
167
- expected_frames = audio_shape.frames
168
- actual_frames = encoded_audio_latent.shape[2]
169
-
170
- if actual_frames > expected_frames:
171
- encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
172
- elif actual_frames < expected_frames:
173
- pad = torch.zeros(
174
- encoded_audio_latent.shape[0],
175
- encoded_audio_latent.shape[1],
176
- expected_frames - actual_frames,
177
- encoded_audio_latent.shape[3],
178
- device=encoded_audio_latent.device,
179
- dtype=encoded_audio_latent.dtype,
180
- )
181
- encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
182
-
183
- video_encoder = self.model_ledger.video_encoder()
184
- transformer = self.model_ledger.transformer()
185
- stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
186
-
187
- def denoising_loop(sigmas, video_state, audio_state, stepper):
188
- return euler_denoising_loop(
189
- sigmas=sigmas,
190
- video_state=video_state,
191
- audio_state=audio_state,
192
- stepper=stepper,
193
- denoise_fn=simple_denoising_func(
194
- video_context=video_context,
195
- audio_context=audio_context,
196
- transformer=transformer,
197
- ),
198
- )
199
-
200
- stage_1_output_shape = VideoPixelShape(
201
- batch=1,
202
- frames=num_frames,
203
- width=width // 2,
204
- height=height // 2,
205
- fps=frame_rate,
206
- )
207
- stage_1_conditionings = combined_image_conditionings(
208
- images=images,
209
- height=stage_1_output_shape.height,
210
- width=stage_1_output_shape.width,
211
- video_encoder=video_encoder,
212
- dtype=dtype,
213
- device=self.device,
214
- )
215
- video_state = denoise_video_only(
216
- output_shape=stage_1_output_shape,
217
- conditionings=stage_1_conditionings,
218
- noiser=noiser,
219
- sigmas=stage_1_sigmas,
220
- stepper=stepper,
221
- denoising_loop_fn=denoising_loop,
222
- components=self.pipeline_components,
223
- dtype=dtype,
224
- device=self.device,
225
- initial_audio_latent=encoded_audio_latent,
226
- )
227
-
228
- torch.cuda.synchronize()
229
- cleanup_memory()
230
-
231
- upscaled_video_latent = upsample_video(
232
- latent=video_state.latent[:1],
233
- video_encoder=video_encoder,
234
- upsampler=self.model_ledger.spatial_upsampler(),
235
- )
236
- stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
237
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
238
- stage_2_conditionings = combined_image_conditionings(
239
- images=images,
240
- height=stage_2_output_shape.height,
241
- width=stage_2_output_shape.width,
242
- video_encoder=video_encoder,
243
- dtype=dtype,
244
- device=self.device,
245
- )
246
- video_state = denoise_video_only(
247
- output_shape=stage_2_output_shape,
248
- conditionings=stage_2_conditionings,
249
- noiser=noiser,
250
- sigmas=stage_2_sigmas,
251
- stepper=stepper,
252
- denoising_loop_fn=denoising_loop,
253
- components=self.pipeline_components,
254
- dtype=dtype,
255
- device=self.device,
256
- noise_scale=stage_2_sigmas[0],
257
- initial_video_latent=upscaled_video_latent,
258
- initial_audio_latent=encoded_audio_latent,
259
- )
260
-
261
- torch.cuda.synchronize()
262
- del transformer
263
- del video_encoder
264
- cleanup_memory()
265
-
266
- decoded_video = vae_decode_video(
267
- video_state.latent,
268
- self.model_ledger.video_decoder(),
269
- tiling_config,
270
- generator,
271
- )
272
- original_audio = Audio(
273
- waveform=decoded_audio.waveform.squeeze(0),
274
- sampling_rate=decoded_audio.sampling_rate,
275
- )
276
- return decoded_video, original_audio
277
-
278
-
279
  # Model repos
280
- CHECKPOINT_REPO = "Civitai/Sulphur-2-distilled-fp8"
 
281
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
282
  GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
283
 
284
  # Download model checkpoints
285
  print("=" * 80)
286
- print("Downloading Element-16 fp8 model + Gemma...")
287
  print("=" * 80)
288
 
289
- checkpoint_path = hf_hub_download(repo_id=CHECKPOINT_REPO, filename="sulphur_distil_fp8mixed.safetensors")
 
290
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
291
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
292
 
293
  print(f"Checkpoint: {checkpoint_path}")
 
294
  print(f"Spatial upsampler: {spatial_upsampler_path}")
295
  print(f"Gemma root: {gemma_root}")
296
 
297
- # Initialize pipeline WITH text encoder and optional audio support
298
- # No fp8_cast needed - checkpoint is already fp8 quantized
299
- pipeline = LTX23DistilledA2VPipeline(
300
- distilled_checkpoint_path=checkpoint_path,
 
 
 
 
 
 
 
 
 
 
 
301
  spatial_upsampler_path=spatial_upsampler_path,
302
  gemma_root=gemma_root,
303
- loras=[],
304
  )
305
 
306
  # Preload all models for ZeroGPU tensor packing.
307
  print("Preloading all models (including Gemma and audio components)...")
308
- ledger = pipeline.model_ledger
309
- _transformer = ledger.transformer()
310
- _video_encoder = ledger.video_encoder()
311
- _video_decoder = ledger.video_decoder()
312
- _audio_encoder = ledger.audio_encoder()
313
- _audio_decoder = ledger.audio_decoder()
314
- _vocoder = ledger.vocoder()
315
- _spatial_upsampler = ledger.spatial_upsampler()
316
- _text_encoder = ledger.text_encoder()
317
- _embeddings_processor = ledger.gemma_embeddings_processor()
318
-
319
- ledger.transformer = lambda: _transformer
320
- ledger.video_encoder = lambda: _video_encoder
321
- ledger.video_decoder = lambda: _video_decoder
322
- ledger.audio_encoder = lambda: _audio_encoder
323
- ledger.audio_decoder = lambda: _audio_decoder
324
- ledger.vocoder = lambda: _vocoder
325
- ledger.spatial_upsampler = lambda: _spatial_upsampler
326
- ledger.text_encoder = lambda: _text_encoder
327
- ledger.gemma_embeddings_processor = lambda: _embeddings_processor
328
  print("All models preloaded (including Gemma text encoder and audio encoder)!")
329
 
330
  print("=" * 80)
@@ -420,28 +258,57 @@ def generate_video(
420
  temp_last_path = Path(last_image)
421
  images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
422
 
 
 
 
423
  tiling_config = TilingConfig.default()
424
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
425
 
426
  log_memory("before pipeline call")
427
 
428
- video, audio = pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  prompt=prompt,
 
430
  seed=current_seed,
431
  height=int(height),
432
  width=int(width),
433
  num_frames=num_frames,
434
  frame_rate=frame_rate,
 
 
 
435
  images=images,
436
- tiling_config=tiling_config,
437
- enhance_prompt=enhance_prompt,
438
  )
439
 
 
 
 
 
440
  log_memory("after pipeline call")
441
 
442
  output_path = tempfile.mktemp(suffix=".mp4")
443
  encode_video(
444
- video=video,
445
  fps=frame_rate,
446
  audio=audio,
447
  output_path=output_path,
 
52
  import numpy as np
53
  from huggingface_hub import hf_hub_download, snapshot_download
54
 
55
+ from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps
56
+ from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline
 
 
 
 
 
 
57
  from ltx_pipelines.utils.args import ImageConditioningInput
58
+ from ltx_pipelines.utils.media_io import encode_video
 
 
 
 
 
 
 
 
59
 
60
  # Patch attention backend into the LTX attention module.
61
  import torch.nn.functional as F
 
100
  }
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # Model repos
104
+ CHECKPOINT_REPO = "SulphurAI/Sulphur-2-base"
105
+ DISTILL_LORA_REPO = "SulphurAI/Sulphur-2-base"
106
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
107
  GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
108
 
109
  # Download model checkpoints
110
  print("=" * 80)
111
+ print("Downloading Element-16 dev + distill LoRA + Gemma...")
112
  print("=" * 80)
113
 
114
+ checkpoint_path = hf_hub_download(repo_id=CHECKPOINT_REPO, filename="sulphur_dev_fp8mixed.safetensors")
115
+ distilled_lora_path = hf_hub_download(repo_id=DISTILL_LORA_REPO, filename="distill_loras/ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors")
116
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
117
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
118
 
119
  print(f"Checkpoint: {checkpoint_path}")
120
+ print(f"Distilled LoRA: {distilled_lora_path}")
121
  print(f"Spatial upsampler: {spatial_upsampler_path}")
122
  print(f"Gemma root: {gemma_root}")
123
 
124
+ # Create distilled LoRA entry
125
+ distilled_lora = [
126
+ LoraPathStrengthAndSDOps(
127
+ distilled_lora_path,
128
+ 1.0,
129
+ LTXV_LORA_COMFY_RENAMING_MAP,
130
+ ),
131
+ ]
132
+
133
+ # Initialize pipeline with dev checkpoint + distilled LoRA
134
+ pipeline = TI2VidTwoStagesHQPipeline(
135
+ checkpoint_path=checkpoint_path,
136
+ distilled_lora=distilled_lora,
137
+ distilled_lora_strength_stage_1=0.25,
138
+ distilled_lora_strength_stage_2=0.5,
139
  spatial_upsampler_path=spatial_upsampler_path,
140
  gemma_root=gemma_root,
141
+ loras=(),
142
  )
143
 
144
  # Preload all models for ZeroGPU tensor packing.
145
  print("Preloading all models (including Gemma and audio components)...")
146
+ stage_1_ledger = pipeline.stage_1_model_ledger
147
+ _transformer = stage_1_ledger.transformer()
148
+ _video_encoder = stage_1_ledger.video_encoder()
149
+ _video_decoder = stage_1_ledger.video_decoder()
150
+ _audio_encoder = stage_1_ledger.audio_encoder()
151
+ _audio_decoder = stage_1_ledger.audio_decoder()
152
+ _vocoder = stage_1_ledger.vocoder()
153
+ _spatial_upsampler = stage_1_ledger.spatial_upsampler()
154
+ _text_encoder = stage_1_ledger.text_encoder()
155
+ _embeddings_processor = stage_1_ledger.gemma_embeddings_processor()
156
+
157
+ stage_1_ledger.transformer = lambda: _transformer
158
+ stage_1_ledger.video_encoder = lambda: _video_encoder
159
+ stage_1_ledger.video_decoder = lambda: _video_decoder
160
+ stage_1_ledger.audio_encoder = lambda: _audio_encoder
161
+ stage_1_ledger.audio_decoder = lambda: _audio_decoder
162
+ stage_1_ledger.vocoder = lambda: _vocoder
163
+ stage_1_ledger.spatial_upsampler = lambda: _spatial_upsampler
164
+ stage_1_ledger.text_encoder = lambda: _text_encoder
165
+ stage_1_ledger.gemma_embeddings_processor = lambda: _embeddings_processor
166
  print("All models preloaded (including Gemma text encoder and audio encoder)!")
167
 
168
  print("=" * 80)
 
258
  temp_last_path = Path(last_image)
259
  images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
260
 
261
+ from ltx_core.components.guiders import MultiModalGuiderParams
262
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
263
+
264
  tiling_config = TilingConfig.default()
265
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
266
 
267
  log_memory("before pipeline call")
268
 
269
+ # Configure guider params
270
+ video_guider_params = MultiModalGuiderParams(
271
+ cfg_scale=3.0,
272
+ stg_scale=0.0,
273
+ rescale_scale=0.45,
274
+ modality_scale=3.0,
275
+ skip_step=0,
276
+ stg_blocks=[],
277
+ )
278
+
279
+ audio_guider_params = MultiModalGuiderParams(
280
+ cfg_scale=7.0,
281
+ stg_scale=0.0,
282
+ rescale_scale=1.0,
283
+ modality_scale=3.0,
284
+ skip_step=0,
285
+ stg_blocks=[],
286
+ )
287
+
288
+ # Run inference - returns (video_frames_iter, audio)
289
+ video_frames_iter, audio = pipeline(
290
  prompt=prompt,
291
+ negative_prompt=negative_prompt,
292
  seed=current_seed,
293
  height=int(height),
294
  width=int(width),
295
  num_frames=num_frames,
296
  frame_rate=frame_rate,
297
+ num_inference_steps=15,
298
+ video_guider_params=video_guider_params,
299
+ audio_guider_params=audio_guider_params,
300
  images=images,
 
 
301
  )
302
 
303
+ # Collect video frames
304
+ frames = [frame for frame in video_frames_iter]
305
+ video_tensor = torch.cat(frames, dim=0) if len(frames) > 1 else frames[0]
306
+
307
  log_memory("after pipeline call")
308
 
309
  output_path = tempfile.mktemp(suffix=".mp4")
310
  encode_video(
311
+ video=video_tensor,
312
  fps=frame_rate,
313
  audio=audio,
314
  output_path=output_path,