eeuuia commited on
Commit
655068e
·
verified ·
1 Parent(s): b521010

Update api/ltx/ltx_aduc_pipeline.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_pipeline.py +232 -140
api/ltx/ltx_aduc_pipeline.py CHANGED
@@ -1,175 +1,267 @@
1
  # FILE: api/ltx/ltx_aduc_pipeline.py
2
- # DESCRIPTION: A unified high-level client for submitting ALL LTX-related jobs (generation and VAE)
3
- # to the LTXAducManager pool.
 
4
 
 
 
5
  import logging
 
 
 
 
6
  import time
7
- import torch
 
8
  import random
9
- from typing import List, Optional, Tuple, Dict
 
 
10
  from PIL import Image
11
- from dataclasses import dataclass
12
- from pathlib import Path
13
- import sys
14
 
15
- from api.ltx.ltx_utils import load_image_to_tensor_with_resize_and_crop # Importa o helper de ltx_utils
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # O cliente importa o MANAGER para submeter todos os trabalhos.
18
- from api.ltx.ltx_aduc_manager import ltx_aduc_manager
19
 
20
- # Adiciona o path do LTX-Video para importações de baixo nível e tipos.
21
- LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
22
  def add_deps_to_path():
23
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
24
  if repo_path not in sys.path:
25
  sys.path.insert(0, repo_path)
26
  add_deps_to_path()
27
 
28
- from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
29
- from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
30
-
 
 
 
 
 
 
31
 
32
  # ==============================================================================
33
- # --- DEFINIÇÕES DE ESTRUTURA ---
34
  # ==============================================================================
35
 
36
- @dataclass
37
- class LatentConditioningItem:
38
- """Estrutura de dados para passar latentes condicionados ao job de geração."""
39
- latent_tensor: torch.Tensor
40
- media_frame_number: int
41
- conditioning_strength: float
42
-
43
- # ==============================================================================
44
- # --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool LTX) ---
45
- # ==============================================================================
46
 
47
- def _job_encode_media(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, pixel_tensor: torch.Tensor) -> torch.Tensor:
48
- """Job que usa o VAE do pipeline para codificar um tensor de pixel."""
49
- vae = pipeline.vae
50
- pixel_tensor_gpu = pixel_tensor.to(vae.device, dtype=vae.dtype)
51
- latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
52
- return latents.cpu()
53
-
54
- def _job_decode_latent(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, latent_tensor: torch.Tensor) -> torch.Tensor:
55
- """Job que usa o VAE do pipeline para decodificar um tensor latente."""
56
- vae = pipeline.vae
57
- latent_tensor_gpu = latent_tensor.to(vae.device, dtype=vae.dtype)
58
- pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
59
- return pixels.cpu()
60
-
61
- def _job_generate_latent_chunk(pipeline: LTXVideoPipeline, autocast_dtype: torch.dtype, **kwargs) -> torch.Tensor:
62
- """Job que usa o pipeline principal para gerar um chunk de vídeo latente."""
63
- generator = torch.Generator(device=pipeline.device).manual_seed(kwargs['seed'])
64
- pipeline_kwargs = {"generator": generator, "output_type": "latent", **kwargs}
65
-
66
- with torch.autocast(device_type=pipeline.device.type, dtype=autocast_dtype):
67
- latents_raw = pipeline(**pipeline_kwargs).images
68
 
69
- return latents_raw.cpu()
 
70
 
71
- # ==============================================================================
72
- # --- A CLASSE CLIENTE UNIFICADA ---
73
- # ==============================================================================
 
 
 
 
74
 
75
- class LtxAducPipeline:
76
- """
77
- Cliente unificado para orquestrar todas as tarefas LTX, incluindo geração e VAE.
78
- """
79
- def __init__(self):
80
- logging.info("✅ Unified LTX/VAE ADUC Pipeline (Client) initialized.")
81
- self.FRAMES_ALIGNMENT = 8
 
 
 
82
 
83
- def _get_random_seed(self) -> int:
84
- return random.randint(0, 2**32 - 1)
 
85
 
