eeuuia commited on
Commit
1cacf10
·
verified ·
1 Parent(s): 460fa35

Update api/ltx/vae_aduc_pipeline.py

Browse files
Files changed (1) hide show
  1. api/ltx/vae_aduc_pipeline.py +156 -145
api/ltx/vae_aduc_pipeline.py CHANGED
@@ -1,164 +1,175 @@
1
  # FILE: api/ltx/vae_aduc_pipeline.py
2
- # DESCRIPTION: A dedicated, "hot" VAE service specialist.
3
- # It loads the VAE model onto a dedicated GPU (managed by GPUManager)
4
- # and keeps it in memory to handle all encoding and decoding requests
5
- # with minimal latency, using the instance pre-loaded by LTXAducManager.
6
 
7
- import os
8
- import sys
9
- import time
10
  import logging
11
- from pathlib import Path
12
- from typing import List, Union, Tuple
13
-
14
  import torch
15
- import numpy as np
16
  from PIL import Image
 
 
 
 
 
 
17
 
18
- # Importa o gerenciador de GPUs e o gerenciador principal do LTX
19
- from managers.gpu_manager import gpu_manager
20
- from api.ltx.ltx_aduc_manager import LatentConditioningItem, ltx_aduc_manager
21
 
22
- # --- Importações da Arquitetura e do LTX ---
23
- try:
24
- # Adiciona o path para as bibliotecas do LTX
25
- LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
26
- if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
27
- sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
28
-
29
  from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
30
  from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
31
- except ImportError as e:
32
- raise ImportError(f"A crucial import failed for VaeLtxAducPipeline. Check dependencies. Error: {e}")
33
-
34
-
35
- class VaeLtxAducPipeline:
36
- _instance = None
37
-
38
- def __new__(cls, *args, **kwargs):
39
- if cls._instance is None:
40
- cls._instance = super().__new__(cls)
41
- cls._instance._initialized = False
42
- return cls._instance
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def __init__(self):
45
- if self._initialized: return
46
-
47
- logging.info("⚙️ Initializing VaeLtxAducPipeline Singleton...")
48
- t0 = time.time()
49
-
50
- # 1. Obter o dispositivo VAE dedicado do gerenciador central
51
- self.device = gpu_manager.get_ltx_vae_device()
52
-
53
- # 2. Obter a referência ao modelo VAE já carregado e posicionado pelo LTXAducManager
54
- try:
55
- # Esta é a etapa crucial: reutilizamos o pipeline já existente.
56
- self.vae = ltx_aduc_manager.get_pipeline().vae
57
- except Exception as e:
58
- logging.critical(f"Failed to get VAE from LTXAducManager. Is it initialized first? Error: {e}", exc_info=True)
59
- raise
60
-
61
- # 3. Confirmação: Garante que o VAE está no dispositivo correto.
62
- # O LTXAducManager já deve ter feito isso, mas esta é uma verificação de segurança.
63
- if self.vae.device != self.device:
64
- logging.warning(f"VAE device mismatch! Expected {self.device} but found {self.vae.device}. Forcing move.")
65
- self.vae.to(self.device)
66
-
67
- self.vae.eval()
68
- self.dtype = self.vae.dtype
69
-
70
- self._initialized = True
71
- logging.info(f"✅ VaeLtxAducPipeline ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s")
72
-
73
- def _cleanup_gpu(self):
74
- """Limpa a VRAM da GPU do VAE."""
75
- if torch.cuda.is_available():
76
- with torch.cuda.device(self.device):
77
- torch.cuda.empty_cache()
78
-
79
- def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor:
80
- """Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera."""
81
- if isinstance(item, Image.Image):
82
- from PIL import ImageOps
83
- img = item.convert("RGB")
84
- # Redimensiona mantendo a proporção e cortando o excesso
85
- processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
86
- image_np = np.array(processed_img).astype(np.float32) / 255.0
87
- tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW
88
- elif isinstance(item, torch.Tensor):
89
- # Se já for um tensor, apenas garante que está no formato CHW
90
- if item.ndim == 4 and item.shape[0] == 1: # Remove dimensão de batch se houver
91
- tensor = item.squeeze(0)
92
- elif item.ndim == 3:
93
- tensor = item
94
- else:
95
- raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}")
96
- else:
97
- raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}")
98
-
99
- # Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1]
100
- tensor_5d = tensor.unsqueeze(0).unsqueeze(2) # Adiciona B=1 e F=1
101
- return (tensor_5d * 2.0) - 1.0
102
 
103
- @torch.no_grad()
104
- def generate_conditioning_items(
105
  self,
106
- media_items: List[Union[Image.Image, torch.Tensor]],
107
- target_frames: List[int],
108
- strengths: List[float],
109
- target_resolution: Tuple[int, int]
110
- ) -> List[LatentConditioningItem]:
111
  """
112
- [FUNÇÃO PRINCIPAL]
113
- Converte uma lista de imagens (PIL ou tensores de pixel) em uma lista de
114
- LatentConditioningItem, pronta para ser usada pelo pipeline LTX corrigido.
 
 
 
 
 
 
 
 
115
  """
