Eueuiaa commited on
Commit
dfa1859
·
verified ·
1 Parent(s): 1054840

Update api/ltx_server_refactored.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored.py +29 -10
api/ltx_server_refactored.py CHANGED
@@ -233,10 +233,18 @@ class VideoService:
233
  tensor = self._prepare_conditioning_tensor(media, height, width, padding_values) if isinstance(media, str) else media.to(self.device, dtype=self.runtime_autocast_dtype)
234
  safe_frame = max(0, min(int(frame), num_frames - 1))
235
  conditioning_items.append(ConditioningItem(tensor, safe_frame, float(weight)))
236
-
237
-
238
  return conditioning_items
239
 
 
 
 
 
 
 
 
 
 
 
240
  def generate_low(self, prompt, negative_prompt, height, width, duration, guidance_scale, seed, conditioning_items=None):
241
  used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
242
  seed_everething(used_seed)
@@ -409,8 +417,10 @@ class VideoService:
409
  frames_per_chunk_last = max(9, frames_per_chunk_last)
410
 
411
  poda_latents_num = overlap_frames // self.pipeline.video_scale_factor if self.pipeline.video_scale_factor > 0 else 0
412
-
 
413
  latentes_chunk_video = []
 
414
  lista_patch_latentes_chunk = []
415
  condition_item_latent_overlap = None
416
  temp_dir = tempfile.mkdtemp(prefix="ltxv_narrative_"); self._register_tmp_dir(temp_dir)
@@ -443,10 +453,22 @@ class VideoService:
443
 
444
  frames_per_chunk = ((frames_per_chunk - 1)//8)*8 + 1
445
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  latentes_bruto_r = self._generate_single_chunk_low(
447
  prompt=chunk_prompt, negative_prompt=negative_prompt, height=height, width=width,
448
  num_frames=frames_per_chunk, guidance_scale=guidance_scale, seed=used_seed + i,
449
- itens_conditions_itens=None,
450
  ltx_configs_override=ltx_configs_override
451
  )
452
 
@@ -458,12 +480,9 @@ class VideoService:
458
  #final_latents = torch.cat(lista_tensores, dim=2).to(self.device)
459
 
460
 
461
- if i== 0:
462
- initial_conditions = None
463
- else:
464
- initial_conditions = initial_conditions
465
-
466
-
467
 
468
 
469
  #poda inicio overlap
 
233
  tensor = self._prepare_conditioning_tensor(media, height, width, padding_values) if isinstance(media, str) else media.to(self.device, dtype=self.runtime_autocast_dtype)
234
  safe_frame = max(0, min(int(frame), num_frames - 1))
235
  conditioning_items.append(ConditioningItem(tensor, safe_frame, float(weight)))
 
 
236
  return conditioning_items
237
 
238
+ def prepare_condition_items_latent(self, items_list: List, height: int, width: int, num_frames: int):
239
+ if not items_list: return []
240
+ conditioning_items = []
241
+ for tensor_patch, frame, weight in items_list:
242
+ tensor = torch.load(tensor_patch).to(self.device)
243
+ safe_frame = max(0, min(int(frame), num_frames - 1))
244
+ conditioning_items.append(ConditioningItem(tensor, safe_frame, float(weight)))
245
+ return conditioning_items
246
+
247
+
248
  def generate_low(self, prompt, negative_prompt, height, width, duration, guidance_scale, seed, conditioning_items=None):
249
  used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
250
  seed_everething(used_seed)
 
417
  frames_per_chunk_last = max(9, frames_per_chunk_last)
418
 
419
  poda_latents_num = overlap_frames // self.pipeline.video_scale_factor if self.pipeline.video_scale_factor > 0 else 0
420
+
421
+ initial_conditions= []
422
  latentes_chunk_video = []
423
+ overlap_condition = []
424
  lista_patch_latentes_chunk = []
425
  condition_item_latent_overlap = None
426
  temp_dir = tempfile.mkdtemp(prefix="ltxv_narrative_"); self._register_tmp_dir(temp_dir)
 
453
 
454
  frames_per_chunk = ((frames_per_chunk - 1)//8)*8 + 1
455
 
456
+
457
+ if i== 0:
458
+ initial_conditions = initial_conditions
459
+ else:
460
+ initial_conditions = None
461
+
462
+ if overlap_latents!=None:
463
+ items_list = [[overlap_latents, 0, 1.0]]
464
+ overlap_condition = prepare_condition_items_latent(items_list)
465
+
466
+ itens_conditions_itens = latentes_chunk_video + overlap_condition
467
+
468
  latentes_bruto_r = self._generate_single_chunk_low(
469
  prompt=chunk_prompt, negative_prompt=negative_prompt, height=height, width=width,
470
  num_frames=frames_per_chunk, guidance_scale=guidance_scale, seed=used_seed + i,
471
+ itens_conditions_itens=itens_conditions_itens,
472
  ltx_configs_override=ltx_configs_override
473
  )
474
 
 
480
  #final_latents = torch.cat(lista_tensores, dim=2).to(self.device)
481
 
482
 
483
+
484
+
485
+
 
 
 
486
 
487
 
488
  #poda inicio overlap