86
- def _align(self, dim: int, alignment: int = 8) -> int:
87
- return ((dim + alignment - 1) // alignment) * alignment
88
-
89
- # --- Métodos de API para o Orquestrador ---
90
-
91
- def encode_to_conditioning_items(self, media_list: List, params: List, resolution: Tuple[int, int]) -> List[LatentConditioningItem]:
92
- """Converte uma lista de imagens em uma lista de LatentConditioningItem."""
93
- pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, resolution[0], resolution[1]) for m in media_list]
94
- items = []
95
- for i, pt in enumerate(pixel_tensors):
96
- latent_tensor = ltx_aduc_manager.submit_job(_job_encode_media, pixel_tensor=pt)
97
- frame_number, strength = params[i]
98
- items.append(LatentConditioningItem(
99
- latent_tensor=latent_tensor,
100
- media_frame_number=frame_number,
101
- conditioning_strength=strength
102
- ))
103
- return items
104
-
105
- def decode_to_pixels(self, latent_tensor: torch.Tensor) -> torch.Tensor:
106
- """Decodifica um tensor latente em um tensor de pixels."""
107
- return ltx_aduc_manager.submit_job(_job_decode_latent, latent_tensor=latent_tensor)
108
-
109
- def generate_latents(
110
  self,
111
  prompt_list: List[str],
112
- duration_in_seconds: float,
113
- common_ltx_args: Dict,
114
- initial_conditioning_items: Optional[List[LatentConditioningItem]] = None
115
- ) -> Tuple[Optional[torch.Tensor], Optional[int]]:
116
- """Gera um vídeo latente completo a partir de uma lista de prompts."""
117
- t0 = time.time()
118
- logging.info(f"LTX Client received a generation job for {len(prompt_list)} scenes.")
119
  used_seed = self._get_random_seed()
 
 
120
 
 
 
121
  num_chunks = len(prompt_list)
122
- total_frames = self._align(int(duration_in_seconds * 24))
123
- frames_per_chunk_base = total_frames // num_chunks if num_chunks > 0 else total_frames
124
- overlap_frames = self._align(9) if num_chunks > 1 else 0
125
-
126
- final_latents_list = []
127
- overlap_condition_item = None
128
-
129
- for i, chunk_prompt in enumerate(prompt_list):
130
- current_conditions = []
131
- if i == 0 and initial_conditioning_items:
132
- current_conditions.extend(initial_conditioning_items)
133
- if overlap_condition_item:
134
- current_conditions.append(overlap_condition_item)
135
-
136
- num_frames_for_chunk = frames_per_chunk_base
137
- if i == num_chunks - 1:
138
- processed_frames = sum(f.shape[2] for f in final_latents_list)
139
- num_frames_for_chunk = total_frames - processed_frames
140
- num_frames_for_chunk = self._align(num_frames_for_chunk)
141
- if num_frames_for_chunk <= 0: continue
142
-
143
- job_specific_args = {
144
- "prompt": chunk_prompt,
145
- "num_frames": num_frames_for_chunk,
146
- "seed": used_seed + i,
147
- "conditioning_items": current_conditions
148
- }
149
- final_job_args = {**common_ltx_args, **job_specific_args}
150
-
151
- chunk_latents = ltx_aduc_manager.submit_job(_job_generate_latent_chunk, **final_job_args)
152
 
