eeuuia commited on
Commit
da355b7
·
verified ·
1 Parent(s): 8bf282e

Create ltx_pool_manager

Browse files
Files changed (1) hide show
  1. api/ltx_pool_manager +208 -0
api/ltx_pool_manager ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILE: api/ltx_pool_manager.py
2
+ # DESCRIPTION: A singleton pool manager for the LTX-Video pipeline.
3
+ # This module is the "secret weapon": it handles loading, device placement,
4
+ # and applies a runtime monkey patch to the LTX pipeline for full control
5
+ # and compatibility with the ADUC-SDR architecture, especially for latent conditioning.
6
+
7
+ import logging
8
+ import time
9
+ import os
10
+ import yaml
11
+ import json
12
+ from pathlib import Path
13
+ from typing import List, Optional, Tuple, Union
14
+ from dataclasses import dataclass
15
+
16
+ import torch
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ # --- Importações da nossa arquitetura ---
21
+ from api.gpu_manager import gpu_manager
22
+
23
+ # --- Importações da biblioteca LTX-Video e Utilitários ---
24
+ from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu
25
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
26
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
27
+
28
+ # ==============================================================================
29
+ # --- DEFINIÇÃO DOS NOSSOS DATACLASSES DE CONDICIONAMENTO ---
30
+ # ==============================================================================
31
+
32
+ @dataclass
33
+ class ConditioningItem:
34
+ """Nosso Data Class para condicionamento com TENSORES DE PIXEL (de imagens)."""
35
+ pixel_tensor: torch.Tensor
36
+ media_frame_number: int
37
+ conditioning_strength: float
38
+
39
+ @dataclass
40
+ class LatentConditioningItem:
41
+ """Nossa "arma secreta": um Data Class para condicionamento com TENSORES LATENTES (de overlap)."""
42
+ latent_tensor: torch.Tensor
43
+ media_frame_number: int
44
+ conditioning_strength: float
45
+
46
+ # ==============================================================================
47
+ # --- O MONKEY PATCH ---
48
+ # Nossa versão customizada de `prepare_conditioning` que entende ambos os Data Classes.
49
+ # ==============================================================================
50
+
51
+ def _aduc_prepare_conditioning_patch(
52
+ self: "LTXVideoPipeline",
53
+ conditioning_items: Optional[List[Union[ConditioningItem, LatentConditioningItem]]],
54
+ init_latents: torch.Tensor,
55
+ num_frames: int, height: int, width: int, # Assinatura mantida para compatibilidade
56
+ vae_per_channel_normalize: bool = False,
57
+ generator=None,
58
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
59
+
60
+ # Se não houver itens, apenas "patchify" os latentes iniciais e retorna.
61
+ if not conditioning_items:
62
+ latents, latent_coords = self.patchifier.patchify(latents=init_latents)
63
+ pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
64
+ return latents, pixel_coords, None, 0
65
+
66
+ # Prepara máscaras e listas para acumular os tensores de condição.
67
+ init_conditioning_mask = torch.zeros_like(init_latents[:, 0, ...], dtype=torch.float32, device=init_latents.device)
68
+ extra_conditioning_latents, extra_conditioning_pixel_coords, extra_conditioning_mask = [], [], []
69
+ extra_conditioning_num_latents = 0
70
+
71
+ for item in conditioning_items:
72
+ strength = item.conditioning_strength
73
+ media_frame_number = item.media_frame_number
74
+
75
+ # --- LÓGICA PRINCIPAL DO PATCH ---
76
+ if isinstance(item, ConditioningItem):
77
+ # Item é um tensor de PIXEL (ex: imagem inicial).
78
+ logging.debug("Patch ADUC: Processando ConditioningItem (pixels).")
79
+ # Encodifica o tensor de pixel para o espaço latente usando o VAE.
80
+ # Garante que a operação ocorra no dispositivo do VAE para evitar erros.
81
+ pixel_tensor_on_vae_device = item.pixel_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
82
+ media_item_latents = vae_encode(pixel_tensor_on_vae_device, self.vae, vae_per_channel_normalize=vae_per_channel_normalize)
83
+ # Traz o resultado de volta para o dispositivo principal (do Transformer).
84
+ media_item_latents = media_item_latents.to(device=init_latents.device, dtype=init_latents.dtype)
85
+
86
+ elif isinstance(item, LatentConditioningItem):
87
+ # Item já é um tensor LATENTE (ex: overlap de chunks).
88
+ logging.debug("Patch ADUC: Processando LatentConditioningItem (latentes).")
89
+ # Apenas garante que o tensor está no dispositivo e tipo corretos.
90
+ media_item_latents = item.latent_tensor.to(device=init_latents.device, dtype=init_latents.dtype)
91
+ else:
92
+ logging.warning(f"Patch ADUC: Item de condicionamento de tipo desconhecido '{type(item)}' será ignorado.")
93
+ continue
94
+
95
+ # Lógica original da pipeline, agora operando sobre `media_item_latents` garantido.
96
+ if media_frame_number == 0:
97
+ f_l, h_l, w_l = media_item_latents.shape[-3:]
98
+ init_latents[..., :f_l, :h_l, :w_l] = torch.lerp(init_latents[..., :f_l, :h_l, :w_l], media_item_latents, strength)
99
+ init_conditioning_mask[..., :f_l, :h_l, :w_l] = strength
100
+ else: # Condicionamento em frames intermediários
101
+ noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
102
+ media_item_latents = torch.lerp(noise, media_item_latents, strength)
103
+ patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
104
+ pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
105
+ pixel_coords[:, 0] += media_frame_number
106
+ extra_conditioning_num_latents += patched_latents.shape[1]
107
+ new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
108
+ extra_conditioning_latents.append(patched_latents)
109
+ extra_conditioning_pixel_coords.append(pixel_coords)
110
+ extra_conditioning_mask.append(new_mask)
111
+
112
+ # Finaliza o processo de patchifying e concatenação dos tensores.
113
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
114
+ init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
115
+ init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
116
+ init_conditioning_mask = init_conditioning_mask.squeeze(-1)
117
+
118
+ if extra_conditioning_latents:
119
+ init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
120
+ init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
121
+ init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
122
+
123
+ return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
124
+
125
+ # ==============================================================================
126
+ # --- LTX WORKER E POOL MANAGER ---
127
+ # ==============================================================================
128
+
129
+ class LTXWorker:
130
+ """Gerencia uma instância do LTX Pipeline em um par de GPUs (main + vae)."""
131
+ def __init__(self, main_device_str: str, vae_device_str: str, config: dict):
132
+ self.main_device = torch.device(main_device_str)
133
+ self.vae_device = torch.device(vae_device_str)
134
+ self.config = config
135
+ self.pipeline: LTXVideoPipeline = None
136
+
137
+ self._load_and_patch_pipeline()
138
+
139
+ def _load_and_patch_pipeline(self):
140
+ logging.info(f"[LTXWorker-{self.main_device}] Carregando pipeline LTX para a CPU...")
141
+ self.pipeline, _ = build_ltx_pipeline_on_cpu(self.config)
142
+
143
+ logging.info(f"[LTXWorker-{self.main_device}] Movendo pipeline para GPUs (Main: {self.main_device}, VAE: {self.vae_device})...")
144
+ self.pipeline.to(self.main_device)
145
+ self.pipeline.vae.to(self.vae_device)
146
+
147
+ logging.info(f"[LTXWorker-{self.main_device}] Aplicando patch ADUC-SDR na função 'prepare_conditioning'...")
148
+ # Substitui o método da instância pelo nosso patch
149
+ self.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(self.pipeline, LTXVideoPipeline)
150
+ logging.info(f"[LTXWorker-{self.main_device}] ✅ Pipeline 'quente', corrigido e pronto para uso.")
151
+
152
+ class LTXPoolManager:
153
+ _instance = None
154
+ _lock = threading.Lock()
155
+
156
+ def __new__(cls, *args, **kwargs):
157
+ with cls._lock:
158
+ if cls._instance is None:
159
+ cls._instance = super().__new__(cls)
160
+ cls._instance._initialized = False
161
+ return cls._instance
162
+
163
+ def __init__(self):
164
+ if self._initialized: return
165
+ with self._lock:
166
+ if self._initialized: return
167
+
168
+ logging.info("⚙️ Inicializando LTXPoolManager Singleton...")
169
+ self.config = self._load_config()
170
+ self._resolve_model_paths_from_cache()
171
+
172
+ main_device_str = str(gpu_manager.get_ltx_device())
173
+ vae_device_str = str(gpu_manager.get_ltx_vae_device())
174
+
175
+ self.worker = LTXWorker(main_device_str, vae_device_str, self.config)
176
+
177
+ self._initialized = True
178
+ logging.info("✅ LTXPoolManager pronto.")
179
+
180
+ def _load_config(self) -> Dict:
181
+ """Carrega a configuração YAML principal do LTX."""
182
+ config_path = Path("/data/LTX-Video/configs/ltxv-13b-0.9.8-distilled-fp8.yaml")
183
+ with open(config_path, "r") as file:
184
+ return yaml.safe_load(file)
185
+
186
+ def _resolve_model_paths_from_cache(self):
187
+ """Garante que a configuração em memória tenha os caminhos absolutos para os modelos no cache."""
188
+ try:
189
+ main_ckpt_path = hf_hub_download(repo_id="Lightricks/LTX-Video", filename=self.config["checkpoint_path"])
190
+ self.config["checkpoint_path"] = main_ckpt_path
191
+ if self.config.get("spatial_upscaler_model_path"):
192
+ upscaler_path = hf_hub_download(repo_id="Lightricks/LTX-Video", filename=self.config["spatial_upscaler_model_path"])
193
+ self.config["spatial_upscaler_model_path"] = upscaler_path
194
+ except Exception as e:
195
+ logging.critical(f"Falha ao resolver caminhos de modelo LTX. O setup.py foi executado? Erro: {e}", exc_info=True)
196
+ raise
197
+
198
+ def get_pipeline(self) -> LTXVideoPipeline:
199
+ """Retorna a instância do pipeline, já carregada e corrigida."""
200
+ return self.worker.pipeline
201
+
202
+ # --- Instância Singleton Global ---
203
+ # A aplicação importará esta instância para interagir com o LTX.
204
+ try:
205
+ ltx_pool_manager = LTXPoolManager()
206
+ except Exception as e:
207
+ logging.critical("Falha crítica ao inicializar o LTXPoolManager.", exc_info=True)
208
+ ltx_pool_manager = None