linoyts HF Staff commited on
Commit
ecfa616
·
verified ·
1 Parent(s): 5efa740

Update packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py

Browse files
packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py CHANGED
@@ -22,6 +22,7 @@ from ltx_pipelines.pipeline_utils import (
22
  denoise_audio_video,
23
  encode_text,
24
  euler_denoising_loop,
 
25
  guider_denoising_func,
26
  simple_denoising_func,
27
  )
@@ -90,6 +91,10 @@ class TI2VidTwoStagesPipeline:
90
  cfg_guidance_scale: float,
91
  images: list[tuple[str, int, float]],
92
  tiling_config: TilingConfig | None = None,
 
 
 
 
93
  ) -> None:
94
  generator = torch.Generator(device=self.device).manual_seed(seed)
95
  noiser = GaussianNoiser(generator=generator)
@@ -97,14 +102,23 @@ class TI2VidTwoStagesPipeline:
97
  cfg_guider = CFGGuider(cfg_guidance_scale)
98
  dtype = torch.bfloat16
99
 
100
- text_encoder = self.stage_1_model_ledger.text_encoder()
101
- context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
102
- v_context_p, a_context_p = context_p
103
- v_context_n, a_context_n = context_n
 
 
 
104
 
105
- torch.cuda.synchronize()
106
- del text_encoder
107
- utils.cleanup_memory()
 
 
 
 
 
 
108
 
109
  # Stage 1: Initial low resolution video generation.
110
  video_encoder = self.stage_1_model_ledger.video_encoder()
@@ -170,7 +184,18 @@ class TI2VidTwoStagesPipeline:
170
  def second_stage_denoising_loop(
171
  sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
172
  ) -> tuple[LatentState, LatentState]:
173
- return euler_denoising_loop(
 
 
 
 
 
 
 
 
 
 
 
174
  sigmas=sigmas,
175
  video_state=video_state,
176
  audio_state=audio_state,
@@ -180,6 +205,7 @@ class TI2VidTwoStagesPipeline:
180
  audio_context=a_context_p,
181
  transformer=transformer, # noqa: F821
182
  ),
 
183
  )
184
 
185
  stage_2_output_shape = VideoPixelShape(
 
22
  denoise_audio_video,
23
  encode_text,
24
  euler_denoising_loop,
25
+ gradient_estimating_euler_denoising_loop,
26
  guider_denoising_func,
27
  simple_denoising_func,
28
  )
 
91
  cfg_guidance_scale: float,
92
  images: list[tuple[str, int, float]],
93
  tiling_config: TilingConfig | None = None,
94
+ video_context_positive: torch.Tensor | None = None,
95
+ audio_context_positive: torch.Tensor | None = None,
96
+ video_context_negative: torch.Tensor | None = None,
97
+ audio_context_negative: torch.Tensor | None = None,
98
  ) -> None:
99
  generator = torch.Generator(device=self.device).manual_seed(seed)
100
  noiser = GaussianNoiser(generator=generator)
 
102
  cfg_guider = CFGGuider(cfg_guidance_scale)
103
  dtype = torch.bfloat16
104
 
105
+ # Use pre-computed embeddings if provided, otherwise encode text
106
+ if (video_context_positive is None or audio_context_positive is None or
107
+ video_context_negative is None or audio_context_negative is None):
108
+ text_encoder = self.stage_1_model_ledger.text_encoder()
109
+ context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
110
+ v_context_p, a_context_p = context_p
111
+ v_context_n, a_context_n = context_n
112
 
113
+ torch.cuda.synchronize()
114
+ del text_encoder
115
+ utils.cleanup_memory()
116
+ else:
117
+ # Move pre-computed embeddings to device if needed
118
+ v_context_p = video_context_positive.to(self.device)
119
+ a_context_p = audio_context_positive.to(self.device)
120
+ v_context_n = video_context_negative.to(self.device)
121
+ a_context_n = audio_context_negative.to(self.device)
122
 
123
  # Stage 1: Initial low resolution video generation.
124
  video_encoder = self.stage_1_model_ledger.video_encoder()
 
184
  def second_stage_denoising_loop(
185
  sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
186
  ) -> tuple[LatentState, LatentState]:
187
+ # return euler_denoising_loop(
188
+ # sigmas=sigmas,
189
+ # video_state=video_state,
190
+ # audio_state=audio_state,
191
+ # stepper=stepper,
192
+ # denoise_fn=simple_denoising_func(
193
+ # video_context=v_context_p,
194
+ # audio_context=a_context_p,
195
+ # transformer=transformer, # noqa: F821
196
+ # ),
197
+ # )
198
+ return gradient_estimating_euler_denoising_loop(
199
  sigmas=sigmas,
200
  video_state=video_state,
201
  audio_state=audio_state,
 
205
  audio_context=a_context_p,
206
  transformer=transformer, # noqa: F821
207
  ),
208
+ ge_gamma=2.0, # Gradient estimation coefficient
209
  )
210
 
211
  stage_2_output_shape = VideoPixelShape(