linoyts HF Staff commited on
Commit
90b9d86
·
verified ·
1 Parent(s): ade6776

Update packages/ltx-pipelines/src/ltx_pipelines/distilled.py

Browse files
packages/ltx-pipelines/src/ltx_pipelines/distilled.py CHANGED
@@ -64,6 +64,10 @@ class DistilledPipeline:
64
  device=device,
65
  )
66
 
 
 
 
 
67
  @torch.inference_mode()
68
  def __call__(
69
  self,
@@ -76,23 +80,37 @@ class DistilledPipeline:
76
  frame_rate: float,
77
  images: list[tuple[str, int, float]],
78
  tiling_config: TilingConfig | None = None,
 
 
79
  ) -> None:
80
  generator = torch.Generator(device=self.device).manual_seed(seed)
81
  noiser = GaussianNoiser(generator=generator)
82
  stepper = EulerDiffusionStep()
83
  dtype = torch.bfloat16
84
 
85
- text_encoder = self.model_ledger.text_encoder()
86
- context_p = encode_text(text_encoder, prompts=[prompt])[0]
87
- video_context, audio_context = context_p
 
 
88
 
89
- torch.cuda.synchronize()
90
- del text_encoder
91
- utils.cleanup_memory()
 
 
 
 
92
 
93
  # Stage 1: Initial low resolution video generation.
94
- video_encoder = self.model_ledger.video_encoder()
95
- transformer = self.model_ledger.transformer()
 
 
 
 
 
 
96
  stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
97
 
98
  def denoising_loop(
@@ -168,9 +186,9 @@ class DistilledPipeline:
168
  )
169
 
170
  torch.cuda.synchronize()
171
- del transformer
172
- del video_encoder
173
- utils.cleanup_memory()
174
 
175
  decoded_video = vae_decode_video(video_state, self.model_ledger.video_decoder(), tiling_config)
176
 
@@ -214,4 +232,4 @@ def main() -> None:
214
 
215
 
216
  if __name__ == "__main__":
217
- main()
 
64
  device=device,
65
  )
66
 
67
+ # Cached models to avoid reloading
68
+ self._video_encoder = None
69
+ self._transformer = None
70
+
71
  @torch.inference_mode()
72
  def __call__(
73
  self,
 
80
  frame_rate: float,
81
  images: list[tuple[str, int, float]],
82
  tiling_config: TilingConfig | None = None,
83
+ video_context: torch.Tensor | None = None,
84
+ audio_context: torch.Tensor | None = None,
85
  ) -> None:
86
  generator = torch.Generator(device=self.device).manual_seed(seed)
87
  noiser = GaussianNoiser(generator=generator)
88
  stepper = EulerDiffusionStep()
89
  dtype = torch.bfloat16
90
 
91
+ # Use pre-computed embeddings if provided, otherwise encode text
92
+ if video_context is None or audio_context is None:
93
+ text_encoder = self.model_ledger.text_encoder()
94
+ context_p = encode_text(text_encoder, prompts=[prompt])[0]
95
+ video_context, audio_context = context_p
96
 
97
+ torch.cuda.synchronize()
98
+ del text_encoder
99
+ utils.cleanup_memory()
100
+ else:
101
+ # Move pre-computed embeddings to device if needed
102
+ video_context = video_context.to(self.device)
103
+ audio_context = audio_context.to(self.device)
104
 
105
  # Stage 1: Initial low resolution video generation.
106
+ # Load models only if not already cached
107
+ if self._video_encoder is None:
108
+ self._video_encoder = self.model_ledger.video_encoder()
109
+ video_encoder = self._video_encoder
110
+
111
+ if self._transformer is None:
112
+ self._transformer = self.model_ledger.transformer()
113
+ transformer = self._transformer
114
  stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
115
 
116
  def denoising_loop(
 
186
  )
187
 
188
  torch.cuda.synchronize()
189
+ # del transformer
190
+ # del video_encoder
191
+ # utils.cleanup_memory()
192
 
193
  decoded_video = vae_decode_video(video_state, self.model_ledger.video_decoder(), tiling_config)
194
 
 
232
 
233
 
234
  if __name__ == "__main__":
235
+ main()