116
  t0 = time.time()
117
- logging.info(f"Generating {len(media_items)} latent conditioning items on device {self.device}...")
118
-
119
- if not (len(media_items) == len(target_frames) == len(strengths)):
120
- raise ValueError("As listas de media_items, target_frames e strengths devem ter o mesmo tamanho.")
121
-
122
- conditioning_items = []
123
- try:
124
- for item, frame, strength in zip(media_items, target_frames, strengths):
125
- # 1. Prepara a imagem/tensor para o formato de pixel correto
126
- pixel_tensor = self._preprocess_input(item, target_resolution)
127
-
128
- # 2. Move o tensor de pixel para a GPU do VAE e encoda para latente
129
- pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype)
130
- latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True)
131
-
132
- # 3. Cria o LatentConditioningItem com o latente (movido para CPU para evitar manter na VRAM)
133
- conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength))
134
-
135
- logging.info(f"Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.")
 
 
 
 
 
 
 
 
 
 
 
136
  return conditioning_items
137
- finally:
138
- self._cleanup_gpu()
139
-
140
- @torch.no_grad()
141
- def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
142
- """Decodifica um tensor latente para um tensor de pixels, retornando na CPU."""
143
- t0 = time.time()
144
- try:
145
- latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
146
- num_items_in_batch = latent_tensor_gpu.shape[0]
147
- timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype)
148
 
149
- pixels = vae_decode(
150
- latent_tensor_gpu, self.vae, is_video=True,
151
- timestep=timestep_tensor, vae_per_channel_normalize=True
152
- )
153
- logging.info(f"Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.")
154
- return pixels.cpu() # Retorna na CPU para liberar VRAM da GPU do VAE
155
- finally:
156
- self._cleanup_gpu()
157
-
158
- # --- Instância Singleton ---
159
- # A inicialização ocorre quando o módulo é importado pela primeira vez.
160
  try:
161
- vae_ltx_aduc_pipeline = VaeLtxAducPipeline()
162
  except Exception as e:
163
- logging.critical("CRITICAL: Failed to initialize VaeLtxAducPipeline singleton.", exc_info=True)
164
- vae_ltx_aduc_pipeline = None
 
1
  # FILE: api/ltx/vae_aduc_pipeline.py
2
+ # DESCRIPTION: A high-level client for submitting VAE-related jobs to the LTXAducManager pool.
3
+ # It handles encoding media to latents, decoding latents to pixels, and creating ConditioningItems.
 
 
4
 
 
 
 
5
  import logging
6
+ import time
 
 
7
  import torch
8
+ import torchvision.transforms.functional as TVF
9
  from PIL import Image
10
+ from typing import List, Union, Tuple, Literal
11
+ from dataclasses import dataclass
12
+ import os
13
+ import subprocess
14
+ import sys
15
+ from pathlib import Path
16
 
17
+ from api.ltx.ltx_aduc_manager import ltx_aduc_manager
 
 
18
 
19
+ DEPS_DIR = Path("/data")
20
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
21
+ repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
22
+ if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
23
+ sys.path.insert(0, repo_path)
24
+ print(f"[DEBUG] Repo adicionado ao sys.path: {repo_path}")
 
25
  from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
26
  from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
