eeuuia commited on
Commit
edd6b83
·
verified ·
1 Parent(s): 9e88bfd

Update api/ltx/ltx_aduc_pipeline.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_pipeline.py +27 -28
api/ltx/ltx_aduc_pipeline.py CHANGED
@@ -119,7 +119,7 @@ class LtxAducPipeline:
119
  num_chunks = len(prompt_list)
120
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
121
  frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT)
122
- overlap_frames = 9 if is_narrative else 0
123
 
124
  initial_conditions = []
125
  if initial_media_items:
@@ -133,6 +133,8 @@ class LtxAducPipeline:
133
 
134
  temp_latent_paths = []
135
  overlap_condition_item: Optional[LatentConditioningItem] = None
 
 
136
 
137
  try:
138
  for i, chunk_prompt in enumerate(prompt_list):
@@ -146,8 +148,9 @@ class LtxAducPipeline:
146
  if overlap_condition_item: current_conditions.append(overlap_condition_item)
147
 
148
  chunk_latents = self._generate_single_chunk_low(
149
- prompt=chunk_prompt, num_frames=current_frames, seed=used_seed + i,
150
- conditioning_items=current_conditions, **kwargs
 
151
  )
152
  if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
153
 
@@ -158,6 +161,7 @@ class LtxAducPipeline:
158
  media_frame_number=0,
159
  conditioning_strength=1.0
160
  )
 
161
 
162
  if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
163
 
@@ -183,6 +187,7 @@ class LtxAducPipeline:
183
  # --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
184
  # ==========================================================================
185
 
 
186
  def _log_conditioning_items(self, items: List[LatentConditioningItem]):
187
  """
188
  Logs detailed information about a list of ConditioningItem objects.
@@ -206,18 +211,20 @@ class LtxAducPipeline:
206
  f"Strength = {item.conditioning_strength:.2f}"
207
  )
208
  else:
209
- tt = str(itemvalue)
210
  log_str.append(f" -> Item [{i}]: Não contém um tensor válido.")
211
- log_str.append(f" {tt[:70]}")
212
 
213
- log_str.append("="*40 + "\n")
214
 
215
  # Usa o logger de debug para imprimir a mensagem completa
216
  logging.info("\n".join(log_str))
217
 
218
 
219
  @log_function_io
220
- def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]:
 
 
 
 
221
  """[WORKER] Calls the pipeline to generate a single chunk of latents."""
222
  height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
223
  downscale_factor = self.config.get("downscale_factor", 0.6666666)
@@ -225,26 +232,15 @@ class LtxAducPipeline:
225
  downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
226
  downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
227
 
228
- # 1. Começa com a configuração padrão
229
- first_pass_config = self.config.get("first_pass", {}).copy()
230
-
231
- # 2. Aplica os overrides da UI, se existirem
232
- if kwargs.get("ltx_configs_override"):
233
- self._apply_ui_overrides(first_pass_config, kwargs.get("ltx_configs_override"))
234
-
235
-
236
- # 3. Monta o dicionário de argumentos SEM conditioning_items primeiro
237
- pipeline_kwargs = {
238
- "num_inference_steps": first_pass_config.get("num_inference_steps"),
239
- "skip_final_inference_steps": first_pass_config.get("skip_final_inference_steps"),
240
  "cfg_star_rescale": "true",
241
- "prompt": kwargs['prompt'],
242
  "negative_prompt": kwargs['negative_prompt'],
243
  "height": downscaled_height,
244
  "width": downscaled_width,
245
- "num_frames": kwargs['num_frames'],
246
  "frame_rate": int(DEFAULT_FPS),
247
- "generator": torch.Generator(device=self.main_device).manual_seed(kwargs['seed']),
248
  "output_type": "latent",
249
  "media_items": None,
250
  "decode_timestep": self.config["decode_timestep"],
@@ -257,14 +253,17 @@ class LtxAducPipeline:
257
  "offload_to_cpu": False,
258
  "enhance_prompt": False,
259
  }
260
-
261
- # Loga os conditioning_items separadamente com a nossa função helper
262
- conditioning_items_list = kwargs.get('conditioning_items')
263
- self._log_conditioning_items(conditioning_items_list)
264
- pipeline_kwargs['conditioning_items'] = conditioning_items_list
 
 
 
265
 
266
  with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
267
- latents_raw = self.pipeline(**pipeline_kwargs).images
268
 
269
  return latents_raw.to(self.main_device)
270
 
 
119
  num_chunks = len(prompt_list)
120
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
121
  frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT)
122
+ overlap_frames = 8 if is_narrative else 0
123
 
124
  initial_conditions = []
125
  if initial_media_items:
 
133
 
134
  temp_latent_paths = []
135
  overlap_condition_item: Optional[LatentConditioningItem] = None
136
+
137
+ current_conditions = initial_conditions
138
 
139
  try:
140
  for i, chunk_prompt in enumerate(prompt_list):
 
148
  if overlap_condition_item: current_conditions.append(overlap_condition_item)
149
 
150
  chunk_latents = self._generate_single_chunk_low(
151
+ prompt_x=chunk_prompt, num_frames_x=current_frames, seed_x=used_seed,
152
+ conditioning_items_x=current_conditions,
153
+ **kwargs
154
  )
155
  if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
156
 
 
161
  media_frame_number=0,
162
  conditioning_strength=1.0
163
  )
164
+ current_conditions=overlap_condition_item
165
 
166
  if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
167
 
 
187
  # --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
188
  # ==========================================================================
189
 
190
+ @log_function_io
191
  def _log_conditioning_items(self, items: List[LatentConditioningItem]):
192
  """
