dagloop5 commited on
Commit
4f31ee8
Β·
verified Β·
1 Parent(s): 884d0d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -57
app.py CHANGED
@@ -217,7 +217,7 @@ class LTX23DistilledA2VPipeline:
217
 
218
  # Stage 1: Generate sigmas using LTX2Scheduler with user-specified steps
219
  empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(
220
- VideoPixelShape(batch=1, frames=num_frames, width=width // 2, height=height // 2, fps=frame_rate)
221
  ).to_torch_shape())
222
  stage_1_sigmas = (
223
  LTX2Scheduler()
@@ -246,25 +246,12 @@ class LTX23DistilledA2VPipeline:
246
  ),
247
  )
248
 
249
- def stage2_denoising_loop(sigmas: torch.Tensor, video_state, audio_state, stepper: DiffusionStepProtocol):
250
- return euler_denoising_loop(
251
- sigmas=sigmas,
252
- video_state=video_state,
253
- audio_state=audio_state,
254
- stepper=stepper,
255
- denoise_fn=simple_denoising_func(
256
- video_context=v_context_p,
257
- audio_context=a_context_p,
258
- transformer=transformer, # noqa: F821
259
- ),
260
- )
261
-
262
  # ── Stage 1: Half resolution ──
263
  stage_1_output_shape = VideoPixelShape(
264
  batch=1,
265
  frames=num_frames,
266
- width=width // 2,
267
- height=height // 2,
268
  fps=frame_rate,
269
  )
