linoyts HF Staff commited on
Commit
e8a7471
·
verified ·
1 Parent(s): 7bf29b2

Update packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py

Browse files
packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py CHANGED
@@ -75,6 +75,11 @@ class TI2VidTwoStagesPipeline:
75
  device=device,
76
  )
77
 
 
 
 
 
 
78
  @torch.inference_mode()
79
  def __call__( # noqa: PLR0913
80
  self,
@@ -90,6 +95,10 @@ class TI2VidTwoStagesPipeline:
90
  cfg_guidance_scale: float,
91
  images: list[tuple[str, int, float]],
92
  tiling_config: TilingConfig | None = None,
 
 
 
 
93
  ) -> None:
94
  generator = torch.Generator(device=self.device).manual_seed(seed)
95
  noiser = GaussianNoiser(generator=generator)
@@ -97,18 +106,33 @@ class TI2VidTwoStagesPipeline:
97
  cfg_guider = CFGGuider(cfg_guidance_scale)
98
  dtype = torch.bfloat16
99
 
100
- text_encoder = self.stage_1_model_ledger.text_encoder()
101
- context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
102
- v_context_p, a_context_p = context_p
103
- v_context_n, a_context_n = context_n
 
 
 
104
 
105
- torch.cuda.synchronize()
106
- del text_encoder
107
- utils.cleanup_memory()
 
 
 
 
 
 
108
 
109
  # Stage 1: Initial low resolution video generation.
110
- video_encoder = self.stage_1_model_ledger.video_encoder()
111
- transformer = self.stage_1_model_ledger.transformer()
 
 
 
 
 
 
112
  sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
113
 
114
  def first_stage_denoising_loop(
@@ -151,8 +175,8 @@ class TI2VidTwoStagesPipeline:
151
  )
152
 
153
  torch.cuda.synchronize()
154
- del transformer
155
- utils.cleanup_memory()
156
 
157
  # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
158
  upscaled_video_latent = utils.upsample_video(
@@ -162,9 +186,12 @@ class TI2VidTwoStagesPipeline:
162
  )
163
 
164
  torch.cuda.synchronize()
165
- utils.cleanup_memory()
166
 
167
- transformer = self.stage_2_model_ledger.transformer()
 
 
 
168
  distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
169
 
170
  def second_stage_denoising_loop(
@@ -209,9 +236,9 @@ class TI2VidTwoStagesPipeline:
209
  )
210
 
211
  torch.cuda.synchronize()
212
- del transformer
213
- del video_encoder
214
- utils.cleanup_memory()
215
 
216
  decoded_video = vae_decode_video(video_state, self.stage_2_model_ledger.video_decoder(), tiling_config)
217
 
 
75
  device=device,
76
  )
77
 
78
+ # Cached models to avoid reloading
79
+ self._video_encoder = None
80
+ self._stage_1_transformer = None
81
+ self._stage_2_transformer = None
82
+
83
  @torch.inference_mode()
84
  def __call__( # noqa: PLR0913
85
  self,
 
95
  cfg_guidance_scale: float,
96
  images: list[tuple[str, int, float]],
97
  tiling_config: TilingConfig | None = None,
98
+ video_context_positive: torch.Tensor | None = None,
99
+ audio_context_positive: torch.Tensor | None = None,
100
+ video_context_negative: torch.Tensor | None = None,
101
+ audio_context_negative: torch.Tensor | None = None,
102
  ) -> None:
103
  generator = torch.Generator(device=self.device).manual_seed(seed)
104
  noiser = GaussianNoiser(generator=generator)
 
106
  cfg_guider = CFGGuider(cfg_guidance_scale)
107
  dtype = torch.bfloat16
108
 
109
+ # Use pre-computed embeddings if provided, otherwise encode text
110
+ if (video_context_positive is None or audio_context_positive is None or
111
+ video_context_negative is None or audio_context_negative is None):
112
+ text_encoder = self.stage_1_model_ledger.text_encoder()
113
+ context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
114
+ v_context_p, a_context_p = context_p
115
+ v_context_n, a_context_n = context_n
116
 
117
+ torch.cuda.synchronize()
118
+ del text_encoder
119
+ utils.cleanup_memory()
120
+ else:
121
+ # Move pre-computed embeddings to device if needed
122
+ v_context_p = video_context_positive.to(self.device)
123
+ a_context_p = audio_context_positive.to(self.device)
124
+ v_context_n = video_context_negative.to(self.device)
125
+ a_context_n = audio_context_negative.to(self.device)
126
 
127
  # Stage 1: Initial low resolution video generation.
128
+ # Load models only if not already cached
129
+ if self._video_encoder is None:
130
+ self._video_encoder = self.stage_1_model_ledger.video_encoder()
131
+ video_encoder = self._video_encoder
132
+
133
+ if self._stage_1_transformer is None:
134
+ self._stage_1_transformer = self.stage_1_model_ledger.transformer()
135
+ transformer = self._stage_1_transformer
136
  sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)
137
 
138
  def first_stage_denoising_loop(
 
175
  )
176
 
177
  torch.cuda.synchronize()
178
+ # del transformer
179
+ # utils.cleanup_memory()
180
 
181
  # Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
182
  upscaled_video_latent = utils.upsample_video(
 
186
  )
187
 
188
  torch.cuda.synchronize()
189
+ # utils.cleanup_memory()
190
 
191
+ # Load stage 2 transformer only if not already cached
192
+ if self._stage_2_transformer is None:
193
+ self._stage_2_transformer = self.stage_2_model_ledger.transformer()
194
+ transformer = self._stage_2_transformer
195
  distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
196
 
197
  def second_stage_denoising_loop(
 
236
  )
237
 
238
  torch.cuda.synchronize()
239
+ # del transformer
240
+ # del video_encoder
241
+ # utils.cleanup_memory()
242
 
243
  decoded_video = vae_decode_video(video_state, self.stage_2_model_ledger.video_decoder(), tiling_config)
244