193
  Logs detailed information about a list of ConditioningItem objects.
 
211
  f"Strength = {item.conditioning_strength:.2f}"
212
  )
213
  else:
 
214
  log_str.append(f" -> Item [{i}]: Não contém um tensor válido.")
 
215
 
216
+ log_str.append("="*30 + "\n")
217
 
218
  # Usa o logger de debug para imprimir a mensagem completa
219
  logging.info("\n".join(log_str))
220
 
221
 
222
  @log_function_io
223
+ def _generate_single_chunk_low(
224
+ prompt_x:str, num_frames_x:int, seed_x:int,
225
+ conditioning_items_x:LatentConditioningItem,
226
+ **kwargs
227
+ ) -> Optional[torch.Tensor]:
228
  """[WORKER] Calls the pipeline to generate a single chunk of latents."""
229
  height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
230
  downscale_factor = self.config.get("downscale_factor", 0.6666666)
 
232
  downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
233
  downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
234
 
235
+ call_kwargs = {
 
 
 
 
 
 
 
 
 
 
 
236
  "cfg_star_rescale": "true",
237
+ "prompt": prompt_x,
238
  "negative_prompt": kwargs['negative_prompt'],
239
  "height": downscaled_height,
240
  "width": downscaled_width,
241
+ "num_frames": num_frames_x,
242
  "frame_rate": int(DEFAULT_FPS),
243
+ "generator": torch.Generator(device=self.main_device).manual_seed(seed_x),
244
  "output_type": "latent",
245
  "media_items": None,
246
  "decode_timestep": self.config["decode_timestep"],
 
253
  "offload_to_cpu": False,
254
  "enhance_prompt": False,
255
  }
256
+
257
+ call_kwargs.pop("num_inference_steps", None)
258
+ call_kwargs.pop("second_pass", None)
259
+ first_pass_config = self.config.get("first_pass", {}).copy()
260
+ call_kwargs.update(first_pass_config)
261
+ ltx_configs_override = kwargs.get("ltx_configs_override", {}).copy()
262
+ call_kwargs.update(ltx_configs_override)
263
+ call_kwargs['conditioning_items'] = conditioning_items_x
264
 
265
  with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
266
+ latents_raw = self.pipeline(**call_kwargs).images
267
 
268
  return latents_raw.to(self.main_device)
269