153
- if chunk_latents is None:
154
- logging.error(f"Failed to generate latents for scene {i+1}. Aborting.")
155
- return None, used_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- if i < num_chunks - 1:
158
- overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
159
- overlap_condition_item = LatentConditioningItem(
160
- latent_tensor=overlap_latents, media_frame_number=0, conditioning_strength=1.0)
161
- final_latents_list.append(chunk_latents[:, :, :-overlap_frames, :, :])
162
- else:
163
- final_latents_list.append(chunk_latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- if not final_latents_list:
166
- logging.warning("No latent chunks were generated.")
167
- return None, used_seed
168
-
169
- final_latents = torch.cat(final_latents_list, dim=2)
170
- logging.info(f"LTX Client job finished in {time.time() - t0:.2f}s. Final latent shape: {final_latents.shape}")
171
 
172
- return final_latents, used_seed
173
 
174
- # --- INSTÂNCIA SINGLETON DO CLIENTE ---
175
- ltx_aduc_pipeline = LtxAducPipeline()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # FILE: api/ltx/ltx_aduc_pipeline.py
2
+ # DESCRIPTION: Final high-level orchestrator for LTX-Video generation.
3
+ # This version acts as a client to the specialized managers (LTX, VAE),
4
+ # focusing solely on the business logic of video generation workflows.
5
 
6
+ import gc
7
+ import json
8
  import logging
9
+ import os
10
+ import shutil
11
+ import sys
12
+ import tempfile
13
  import time
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional, Tuple, Union
16
  import random
17
+ import torch
18
+ import yaml
19
+ import numpy as np
20
  from PIL import Image
 
 
 
21
 
22
+ # ==============================================================================
23
+ # --- SETUP E IMPORTAÇÕES DO PROJETO ---
24
+ # ==============================================================================
25
+
26
+ # Configuração de logging e supressão de warnings
27
+ import warnings
28
+ warnings.filterwarnings("ignore")
29
+ logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
30
+ log_level = os.environ.get("ADUC_LOG_LEVEL", "INFO").upper()
31
+ logging.basicConfig(level=log_level, format='[%(levelname)s] [%(name)s] %(message)s')
32
+
33
+ # --- Constantes de Configuração ---
34
+ DEPS_DIR = Path("/data")
35
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
36
+ RESULTS_DIR = Path("/app/output")
37
+ DEFAULT_FPS = 24.0
38
+ FRAMES_ALIGNMENT = 8
39
+
40
+ from api.ltx.ltx_utils import seed_everything
41
+ from utils.debug_utils import log_function_io
42
 
 
 
43
 
44
+ # Garante que a biblioteca LTX-Video seja importável
 
45
  def add_deps_to_path():
46
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
47
  if repo_path not in sys.path:
48
  sys.path.insert(0, repo_path)
49
  add_deps_to_path()
50
 
51
+ # --- Módulos da nossa Arquitetura ---
52
+ try:
53
+ from api.managers.gpu_manager import gpu_manager
54
+ from api.ltx.ltx_aduc_manager import ltx_pool_manager, LatentConditioningItem
55
+ from api.ltx.vae_aduc_pipeline import vae_server_singleton
56
+ from tools.video_encode_tool import video_encode_tool_singleton
57
+ except ImportError as e:
58
+ logging.critical(f"A crucial import from the local API/architecture failed. Error: {e}", exc_info=True)
59
+ sys.exit(1)
60
 
61
  # ==============================================================================
62
+ # --- CLASSE DE SERVIÇO (O ORQUESTRADOR) ---
63
  # ==============================================================================
64
 
65
+ class VideoService:
66
+ """
67
+ Orchestrates the high-level logic of video generation, delegating all
68
+ low-level tasks to specialized managers and utility modules.
69
+ """
 
 
 
 
 
70
 
71
+ @log_function_io
72
+ def __init__(self):
73
+ t0 = time.time()
74
+ logging.info("Initializing VideoService Orchestrator...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ if ltx_pool_manager is None or vae_server_singleton is None:
77
+ raise RuntimeError("A required manager (LTX or VAE) failed to initialize. Aborting.")
78
 
79
+ self.pipeline = ltx_pool_manager.get_pipeline()
80
+ self.main_device = self.pipeline.device
81
+ self.vae_device = self.pipeline.vae.device
82
+ self.config = ltx_pool_manager.config
83
+
84
+ self._apply_precision_policy()
85
+ logging.info(f"VideoService ready. Using Main: {self.main_device}, VAE: {self.vae_device}. Startup time: {time.time() - t0:.2f}s")
86
 
87
+ def finalize(self):
88
+ """Cleans up GPU memory after a generation task."""
89
+ gc.collect()
90
+ if torch.cuda.is_available():
91
+ with torch.cuda.device(self.main_device):
92
+ torch.cuda.empty_cache()
93
+ with torch.cuda.device(self.vae_device):
94
+ torch.cuda.empty_cache()
95
+ try: torch.cuda.ipc_collect()
96
+ except Exception: pass
97
 
98
+ # ==========================================================================
99
+ # --- LÓGICA DE NEGÓCIO: ORQUESTRADOR PÚBLICO UNIFICADO ---
100
+ # ==========================================================================
101
 
102
+ @log_function_io
103
+ def generate_low_resolution(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  self,
105
  prompt_list: List[str],
106
+ initial_media_items: Optional[List[Tuple[Union[str, Image.Image, torch.Tensor], int, float]]] = None,
107
+ **kwargs
108
+ ) -> Tuple[Optional[str], Optional[str], Optional[int]]:
109
+ """
110
+ [UNIFIED ORCHESTRATOR] Generates a video from a list of prompts and raw media items.
111
+ """
112
+ logging.info("Starting unified low-resolution generation...")
113
  used_seed = self._get_random_seed()
114
+ seed_everything(used_seed)
115
+ logging.info(f"Using randomly generated seed: {used_seed}")
116
 
117
+ if not prompt_list: raise ValueError("Prompt list cannot be empty.")
118
+ is_narrative = len(prompt_list) > 1
119
  num_chunks = len(prompt_list)
120
+ total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
121
+ frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT)
122
+ overlap_frames = 9 if is_narrative else 0
123
+
124
+ initial_conditions = []
125
+ if initial_media_items:
126
+ logging.info("Delegating to VaeServer to prepare initial conditioning items...")
127
+ initial_conditions = vae_server_singleton.generate_conditioning_items(
128
+ media_items=[item[0] for item in initial_media_items],
129
+ target_frames=[item[1] for item in initial_media_items],
130
+ strengths=[item[2] for item in initial_media_items],
131
+ target_resolution=(kwargs['height'], kwargs['width'])
132
+ )
133
+
134
+ temp_latent_paths = []
135
+ overlap_condition_item: Optional[LatentConditioningItem] = None
136
+
137
+ try:
138
+ for i, chunk_prompt in enumerate(prompt_list):
139
+ logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
140
+
141
+ current_frames_base = frames_per_chunk if i < num_chunks - 1 else total_frames - ((num_chunks - 1) * frames_per_chunk)
142
+ current_frames = current_frames_base + (overlap_frames if i > 0 else 0)
143
+ current_frames = self._align(current_frames, alignment_rule='n*8+1')
144
+
145
+ current_conditions = initial_conditions if i == 0 else []
146
+ if overlap_condition_item: current_conditions.append(overlap_condition_item)
 
 
 
147
 
148
+ chunk_latents = self._generate_single_chunk_low(
149
+ prompt=chunk_prompt, num_frames=current_frames, seed=used_seed + i,
150
+ conditioning_items=current_conditions, **kwargs
151
+ )
152
+ if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
153
+
154
+ if is_narrative and i < num_chunks - 1:
155
+ overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
156
+ overlap_condition_item = LatentConditioningItem(
157
+ latent_tensor=overlap_latents.cpu(),
158
+ media_frame_number=0,
159
+ conditioning_strength=1.0
160
+ )
161
+
162
+ if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
163
+
164
+ chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt"
165
+ torch.save(chunk_latents.cpu(), chunk_path)
166
+ temp_latent_paths.append(chunk_path)
167
 
168
+ base_filename = "narrative_video" if is_narrative else "single_video"
169
+ all_tensors_cpu = [torch.load(p) for p in temp_latent_paths]
170
+ final_latents = torch.cat(all_tensors_cpu, dim=2)
171
+
172
+ video_path, latents_path = self._finalize_generation(final_latents, base_filename, used_seed)
173
+ return video_path, latents_path, used_seed
174
+ except Exception as e:
175
+ logging.error(f"Error during unified generation: {e}", exc_info=True)
176
+ return None, None, None
177
+ finally:
178
+ for path in temp_latent_paths:
179
+ if path.exists(): path.unlink()
180
+ self.finalize()
181
+
182
+ # ==========================================================================
183
+ # --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
184
+ # ==========================================================================
185
+
186
+ @log_function_io
187
+ def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]:
188
+ """[WORKER] Calls the patched LTX pipeline to generate a single chunk of latents."""
189
+ height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
190
+ downscale_factor = self.config.get("downscale_factor", 0.6666666)
191
+ vae_scale_factor = self.pipeline.vae_scale_factor
192
+ downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
193
+ downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
194
+
195
+ first_pass_config = self.config.get("first_pass", {}).copy()
196
+ if kwargs.get("ltx_configs_override"):
197
+ self._apply_ui_overrides(first_pass_config, kwargs["ltx_configs_override"])
198
+
199
+ pipeline_kwargs = {
200
+ "prompt": kwargs['prompt'], "negative_prompt": kwargs['negative_prompt'],
201
+ "height": downscaled_height, "width": downscaled_width, "num_frames": kwargs['num_frames'],
202
+ "frame_rate": int(DEFAULT_FPS), "generator": torch.Generator(device=self.main_device).manual_seed(kwargs['seed']),
203
+ "output_type": "latent", "conditioning_items": kwargs['conditioning_items'], **first_pass_config
204
+ }
205
 