27
+ import ltx_video.pipelines.crf_compressor as crf_compressor
28
+
29
+ # ==============================================================================
30
+ # --- DEFINIÇÕES DE ESTRUTURA E HELPERS (Importadas ou movidas para cá) ---
31
+ # ==============================================================================
32
+
33
+ @dataclass
34
+ class LatentConditioningItem:
35
+ """
36
+ Estrutura de dados para passar latentes condicionados entre serviços.
37
+ O tensor latente é mantido na CPU para economizar VRAM.
38
+ """
39
+ latent_tensor: torch.Tensor
40
+ media_frame_number: int
41
+ conditioning_strength: float
42
+
43
+ def load_image_to_tensor_with_resize_and_crop(
44
+ image_input: Union[str, Image.Image],
45
+ target_height: int,
46
+ target_width: int,
47
+ ) -> torch.Tensor:
48
+ """
49
+ Carrega e processa uma imagem para um tensor de pixel 5D, normalizado para [-1, 1],
50
+ pronto para ser enviado ao VAE.
51
+ """
52
+ if isinstance(image_input, str):
53
+ image = Image.open(image_input).convert("RGB")
54
+ elif isinstance(image_input, Image.Image):
55
+ image = image_input
56
+ else:
57
+ raise ValueError("image_input must be a file path or a PIL Image object")
58
+
59
+ input_width, input_height = image.size
60
+ aspect_ratio_target = target_width / target_height
61
+ aspect_ratio_frame = input_width / input_height
62
+
63
+ if aspect_ratio_frame > aspect_ratio_target:
64
+ new_width, new_height = int(input_height * aspect_ratio_target), input_height
65
+ x_start, y_start = (input_width - new_width) // 2, 0
66
+ else:
67
+ new_width, new_height = input_width, int(input_width / aspect_ratio_target)
68
+ x_start, y_start = 0, (input_height - new_height) // 2
69
+
70
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
71
+ image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
72
+
73
+ frame_tensor = TVF.to_tensor(image)
74
+ frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=(3, 3))
75
+
76
+ frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
77
+ frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
78
+ frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
79
+
80
+ frame_tensor = (frame_tensor * 2.0) - 1.0
81
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
82
+
83
+
84
+ # ==============================================================================
85
+ # --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool) ---
86
+ # ==============================================================================
87
+
88
+ def _job_encode_media(vae: CausalVideoAutoencoder, pixel_tensor: torch.Tensor) -> torch.Tensor:
89
+ """Função de trabalho genérica para codificar um tensor de pixel."""
90
+ device = vae.device
91
+ dtype = vae.dtype
92
+ pixel_tensor_gpu = pixel_tensor.to(device, dtype=dtype)
93
+ latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
94
+ return latents.cpu()
95
+
96
+ def _job_decode_latent_to_pixels(vae: CausalVideoAutoencoder, latent_tensor: torch.Tensor) -> torch.Tensor:
97
+ """Função de trabalho para decodificar um tensor latente."""
98
+ device = vae.device
99
+ dtype = vae.dtype
100
+ latent_tensor_gpu = latent_tensor.to(device, dtype=dtype)
101
+ pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
102
+ return pixels.cpu()
103
+
104
+ # ==============================================================================
105
+ # --- A CLASSE CLIENTE (Interface Pública) ---
106
+ # ==============================================================================
107
+
108
+ class VaeAducPipeline:
109
+ """Cliente de alto nível para orquestrar todas as tarefas de VAE."""
110
  def __init__(self):
111
+ logging.info("✅ VAE ADUC Pipeline (Client) initialized and ready to submit jobs.")
112
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ def __call__(
 
115
  self,
116
+ media: Union[torch.Tensor, List[Union[Image.Image, torch.Tensor]]],
117
+ task: Literal['encode', 'decode', 'create_conditioning_items'],
118
+ target_resolution: Optional[Tuple[int, int]] = (512, 512),
119
+ conditioning_params: Optional[List[Tuple[int, float]]] = None
120
+ ) -> Union[List[torch.Tensor], torch.Tensor, List[LatentConditioningItem]]:
121
  """
122
+ Ponto de entrada principal para executar tarefas de VAE.
123
+
124
+ Args:
125
+ media: O dado de entrada.
126
+ task: A tarefa a executar ('encode', 'decode', 'create_conditioning_items').
127
+ target_resolution: A resolução (altura, largura) para o pré-processamento.
128
+ conditioning_params: Para 'create_conditioning_items', uma lista de tuplas
129
+ (frame_number, strength) correspondente a cada item de mídia.
130
+
131
+ Returns:
132
+ O resultado da tarefa, sempre na CPU.
133
  """
134
  t0 = time.time()
135
+ logging.info(f"VAE Client received a '{task}' job.")
136
+
137
+ if task == 'encode':
138
+ if not isinstance(media, list): media = [media]
139
+ pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
140
+ results = []
141
+ for pt in pixel_tensors:
142
+ latent = ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt)
143
+ results.append(latent)
144
+ return results
145
+
146
+ elif task == 'decode':
147
+ if not isinstance(media, torch.Tensor):
148
+ raise TypeError("Para 'decode', 'media' deve ser um único tensor latente.")
149
+ return ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_decode_latent_to_pixels, latent_tensor=media)
150
+
151
+ elif task == 'create_conditioning_items':
152
+ if not isinstance(media, list) or not isinstance(conditioning_params, list) or len(media) != len(conditioning_params):
153
+ raise ValueError("Para 'create_conditioning_items', 'media' e 'conditioning_params' devem ser listas de mesmo tamanho.")
154
+
155
+ pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
156
+ conditioning_items = []
157
+ for i, pt in enumerate(pixel_tensors):
158
+ latent_tensor = ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt)
159
+ frame_number, strength = conditioning_params[i]
160
+ conditioning_items.append(LatentConditioningItem(
161
+ latent_tensor=latent_tensor,
162
+ media_frame_number=frame_number,
163
+ conditioning_strength=strength
164
+ ))
165
  return conditioning_items
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ else:
168
+ raise ValueError(f"Tarefa desconhecida: '{task}'. Opções: 'encode', 'decode', 'create_conditioning_items'.")
169
+
170
+ # --- INSTÂNCIA SINGLETON DO CLIENTE ---
 
 
 
 
 
 
 
171
  try:
172
+ vae_aduc_pipeline = VaeAducPipeline()
173
  except Exception as e:
174
+ logging.critical("CRITICAL: Failed to initialize the VaeAducPipeline client.", exc_info=True)
175
+ vae_aduc_pipeline = None