caarleexx commited on
Commit
9d2962f
·
verified ·
1 Parent(s): f2a0118

Update api/ltx_server_refactored.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored.py +113 -2
api/ltx_server_refactored.py CHANGED
@@ -307,7 +307,6 @@ class VideoService:
307
 
308
 
309
 
310
- # ADICIONE A FUNÇÃO ABAIXO
311
  @torch.no_grad()
312
  def _image_to_latents(self, image_input: Union[str, Image.Image], height: int, width: int) -> torch.Tensor:
313
  """
@@ -428,7 +427,7 @@ class VideoService:
428
  tensor_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
429
 
430
  final_video_path = self._save_video_from_tensor(pixel_tensor, f"final_video_{seed}", seed, temp_dir, fps=DEFAULT_FPS)
431
- return final_video_path
432
 
433
  # --- Limpeza ---
434
  self._finalize()
@@ -437,6 +436,118 @@ class VideoService:
437
  return final_video_path, tensor_path, used_seed
438
 
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  # --------------------------------------------------------------------------
441
  # --- Métodos Internos e Auxiliares ---
442
  # --------------------------------------------------------------------------
 
307
 
308
 
309
 
 
310
  @torch.no_grad()
311
  def _image_to_latents(self, image_input: Union[str, Image.Image], height: int, width: int) -> torch.Tensor:
312
  """
 
427
  tensor_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
428
 
429
  final_video_path = self._save_video_from_tensor(pixel_tensor, f"final_video_{seed}", seed, temp_dir, fps=DEFAULT_FPS)
430
+ # A linha "return final_video_path" foi removida daqui!
431
 
432
  # --- Limpeza ---
433
  self._finalize()
 
436
  return final_video_path, tensor_path, used_seed
437
 
438
 
439
+ # --------------------------------------------------------------------------
440
+ # --- Métodos Públicos (API do Serviço) ---
441
+ # --------------------------------------------------------------------------
442
+
443
+ def generate_upscale_denoise(
444
+ self,
445
+ latents_path: str,
446
+ prompt: str,
447
+ negative_prompt: str,
448
+ guidance_scale: float,
449
+ seed: int,
450
+ # Estes deveriam ser parâmetros no mundo ideal, mas vamos recuperá-los do latente se necessário
451
+ # height: int,
452
+ # width: int,
453
+ # duration_secs: float,
454
+ ) -> Tuple[str, str]:
455
+ """
456
+ ETAPA 2: Aplica upscaling espacial e denoise na textura usando o pipeline principal
457
+ para refinar um tensor de latentes de baixa resolução.
458
+ """
459
+ print("[INFO] Iniciando ETAPA 2: Upscale e Refinamento de Textura (LTX)...")
460
+
461
+ # --- 1. Carregar Latentes de Entrada ---
462
+ if not Path(latents_path).exists():
463
+ raise FileNotFoundError(f"Latentes não encontrados no caminho: {latents_path}")
464
+
465
+ latents_low_res = torch.load(latents_path).to(self.device, dtype=self._get_precision_dtype())
466
+ log_tensor_info(latents_low_res, "Latentes Carregados (Baixa Resolução)")
467
+
468
+ # --- 2. Upsample Espacial e Filtro AdaIN ---
469
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
470
+ upsampled_latents = self._upsample_and_filter_latents(latents_low_res)
471
+ log_tensor_info(upsampled_latents, "Latentes Upscaled (Antes do Denoise)")
472
+
473
+ # --- 3. Denoise (Segundo Passo da Pipeline) ---
474
+ _, _, _, latent_height, latent_width = upsampled_latents.shape
475
+ # O upscaler dobra a resolução espacial
476
+ target_height = latent_height * self.pipeline.vae_scale_factor
477
+ target_width = latent_width * self.pipeline.vae_scale_factor
478
+
479
+ # O num_frames latente é preservado
480
+ latent_num_frames = upsampled_latents.shape[2]
481
+ actual_num_frames = (latent_num_frames - 1) * 8 + 1 # Reverte de latent frames para pixel frames (n*8+1)
482
+ duration_secs = actual_num_frames / DEFAULT_FPS
483
+
484
+ print(f" - Resolução de Saída Estimada: {target_height}x{target_width}")
485
+ print(f" - Frames Estimados: {actual_num_frames}")
486
+
487
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_high_")
488
+ self._register_tmp_dir(temp_dir)
489
+
490
+ second_pass_kwargs = {
491
+ "prompt": prompt,
492
+ "negative_prompt": negative_prompt,
493
+ "height": target_height,
494
+ "width": target_width,
495
+ "num_frames": actual_num_frames,
496
+ "frame_rate": int(DEFAULT_FPS),
497
+ "generator": torch.Generator(device=self.device).manual_seed(seed),
498
+ "output_type": "latent",
499
+ "vae_per_channel_normalize": True,
500
+ "is_video": True,
501
+ "latents": upsampled_latents, # Passa os latentes upscaled como ponto de partida
502
+ "guidance_scale": float(guidance_scale),
503
+ **(self.config.get("second_pass", {}))
504
+ }
505
+
506
+ print(" - Enviando para a pipeline LTX (Refinamento)...")
507
+ with torch.autocast(device_type=self.device.split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
508
+ refined_latents = self.pipeline(**second_pass_kwargs).images
509
+ log_tensor_info(refined_latents, "Latentes Refinados (Saída do Denoise)")
510
+
511
+ # --- 4. Decodificação e Saída ---
512
+ pixel_tensor = vae_manager_singleton.decode(refined_latents, decode_timestep=float(self.config.get("decode_timestep", 0.00)))
513
+
514
+ tensor_path = self._save_latents_to_disk(refined_latents, "latents_refined", seed)
515
+ video_path = self._save_video_from_tensor(pixel_tensor, f"refined_video_{seed}", seed, temp_dir, fps=DEFAULT_FPS)
516
+
517
+ self._finalize()
518
+
519
+ print("[SUCCESS] ETAPA 2 Concluída.")
520
+ return video_path, tensor_path
521
+
522
+ def move_to_cpu(self):
523
+ """Move todos os modelos para a CPU para liberar VRAM."""
524
+ print("[LTX/SWAP] Movendo modelos LTX para a CPU...")
525
+ self.pipeline.to("cpu")
526
+ if self.latent_upsampler:
527
+ self.latent_upsampler.to("cpu")
528
+ gc.collect()
529
+ if torch.cuda.is_available():
530
+ torch.cuda.empty_cache()
531
+ torch.cuda.ipc_collect()
532
+ print("[LTX/SWAP] Modelos LTX na CPU.")
533
+
534
+ def move_to_device(self, device: torch.device):
535
+ """Move todos os modelos para o dispositivo especificado (normalmente GPU)."""
536
+ if str(device) == "cpu":
537
+ return self.move_to_cpu()
538
+
539
+ print(f"[LTX/SWAP] Movendo modelos LTX para {device}...")
540
+ self.pipeline.to(device)
541
+ if self.latent_upsampler:
542
+ self.latent_upsampler.to(device)
543
+ self.device = str(device)
544
+ gc.collect()
545
+ if torch.cuda.is_available():
546
+ torch.cuda.empty_cache()
547
+ torch.cuda.ipc_collect()
548
+ print(f"[LTX/SWAP] Modelos LTX em {device}.")
549
+
550
+
551
  # --------------------------------------------------------------------------
552
  # --- Métodos Internos e Auxiliares ---
553
  # --------------------------------------------------------------------------