206
+ with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
207
+ latents_raw = self.pipeline(**pipeline_kwargs).images
 
 
 
 
208
 
209
+ return latents_raw.to(self.main_device)
210
 
211
+ @log_function_io
212
+ def _finalize_generation(self, final_latents: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]:
213
+ """Delegates final decoding and encoding to specialist services."""
214
+ logging.info("Finalizing generation: decoding latents and encoding video.")
215
+
216
+ final_latents_path = RESULTS_DIR / f"latents_{base_filename}_{seed}.pt"
217
+ torch.save(final_latents, final_latents_path)
218
+ logging.info(f"Final latents saved to: {final_latents_path}")
219
+
220
+ pixel_tensor = vae_server_singleton.decode_to_pixels(
221
+ final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05))
222
+ )
223
+ video_path = self._save_and_log_video(pixel_tensor, f"{base_filename}_{seed}")
224
+ return str(video_path), str(final_latents_path)
225
+
226
+ def _apply_ui_overrides(self, config_dict: Dict, overrides: Dict):
227
+ # (Lógica de overrides da UI)
228
+ pass
229
+
230
+ def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
231
+ with tempfile.TemporaryDirectory() as temp_dir:
232
+ temp_path = os.path.join(temp_dir, f"{base_filename}.mp4")
233
+ video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=DEFAULT_FPS)
234
+ final_path = RESULTS_DIR / f"{base_filename}.mp4"
235
+ shutil.move(temp_path, final_path)
236
+ logging.info(f"Video saved successfully to: {final_path}")
237
+ return final_path
238
+
239
+ def _apply_precision_policy(self):
240
+ precision = str(self.config.get("precision", "bfloat16")).lower()
241
+ if precision in ["float8_e4m3fn", "bfloat16"]: self.runtime_autocast_dtype = torch.bfloat16
242
+ elif precision == "mixed_precision": self.runtime_autocast_dtype = torch.float16
243
+ else: self.runtime_autocast_dtype = torch.float32
244
+ logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}")
245
+
246
+ def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT, alignment_rule: str = 'default') -> int:
247
+ if alignment_rule == 'n*8+1':
248
+ return ((dim - 1) // alignment) * alignment + 1
249
+ return ((dim - 1) // alignment + 1) * alignment
250
+
251
+ def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
252
+ num_frames = int(round(duration_s * DEFAULT_FPS))
253
+ aligned_frames = self._align(num_frames, alignment=FRAMES_ALIGNMENT)
254
+ return max(aligned_frames, min_frames)
255
+
256
+ def _get_random_seed(self) -> int:
257
+ return random.randint(0, 2**32 - 1)
258
+
259
+ # ==============================================================================
260
+ # --- INSTANCIAÇÃO SINGLETON ---
261
+ # ==============================================================================
262
+ try:
263
+ video_generation_service = VideoService()
264
+ logging.info("Global VideoService orchestrator instance created successfully.")
265
+ except Exception as e:
266
+ logging.critical(f"Failed to initialize VideoService: {e}", exc_info=True)
267
+ sys.exit(1)