270
  stage_1_conditionings = combined_image_conditionings(
@@ -294,42 +281,6 @@ class LTX23DistilledA2VPipeline:
294
  torch.cuda.synchronize()
295
  # cleanup_memory()
296
 
297
- # ── Upscaling ──
298
- upscaled_video_latent = upsample_video(
299
- latent=video_state.latent[:1],
300
- video_encoder=video_encoder,
301
- upsampler=self.model_ledger.spatial_upsampler(),
302
- )
303
-
304
- # ── Stage 2: Full resolution ──
305
- stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
306
- stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
307
- stage_2_conditionings = combined_image_conditionings(
308
- images=images,
309
- height=stage_2_output_shape.height,
310
- width=stage_2_output_shape.width,
311
- video_encoder=video_encoder,
312
- dtype=dtype,
313
- device=self.device,
314
- )
315
- video_state, audio_state = denoise_audio_video(
316
- output_shape=stage_2_output_shape,
317
- conditionings=stage_2_conditionings,
318
- noiser=noiser,
319
- sigmas=stage_2_sigmas,
320
- stepper=stepper,
321
- denoising_loop_fn=stage2_denoising_loop,
322
- components=self.pipeline_components,
323
- dtype=dtype,
324
- device=self.device,
325
- noise_scale=stage_2_sigmas[0],
326
- initial_video_latent=upscaled_video_latent,
327
- initial_audio_latent=audio_state.latent,
328
- )
329
-
330
- torch.cuda.synchronize()
331
- # cleanup_memory()
332
-
333
  # ── Decode both video and audio ──
334
  decoded_video = vae_decode_video(
335
  video_state.latent,
@@ -346,7 +297,7 @@ class LTX23DistilledA2VPipeline:
346
  return decoded_video, decoded_audio_output
347
 
348
  # Model repos
349
- LTX_MODEL_REPO = "SulphurAI/Sulphur-2-base"
350
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
351
 
352
  # Download model checkpoints
@@ -367,10 +318,11 @@ weights_dir = Path("weights")
367
  weights_dir.mkdir(exist_ok=True)
368
  checkpoint_path = hf_hub_download(
369
  repo_id=LTX_MODEL_REPO,
370
- filename="sulphur_distil_bf16.safetensors",
371
  local_dir=str(weights_dir),
372
  local_dir_use_symlinks=False,
373
  )
 
374
  spatial_upsampler_path = hf_hub_download(repo_id="Lightricks/LTX-2.3", filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
375
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
376
 
@@ -607,7 +559,6 @@ _orig_video_decoder_factory = ledger.video_decoder
607
  _orig_audio_encoder_factory = ledger.audio_encoder
608
  _orig_audio_decoder_factory = ledger.audio_decoder
609
  _orig_vocoder_factory = ledger.vocoder
610
- _orig_spatial_upsampler_factory = ledger.spatial_upsampler
611
  _orig_text_encoder_factory = ledger.text_encoder
612
  _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
613
 
@@ -618,7 +569,6 @@ _video_decoder = _orig_video_decoder_factory()
618
  _audio_encoder = _orig_audio_encoder_factory()
619
  _audio_decoder = _orig_audio_decoder_factory()
620
  _vocoder = _orig_vocoder_factory()
621
- _spatial_upsampler = _orig_spatial_upsampler_factory()
622
  _text_encoder = _orig_text_encoder_factory()
623
  _embeddings_processor = _orig_gemma_embeddings_factory()
624
 
@@ -630,7 +580,6 @@ ledger.video_decoder = lambda: _video_decoder
630
  ledger.audio_encoder = lambda: _audio_encoder
631
  ledger.audio_decoder = lambda: _audio_decoder
632
  ledger.vocoder = lambda: _vocoder
633
- ledger.spatial_upsampler = lambda: _spatial_upsampler
634
  ledger.text_encoder = lambda: _text_encoder
635
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
636
 
 
217
 
218
  # Stage 1: Generate sigmas using LTX2Scheduler with user-specified steps
219
  empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(
220
+ VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
221
  ).to_torch_shape())
222
  stage_1_sigmas = (
223
  LTX2Scheduler()
 
246
  ),
247
  )
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  # ── Stage 1: Half resolution ──
250
  stage_1_output_shape = VideoPixelShape(
251
  batch=1,
252
  frames=num_frames,
253
+ width=width,
254
+ height=height,
255
  fps=frame_rate,
256
  )
257
  stage_1_conditionings = combined_image_conditionings(
 
281
  torch.cuda.synchronize()
282
  # cleanup_memory()
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  # ── Decode both video and audio ──
285
  decoded_video = vae_decode_video(
286
  video_state.latent,
 
297
  return decoded_video, decoded_audio_output
298
 
299
  # Model repos
300
+ LTX_MODEL_REPO = "TenStrip/LTX2.3-10Eros"
301
  GEMMA_REPO ="Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
302
 
303
  # Download model checkpoints
 
318
  weights_dir.mkdir(exist_ok=True)
319
  checkpoint_path = hf_hub_download(
320
  repo_id=LTX_MODEL_REPO,
321
+ filename="10Eros_v1.2_bf16.safetensors",
322
  local_dir=str(weights_dir),
323
  local_dir_use_symlinks=False,
324
  )
325
+
326
  spatial_upsampler_path = hf_hub_download(repo_id="Lightricks/LTX-2.3", filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
327
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
328
 
 
559
  _orig_audio_encoder_factory = ledger.audio_encoder
560
  _orig_audio_decoder_factory = ledger.audio_decoder
561
  _orig_vocoder_factory = ledger.vocoder
 
562
  _orig_text_encoder_factory = ledger.text_encoder
563
  _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
564
 
 
569
  _audio_encoder = _orig_audio_encoder_factory()
570
  _audio_decoder = _orig_audio_decoder_factory()
571
  _vocoder = _orig_vocoder_factory()
 
572
  _text_encoder = _orig_text_encoder_factory()
573
  _embeddings_processor = _orig_gemma_embeddings_factory()
574
 
 
580
  ledger.audio_encoder = lambda: _audio_encoder
581
  ledger.audio_decoder = lambda: _audio_decoder
582
  ledger.vocoder = lambda: _vocoder
 
583
  ledger.text_encoder = lambda: _text_encoder
584
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
585