caarleexx commited on
Commit
f433ba6
·
verified ·
1 Parent(s): 9094e95

Upload ltx-video-complete.py

Browse files
Files changed (1) hide show
  1. api/ltx-video-complete.py +1215 -0
api/ltx-video-complete.py ADDED
@@ -0,0 +1,1215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # ltx_video_service_with_gpu_pools.py
3
+ # VideoService com Multi-GPU Pool Manager Integrado
4
+ # ==============================================================================
5
+ # Arquitetura:
6
+ # - GPU 0 e 1: Pipeline + Upscaler (geração/refinamento de latentes)
7
+ # - GPU 2 e 3: VAE Decode (decodificação de latentes para pixels)
8
+ # ==============================================================================
9
+
10
+ import os
11
+ import sys
12
+ import gc
13
+ import yaml
14
+ import time
15
+ import json
16
+ import random
17
+ import shutil
18
+ import warnings
19
+ import tempfile
20
+ import traceback
21
+ import subprocess
22
+ import threading
23
+ import queue
24
+ from pathlib import Path
25
+ from typing import List, Dict, Optional, Tuple, Union
26
+ from dataclasses import dataclass
27
+ from enum import Enum
28
+ import cv2
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import numpy as np
32
+ from PIL import Image
33
+ from einops import rearrange
34
+ from huggingface_hub import hf_hub_download
35
+ from safetensors import safe_open
36
+
37
+ # --- Configurações ---
38
+ ENABLE_MEMORY_OPTIMIZATION = os.getenv("ADUC_MEMORY_OPTIMIZATION", "1").lower() in ["1", "true", "yes"]
39
+ warnings.filterwarnings("ignore", category=UserWarning)
40
+ warnings.filterwarnings("ignore", category=FutureWarning)
41
+ from huggingface_hub import logging as hf_logging
42
+ hf_logging.set_verbosity_error()
43
+
44
+ # --- Importações de managers ---
45
+ from managers.vae_manager import vae_manager_singleton
46
+ from tools.video_encode_tool import video_encode_tool_singleton
47
+
48
+ # --- Constantes Globais ---
49
+ LTXV_DEBUG = True
50
+ LTXV_FRAME_LOG_EVERY = 8
51
+ DEPS_DIR = Path("/data")
52
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
53
+ RESULTS_DIR = Path("/app/output")
54
+ DEFAULT_FPS = 24.0
55
+
56
+ # ==============================================================================
57
+ # SETUP E IMPORTAÇÕES DO REPOSITÓRIO
58
+ # ==============================================================================
59
+
60
+ def _run_setup_script():
61
+ """Executa o script setup.py se o repositório LTX-Video não existir."""
62
+ setup_script_path = "setup.py"
63
+ if not os.path.exists(setup_script_path):
64
+ print("[DEBUG] 'setup.py' não encontrado. Pulando clonagem de dependências.")
65
+ return
66
+
67
+ print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Executando setup.py...")
68
+ try:
69
+ subprocess.run([sys.executable, setup_script_path], check=True, capture_output=True, text=True)
70
+ print("[DEBUG] Script 'setup.py' concluído com sucesso.")
71
+ except subprocess.CalledProcessError as e:
72
+ print(f"[ERROR] Falha ao executar 'setup.py' (código {e.returncode}).\nOutput:\n{e.stdout}\n{e.stderr}")
73
+ sys.exit(1)
74
+
75
+ def add_deps_to_path(repo_path: Path):
76
+ """Adiciona o diretório do repositório ao sys.path para importações locais."""
77
+ resolved_path = str(repo_path.resolve())
78
+ if resolved_path not in sys.path:
79
+ sys.path.insert(0, resolved_path)
80
+ if LTXV_DEBUG:
81
+ print(f"[DEBUG] Adicionado ao sys.path: {resolved_path}")
82
+
83
+ if not LTX_VIDEO_REPO_DIR.exists():
84
+ _run_setup_script()
85
+ add_deps_to_path(LTX_VIDEO_REPO_DIR)
86
+
87
+ # --- Importações Dependentes do Path Adicionado ---
88
+ from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
89
+ from ltx_video.pipelines.pipeline_ltx_video import adain_filter_latent
90
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
91
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
92
+ from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
93
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
94
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
95
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
96
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
97
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
98
+ import ltx_video.pipelines.crf_compressor as crf_compressor
99
+
100
+ # ==============================================================================
101
+ # GPU POOL MANAGER - Sistema Multi-GPU
102
+ # ==============================================================================
103
+
104
+ class GPUPoolType(Enum):
105
+ """Tipos de pools de GPU disponíveis"""
106
+ GENERATION = "generation" # Pipeline + Upscaler
107
+ DECODE = "decode" # VAE Decode
108
+
109
+
110
+ @dataclass
111
+ class GPUTask:
112
+ """Representa uma tarefa a ser executada em uma GPU"""
113
+ task_id: str
114
+ task_fn: callable
115
+ args: tuple
116
+ kwargs: dict
117
+ result_queue: queue.Queue
118
+
119
+
120
+ @dataclass
121
+ class GPUWorker:
122
+ """Representa um worker de GPU individual"""
123
+ worker_id: int
124
+ device_id: str
125
+ pool_type: GPUPoolType
126
+ thread: Optional[threading.Thread] = None
127
+ is_busy: bool = False
128
+
129
+
130
+ class GPUPoolManager:
131
+ """
132
+ Gerenciador de pools de GPU para distribuição de tarefas.
133
+
134
+ Arquitetura:
135
+ - Pool 1 (GENERATION): 2 GPUs para pipeline + upscaler
136
+ - Pool 2 (DECODE): 2 GPUs para VAE decode
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ generation_devices: List[str] = None,
142
+ decode_devices: List[str] = None,
143
+ max_queue_size: int = 10
144
+ ):
145
+ """Inicializa o gerenciador de pools."""
146
+ self.generation_devices = generation_devices or ["cuda:0", "cuda:1"]
147
+ self.decode_devices = decode_devices or ["cuda:2", "cuda:3"]
148
+
149
+ self.generation_queue = queue.Queue(maxsize=max_queue_size)
150
+ self.decode_queue = queue.Queue(maxsize=max_queue_size)
151
+
152
+ self.generation_workers: List[GPUWorker] = []
153
+ self.decode_workers: List[GPUWorker] = []
154
+
155
+ self._shutdown = False
156
+ self._lock = threading.Lock()
157
+
158
+ self.stats = {
159
+ "generation_tasks_completed": 0,
160
+ "decode_tasks_completed": 0,
161
+ "generation_tasks_failed": 0,
162
+ "decode_tasks_failed": 0,
163
+ }
164
+
165
+ self._initialize_workers()
166
+
167
+ def _initialize_workers(self):
168
+ """Inicializa todos os workers de GPU"""
169
+ print("[GPU Pool Manager] Inicializando workers...")
170
+
171
+ for i, device in enumerate(self.generation_devices):
172
+ worker = GPUWorker(
173
+ worker_id=i,
174
+ device_id=device,
175
+ pool_type=GPUPoolType.GENERATION
176
+ )
177
+ worker.thread = threading.Thread(
178
+ target=self._worker_loop,
179
+ args=(worker, self.generation_queue),
180
+ daemon=True
181
+ )
182
+ worker.thread.start()
183
+ self.generation_workers.append(worker)
184
+ print(f" ✓ Generation Worker {i} iniciado em {device}")
185
+
186
+ for i, device in enumerate(self.decode_devices):
187
+ worker = GPUWorker(
188
+ worker_id=i,
189
+ device_id=device,
190
+ pool_type=GPUPoolType.DECODE
191
+ )
192
+ worker.thread = threading.Thread(
193
+ target=self._worker_loop,
194
+ args=(worker, self.decode_queue),
195
+ daemon=True
196
+ )
197
+ worker.thread.start()
198
+ self.decode_workers.append(worker)
199
+ print(f" ✓ Decode Worker {i} iniciado em {device}")
200
+
201
+ print(f"[GPU Pool Manager] {len(self.generation_workers)} workers de GERAÇÃO e {len(self.decode_workers)} workers de DECODE ativos.\n")
202
+
203
+ def _worker_loop(self, worker: GPUWorker, task_queue: queue.Queue):
204
+ """Loop principal de um worker."""
205
+ print(f"[Worker {worker.worker_id}:{worker.device_id}] Aguardando tarefas ({worker.pool_type.value})...")
206
+
207
+ while not self._shutdown:
208
+ try:
209
+ task: GPUTask = task_queue.get(timeout=1.0)
210
+
211
+ with self._lock:
212
+ worker.is_busy = True
213
+
214
+ print(f"[Worker {worker.worker_id}:{worker.device_id}] Executando tarefa '{task.task_id}'...")
215
+
216
+ try:
217
+ torch.cuda.set_device(worker.device_id)
218
+ result = task.task_fn(
219
+ worker.device_id,
220
+ *task.args,
221
+ **task.kwargs
222
+ )
223
+ task.result_queue.put(("success", result))
224
+
225
+ with self._lock:
226
+ if worker.pool_type == GPUPoolType.GENERATION:
227
+ self.stats["generation_tasks_completed"] += 1
228
+ else:
229
+ self.stats["decode_tasks_completed"] += 1
230
+
231
+ print(f"[Worker {worker.worker_id}:{worker.device_id}] Tarefa '{task.task_id}' concluída com sucesso.")
232
+
233
+ except Exception as e:
234
+ print(f"[Worker {worker.worker_id}:{worker.device_id}] ERRO na tarefa '{task.task_id}': {e}")
235
+ import traceback
236
+ traceback.print_exc()
237
+
238
+ task.result_queue.put(("error", str(e)))
239
+
240
+ with self._lock:
241
+ if worker.pool_type == GPUPoolType.GENERATION:
242
+ self.stats["generation_tasks_failed"] += 1
243
+ else:
244
+ self.stats["decode_tasks_failed"] += 1
245
+
246
+ finally:
247
+ with self._lock:
248
+ worker.is_busy = False
249
+ task_queue.task_done()
250
+ torch.cuda.empty_cache()
251
+
252
+ except queue.Empty:
253
+ continue
254
+
255
+ def submit_generation_task(
256
+ self,
257
+ task_id: str,
258
+ task_fn: callable,
259
+ *args,
260
+ **kwargs
261
+ ) -> queue.Queue:
262
+ """Submete uma tarefa de GERAÇÃO ao pool."""
263
+ result_queue = queue.Queue(maxsize=1)
264
+ task = GPUTask(
265
+ task_id=task_id,
266
+ task_fn=task_fn,
267
+ args=args,
268
+ kwargs=kwargs,
269
+ result_queue=result_queue
270
+ )
271
+
272
+ print(f"[GPU Pool Manager] Submetendo tarefa de GERAÇÃO: '{task_id}'")
273
+ self.generation_queue.put(task)
274
+ return result_queue
275
+
276
+ def submit_decode_task(
277
+ self,
278
+ task_id: str,
279
+ task_fn: callable,
280
+ *args,
281
+ **kwargs
282
+ ) -> queue.Queue:
283
+ """Submete uma tarefa de DECODE ao pool."""
284
+ result_queue = queue.Queue(maxsize=1)
285
+ task = GPUTask(
286
+ task_id=task_id,
287
+ task_fn=task_fn,
288
+ args=args,
289
+ kwargs=kwargs,
290
+ result_queue=result_queue
291
+ )
292
+
293
+ print(f"[GPU Pool Manager] Submetendo tarefa de DECODE: '{task_id}'")
294
+ self.decode_queue.put(task)
295
+ return result_queue
296
+
297
+ def get_result(self, result_queue: queue.Queue, timeout: Optional[float] = None):
298
+ """Aguarda e retorna o resultado de uma tarefa."""
299
+ status, result = result_queue.get(timeout=timeout)
300
+
301
+ if status == "error":
302
+ raise Exception(f"Tarefa falhou: {result}")
303
+
304
+ return result
305
+
306
+ def submit_and_wait_generation(
307
+ self,
308
+ task_id: str,
309
+ task_fn: callable,
310
+ *args,
311
+ timeout: Optional[float] = None,
312
+ **kwargs
313
+ ):
314
+ """Submete uma tarefa de geração e aguarda o resultado (bloqueante)."""
315
+ result_queue = self.submit_generation_task(task_id, task_fn, *args, **kwargs)
316
+ return self.get_result(result_queue, timeout=timeout)
317
+
318
+ def submit_and_wait_decode(
319
+ self,
320
+ task_id: str,
321
+ task_fn: callable,
322
+ *args,
323
+ timeout: Optional[float] = None,
324
+ **kwargs
325
+ ):
326
+ """Submete uma tarefa de decode e aguarda o resultado (bloqueante)."""
327
+ result_queue = self.submit_decode_task(task_id, task_fn, *args, **kwargs)
328
+ return self.get_result(result_queue, timeout=timeout)
329
+
330
+ def wait_all(self):
331
+ """Aguarda todas as tarefas pendentes serem concluídas"""
332
+ print("[GPU Pool Manager] Aguardando conclusão de todas as tarefas...")
333
+ self.generation_queue.join()
334
+ self.decode_queue.join()
335
+ print("[GPU Pool Manager] Todas as tarefas concluídas.")
336
+
337
+ def get_stats(self) -> dict:
338
+ """Retorna estatísticas de uso do pool"""
339
+ with self._lock:
340
+ return {
341
+ **self.stats,
342
+ "generation_queue_size": self.generation_queue.qsize(),
343
+ "decode_queue_size": self.decode_queue.qsize(),
344
+ "generation_workers_busy": sum(1 for w in self.generation_workers if w.is_busy),
345
+ "decode_workers_busy": sum(1 for w in self.decode_workers if w.is_busy),
346
+ }
347
+
348
+ def print_stats(self):
349
+ """Imprime estatísticas formatadas"""
350
+ stats = self.get_stats()
351
+ print("\n" + "="*60)
352
+ print("GPU POOL MANAGER - ESTATÍSTICAS")
353
+ print("="*60)
354
+ print(f"Generation Pool:")
355
+ print(f" - Tarefas Concluídas: {stats['generation_tasks_completed']}")
356
+ print(f" - Tarefas Falhadas: {stats['generation_tasks_failed']}")
357
+ print(f" - Workers Ocupados: {stats['generation_workers_busy']}/{len(self.generation_workers)}")
358
+ print(f" - Fila: {stats['generation_queue_size']} tarefas")
359
+ print(f"\nDecode Pool:")
360
+ print(f" - Tarefas Concluídas: {stats['decode_tasks_completed']}")
361
+ print(f" - Tarefas Falhadas: {stats['decode_tasks_failed']}")
362
+ print(f" - Workers Ocupados: {stats['decode_workers_busy']}/{len(self.decode_workers)}")
363
+ print(f" - Fila: {stats['decode_queue_size']} tarefas")
364
+ print("="*60 + "\n")
365
+
366
+ def shutdown(self):
367
+ """Encerra todos os workers"""
368
+ print("[GPU Pool Manager] Encerrando...")
369
+ self._shutdown = True
370
+
371
+ for worker in self.generation_workers + self.decode_workers:
372
+ if worker.thread:
373
+ worker.thread.join(timeout=5.0)
374
+
375
+ print("[GPU Pool Manager] Encerrado.")
376
+
377
+
378
+ # Singleton global
379
+ _gpu_pool_manager_instance: Optional[GPUPoolManager] = None
380
+
381
+
382
+ def get_gpu_pool_manager(
383
+ generation_devices: List[str] = None,
384
+ decode_devices: List[str] = None,
385
+ force_reinit: bool = False
386
+ ) -> GPUPoolManager:
387
+ """Retorna a instância singleton do GPUPoolManager."""
388
+ global _gpu_pool_manager_instance
389
+
390
+ if _gpu_pool_manager_instance is None or force_reinit:
391
+ if _gpu_pool_manager_instance and force_reinit:
392
+ _gpu_pool_manager_instance.shutdown()
393
+
394
+ _gpu_pool_manager_instance = GPUPoolManager(
395
+ generation_devices=generation_devices,
396
+ decode_devices=decode_devices
397
+ )
398
+
399
+ return _gpu_pool_manager_instance
400
+
401
+ # ==============================================================================
402
+ # FUNÇÕES AUXILIARES DE PROCESSAMENTO
403
+ # ==============================================================================
404
+
405
+ def debug_log(message: str):
406
+ """Log condicional baseado em LTXV_DEBUG"""
407
+ if LTXV_DEBUG:
408
+ print(f"[DEBUG] {message}")
409
+
410
+ def load_image_cv2(image_path: str, target_height: int, target_width: int) -> np.ndarray:
411
+ """Carrega uma imagem usando OpenCV e redimensiona"""
412
+ image = cv2.imread(image_path)
413
+ if image is None:
414
+ raise ValueError(f"Não foi possível carregar a imagem: {image_path}")
415
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
416
+ image = cv2.resize(image, (target_width, target_height), interpolation=cv2.INTER_LINEAR)
417
+ return image
418
+
419
+ def normalize_image(image: np.ndarray) -> np.ndarray:
420
+ """Normaliza imagem para [-1, 1]"""
421
+ image = image.astype(np.float32) / 127.5 - 1.0
422
+ return image
423
+
424
+ def denormalize_image(image: np.ndarray) -> np.ndarray:
425
+ """Desnormaliza imagem de [-1, 1] para [0, 255]"""
426
+ image = (image + 1.0) * 127.5
427
+ return np.clip(image, 0, 255).astype(np.uint8)
428
+
429
+ # ==============================================================================
430
+ # CLASSE PRINCIPAL DO SERVIÇO DE VÍDEO COM GPU POOLS
431
+ # ==============================================================================
432
+
433
+ class VideoService:
434
+ """
435
+ Serviço de Geração de Vídeos com LTX Video e Multi-GPU Pool Manager.
436
+
437
+ Arquitetura de GPUs:
438
+ - GPU 0 e 1: Pipeline + Upscaler (GENERATION pool)
439
+ - GPU 2 e 3: VAE Decode (DECODE pool)
440
+ """
441
+
442
+ def __init__(self):
443
+ """Inicializa o serviço com GPU Pools"""
444
+ print("[VideoService] Inicializando com Multi-GPU Pools...")
445
+
446
+ # Inicializa o pool manager
447
+ self.gpu_pool = get_gpu_pool_manager(
448
+ generation_devices=["cuda:0", "cuda:1"],
449
+ decode_devices=["cuda:2", "cuda:3"]
450
+ )
451
+
452
+ # Carrega configuração
453
+ self.config = self._load_config("ltxv-13b-0.9.8-distilled-fp8.yaml")
454
+
455
+ # Carrega modelos (template que será clonado para cada GPU)
456
+ self.pipeline_template, self.latent_upsampler_template = self._load_models_from_hub()
457
+
458
+ # Inicializa pipelines em cada GPU de geração
459
+ self.generation_models = {}
460
+ for device in ["cuda:0", "cuda:1"]:
461
+ self.generation_models[device] = self._clone_pipeline_to_device(device)
462
+
463
+ # Inicializa VAE em cada GPU de decode
464
+ self.decode_models = {}
465
+ for device in ["cuda:2", "cuda:3"]:
466
+ self.decode_models[device] = self._clone_vae_to_device(device)
467
+
468
+ # Configurações de tempo de execução
469
+ self.runtime_autocast_dtype = self._get_precision_dtype()
470
+
471
+ # Anexa pipeline ao vae_manager_singleton
472
+ vae_manager_singleton.attach_pipeline(
473
+ self.pipeline_template,
474
+ device="cuda:0",
475
+ autocast_dtype=self.runtime_autocast_dtype
476
+ )
477
+
478
+ # Rastreamento de seed
479
+ self.used_seed = None
480
+ self.tmp_dir = None
481
+ self._register_tmp_dir()
482
+
483
+ print("[VideoService] Inicializado com sucesso!")
484
+ print("[VideoService] Pools de GPU ativos:")
485
+ print("[VideoService] - Geração: cuda:0, cuda:1")
486
+ print("[VideoService] - Decode: cuda:2, cuda:3")
487
+
488
+ def _clone_pipeline_to_device(self, device: str) -> Dict:
489
+ """Clona a pipeline para um dispositivo específico"""
490
+ print(f" Clonando pipeline para {device}...")
491
+ pipeline = {
492
+ 'transformer': self.pipeline_template.transformer.to(device),
493
+ 'text_encoder': self.pipeline_template.text_encoder.to(device),
494
+ 'scheduler': self.pipeline_template.scheduler,
495
+ 'tokenizer': self.pipeline_template.tokenizer,
496
+ 'patchifier': self.pipeline_template.patchifier,
497
+ }
498
+
499
+ if self.latent_upsampler_template:
500
+ pipeline['upsampler'] = self.latent_upsampler_template.to(device)
501
+
502
+ return pipeline
503
+
504
+ def _clone_vae_to_device(self, device: str) -> torch.nn.Module:
505
+ """Clona o VAE para um dispositivo específico"""
506
+ print(f" Clonando VAE para {device}...")
507
+ vae = self.pipeline_template.vae.to(device)
508
+ vae.eval()
509
+ return vae
510
+
511
+ # ==============================================================================
512
+ # FUNÇÕES WORKER PARA POOL MANAGER
513
+ # ==============================================================================
514
+
515
+ def _generate_latents_worker(
516
+ self,
517
+ device_id: str,
518
+ prompt: str,
519
+ negative_prompt: str,
520
+ height: int,
521
+ width: int,
522
+ num_frames: int,
523
+ guidance_scale: float,
524
+ seed: int,
525
+ conditioning_items: Optional[List] = None
526
+ ) -> torch.Tensor:
527
+ """Worker para geração de latentes (roda em cuda:0 ou cuda:1)"""
528
+ print(f" [Generation Worker] Gerando latentes em {device_id}")
529
+
530
+ generator = torch.Generator(device=device_id).manual_seed(seed)
531
+
532
+ with torch.autocast(device_type='cuda', dtype=self.runtime_autocast_dtype):
533
+ kwargs = {
534
+ "prompt": prompt,
535
+ "negative_prompt": negative_prompt,
536
+ "height": height,
537
+ "width": width,
538
+ "num_frames": num_frames,
539
+ "frame_rate": int(DEFAULT_FPS),
540
+ "generator": generator,
541
+ "output_type": "latent",
542
+ "guidance_scale": float(guidance_scale),
543
+ "conditioning_items": conditioning_items,
544
+ **self.config.get("first_pass", {})
545
+ }
546
+
547
+ latents = self.pipeline_template(**kwargs).images
548
+
549
+ # Aplica upsampler se disponível
550
+ if 'upsampler' in self.generation_models[device_id]:
551
+ latents = self._upsample_and_filter_latents(
552
+ latents,
553
+ self.generation_models[device_id]['upsampler'],
554
+ device_id
555
+ )
556
+
557
+ return latents.cpu()
558
+
559
+ def _refine_latents_worker(
560
+ self,
561
+ device_id: str,
562
+ latents: torch.Tensor,
563
+ prompt: str,
564
+ negative_prompt: str,
565
+ guidance_scale: float,
566
+ seed: int,
567
+ conditioning_items: Optional[List] = None
568
+ ) -> torch.Tensor:
569
+ """Worker para refinamento de latentes (roda em cuda:0 ou cuda:1)"""
570
+ print(f" [Refine Worker] Refinando latentes em {device_id}")
571
+
572
+ latents = latents.to(device_id)
573
+
574
+ with torch.autocast(device_type='cuda', dtype=self.runtime_autocast_dtype):
575
+ refine_height = latents.shape[3] * 8 # vae_scale_factor
576
+ refine_width = latents.shape[4] * 8
577
+
578
+ kwargs = {
579
+ "prompt": prompt,
580
+ "negative_prompt": negative_prompt,
581
+ "height": refine_height,
582
+ "width": refine_width,
583
+ "frame_rate": int(DEFAULT_FPS),
584
+ "num_frames": latents.shape[2],
585
+ "latents": latents,
586
+ "guidance_scale": float(guidance_scale),
587
+ "output_type": "latent",
588
+ "generator": torch.Generator(device=device_id).manual_seed(seed),
589
+ "conditioning_items": conditioning_items,
590
+ **self.config.get("second_pass", {})
591
+ }
592
+
593
+ refined_latents = self.pipeline_template(**kwargs).images
594
+
595
+ return refined_latents.cpu()
596
+
597
+ def _decode_latents_worker(
598
+ self,
599
+ device_id: str,
600
+ latents: torch.Tensor,
601
+ decode_timestep: float = 0.05
602
+ ) -> torch.Tensor:
603
+ """Worker para decodificação de latentes (roda em cuda:2 ou cuda:3)"""
604
+ print(f" [Decode Worker] Decodificando em {device_id} (shape: {latents.shape})")
605
+
606
+ latents = latents.to(device_id)
607
+ vae = self.decode_models[device_id]
608
+
609
+ with torch.no_grad():
610
+ with torch.autocast(device_type='cuda', dtype=self.runtime_autocast_dtype):
611
+ pixel_tensor = vae_manager_singleton.decode(
612
+ latents,
613
+ decode_timestep=decode_timestep
614
+ )
615
+
616
+ return pixel_tensor.cpu()
617
+
618
+ # ==============================================================================
619
+ # MÉTODOS DE PREPARAÇÃO DE DADOS
620
+ # ==============================================================================
621
+
622
+ def _load_image_to_tensor_with_resize_and_crop(
623
+ self,
624
+ image_path: str,
625
+ target_height: int,
626
+ target_width: int,
627
+ padding_values: tuple = (0, 0, 0)
628
+ ) -> torch.Tensor:
629
+ """Carrega uma imagem, redimensiona e converte para tensor"""
630
+ image = load_image_cv2(image_path, target_height, target_width)
631
+ image = normalize_image(image)
632
+ tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()
633
+ return tensor
634
+
635
+ def _prepare_conditioning_tensor(
636
+ self,
637
+ image_path: str,
638
+ target_height: int,
639
+ target_width: int,
640
+ padding_values: tuple = (0, 0, 0)
641
+ ) -> torch.Tensor:
642
+ """Prepara tensor de condicionamento de uma imagem"""
643
+ return self._load_image_to_tensor_with_resize_and_crop(
644
+ image_path,
645
+ target_height,
646
+ target_width,
647
+ padding_values
648
+ )
649
+
650
+ def _prepare_conditioning_tensor_from_path(self, image_path: str) -> torch.Tensor:
651
+ """Prepara tensor de condicionamento com resolução padrão"""
652
+ return self._prepare_conditioning_tensor(image_path, 512, 768, (0, 0, 0))
653
+
654
+ # ==============================================================================
655
+ # MÉTODOS DE CÁLCULO E PROCESSAMENTO
656
+ # ==============================================================================
657
+
658
+ def _calculate_downscaled_dims(self, height: int, width: int) -> Tuple[int, int]:
659
+ """Calcula dimensões reduzidas para primeira passagem"""
660
+ downscale_factor = 4
661
+ return height // downscale_factor, width // downscale_factor
662
+
663
+ def _calculate_dynamic_cuts(
664
+ self,
665
+ total_latents: int,
666
+ min_chunk_size: int = 8,
667
+ overlap: int = 2
668
+ ) -> Tuple[List[Tuple[int, int]], List[int]]:
669
+ """Calcula pontos de corte dinâmicos para chunks com overlap"""
670
+ cut_points = []
671
+ segment_sizes = []
672
+
673
+ start = 0
674
+ while start < total_latents:
675
+ end = min(start + min_chunk_size, total_latents)
676
+ cut_points.append((start, end))
677
+ segment_sizes.append(end - start)
678
+
679
+ if end >= total_latents:
680
+ break
681
+
682
+ start = end - overlap
683
+
684
+ return cut_points, segment_sizes
685
+
686
+ def _split_latents_with_overlap(
687
+ self,
688
+ latents: torch.Tensor,
689
+ chunk_size: int = 8,
690
+ overlap: int = 2
691
+ ) -> List[torch.Tensor]:
692
+ """Divide latentes em chunks com overlap"""
693
+ chunks = []
694
+ start = 0
695
+ total_frames = latents.shape[2]
696
+
697
+ while start < total_frames:
698
+ end = min(start + chunk_size, total_frames)
699
+ chunk = latents[:, :, start:end, :, :]
700
+ chunks.append(chunk)
701
+
702
+ if end >= total_frames:
703
+ break
704
+
705
+ start = end - overlap
706
+
707
+ return chunks
708
+
709
+ def _merge_chunks_with_overlap(
710
+ self,
711
+ chunks: List[torch.Tensor],
712
+ overlap: int = 2
713
+ ) -> torch.Tensor:
714
+ """Costura chunks removendo overlap"""
715
+ if len(chunks) == 1:
716
+ return chunks[0]
717
+
718
+ overlap_pixels = overlap * 8 # 8 = VAE scale factor
719
+
720
+ result_parts = [chunks[0][:, :, :-overlap_pixels, :, :]]
721
+
722
+ for chunk in chunks[1:-1]:
723
+ result_parts.append(chunk[:, :, overlap_pixels:-overlap_pixels, :, :])
724
+
725
+ if len(chunks) > 1:
726
+ result_parts.append(chunks[-1][:, :, overlap_pixels:, :, :])
727
+
728
+ return torch.cat(result_parts, dim=2)
729
+
730
+ def _stitch_dynamic_chunks(
731
+ self,
732
+ pixel_chunks: List[torch.Tensor],
733
+ segment_sizes: List[int],
734
+ macro_overlap: int = 2
735
+ ) -> torch.Tensor:
736
+ """Costura chunks dinâmicos com tratamento de overlap"""
737
+ if len(pixel_chunks) == 1:
738
+ return pixel_chunks[0]
739
+
740
+ overlap_frames = macro_overlap * 8
741
+ stitched_parts = []
742
+
743
+ for i, chunk in enumerate(pixel_chunks):
744
+ if i == 0:
745
+ stitched_parts.append(chunk[:, :, :-overlap_frames, :, :])
746
+ elif i == len(pixel_chunks) - 1:
747
+ stitched_parts.append(chunk[:, :, overlap_frames:, :, :])
748
+ else:
749
+ stitched_parts.append(chunk[:, :, overlap_frames:-overlap_frames, :, :])
750
+
751
+ return torch.cat(stitched_parts, dim=2)
752
+
753
+ def _upsample_and_filter_latents(
754
+ self,
755
+ latents: torch.Tensor,
756
+ upsampler: torch.nn.Module,
757
+ device: str
758
+ ) -> torch.Tensor:
759
+ """Aplica upsampler e filtro aos latentes"""
760
+ latents = latents.to(device)
761
+
762
+ with torch.no_grad():
763
+ with torch.autocast(device_type='cuda', dtype=self.runtime_autocast_dtype):
764
+ upsampled = upsampler(latents)
765
+ filtered = adain_filter_latent(upsampled, latents)
766
+
767
+ return filtered
768
+
769
+ # ==============================================================================
770
+ # MÉTODOS DE GERAÇÃO E REFINAMENTO (USANDO POOL MANAGER)
771
+ # ==============================================================================
772
+
773
+ def generate_low_resolution(
774
+ self,
775
+ prompt: str,
776
+ negative_prompt: str,
777
+ height: int,
778
+ width: int,
779
+ duration_secs: float,
780
+ guidance_scale: float,
781
+ seed: Optional[int] = None,
782
+ image_filepaths: Optional[List[str]] = None
783
+ ) -> Tuple[str, int]:
784
+ """Gera vídeo em baixa resolução usando pool de geração"""
785
+ print("[INFO] Iniciando geração em baixa resolução (modo paralelo)...")
786
+
787
+ used_seed = seed or random.randint(0, 2**32 - 1)
788
+ self._seed_everething(used_seed)
789
+
790
+ actual_num_frames = int(round(duration_secs * DEFAULT_FPS))
791
+ downscaled_height, downscaled_width = self._calculate_downscaled_dims(height, width)
792
+
793
+ conditioning_items = []
794
+ if image_filepaths:
795
+ for filepath in image_filepaths:
796
+ cond_tensor = self._prepare_conditioning_tensor(
797
+ filepath,
798
+ downscaled_height,
799
+ downscaled_width,
800
+ (0, 0, 0)
801
+ )
802
+ conditioning_items.append(ConditioningItem(cond_tensor, 0, 1.0))
803
+
804
+ # Submete tarefa de geração ao pool
805
+ task_id = f"gen_lowres_{used_seed}"
806
+ latents = self.gpu_pool.submit_and_wait_generation(
807
+ task_id=task_id,
808
+ task_fn=self._generate_latents_worker,
809
+ prompt=prompt,
810
+ negative_prompt=negative_prompt,
811
+ height=downscaled_height,
812
+ width=downscaled_width,
813
+ num_frames=(actual_num_frames // 8) + 1,
814
+ guidance_scale=guidance_scale,
815
+ seed=used_seed,
816
+ conditioning_items=conditioning_items if conditioning_items else None,
817
+ timeout=600
818
+ )
819
+
820
+ tensor_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
821
+
822
+ print("[SUCCESS] Geração de baixa resolução concluída!")
823
+ self.used_seed = used_seed
824
+
825
+ return tensor_path, used_seed
826
+
827
+ def refine_texture_only(
828
+ self,
829
+ latents_path: str,
830
+ prompt: str,
831
+ negative_prompt: str,
832
+ guidance_scale: float,
833
+ seed: int,
834
+ image_filepaths: Optional[List[str]] = None,
835
+ macro_chunk_size: int = 8,
836
+ macro_overlap: int = 2
837
+ ) -> Tuple[str, str, torch.Tensor]:
838
+ """Refina e decodifica latentes usando ambos os pools em paralelo"""
839
+ print("[INFO] Iniciando refinamento e decodificação paralela...")
840
+
841
+ initial_latents = torch.load(latents_path).cpu()
842
+ total_latents = initial_latents.shape[2]
843
+ height = initial_latents.shape[3] * 8
844
+ width = initial_latents.shape[4] * 8
845
+
846
+ cut_points, segment_sizes = self._calculate_dynamic_cuts(
847
+ total_latents,
848
+ min_chunk_size=macro_chunk_size,
849
+ overlap=macro_overlap
850
+ )
851
+
852
+ print(f" Processando {len(cut_points)} chunks em paralelo...")
853
+
854
+ # Prepara conditioning se fornecido
855
+ conditioning_items = []
856
+ if image_filepaths:
857
+ for filepath in image_filepaths:
858
+ cond_tensor = self._prepare_conditioning_tensor(
859
+ filepath,
860
+ height,
861
+ width,
862
+ (0, 0, 0)
863
+ )
864
+ conditioning_items.append(ConditioningItem(cond_tensor, 0, 1.0))
865
+
866
+ pixel_results = []
867
+
868
+ for i, (start, end) in enumerate(cut_points):
869
+ chunk_id = f"chunk_{i}_seed_{seed}"
870
+ latent_chunk = initial_latents[:, :, start:end, :, :]
871
+
872
+ # ETAPA 1: Refinar latentes (pool de geração)
873
+ print(f"\n [{i+1}/{len(cut_points)}] Refinando chunk {start}-{end}...")
874
+ refined_latents = self.gpu_pool.submit_and_wait_generation(
875
+ task_id=f"refine_{chunk_id}",
876
+ task_fn=self._refine_latents_worker,
877
+ latents=latent_chunk,
878
+ prompt=prompt,
879
+ negative_prompt=negative_prompt,
880
+ guidance_scale=guidance_scale,
881
+ seed=seed + i,
882
+ conditioning_items=conditioning_items if conditioning_items else None,
883
+ timeout=600
884
+ )
885
+
886
+ # ETAPA 2: Decodificar latentes (pool de decode)
887
+ print(f" [{i+1}/{len(cut_points)}] Decodificando chunk {start}-{end}...")
888
+ pixel_tensor = self.gpu_pool.submit_and_wait_decode(
889
+ task_id=f"decode_{chunk_id}",
890
+ task_fn=self._decode_latents_worker,
891
+ latents=refined_latents,
892
+ decode_timestep=float(self.config.get("decode_timestep", 0.05)),
893
+ timeout=300
894
+ )
895
+
896
+ pixel_results.append(pixel_tensor)
897
+
898
+ del refined_latents
899
+ torch.cuda.empty_cache()
900
+
901
+ # Costura resultados
902
+ print("\n Costurando chunks finais...")
903
+ final_pixel_tensor = self._stitch_dynamic_chunks(
904
+ pixel_results,
905
+ segment_sizes,
906
+ macro_overlap
907
+ )
908
+
909
+ final_video_path = self._save_video_from_tensor(
910
+ final_pixel_tensor,
911
+ "final_video",
912
+ seed
913
+ )
914
+
915
+ print(f"[SUCCESS] Vídeo final salvo em: {final_video_path}")
916
+ self.gpu_pool.print_stats()
917
+
918
+ return final_video_path, latents_path, final_pixel_tensor
919
+
920
+ def apply_secondary_refinement(
921
+ self,
922
+ initial_latents_path: str,
923
+ prompt: str,
924
+ negative_prompt: str,
925
+ guidance_scale: float,
926
+ seed: int,
927
+ image_filepaths: Optional[List[str]] = None
928
+ ) -> str:
929
+ """Aplica refinamento secundário em múltiplos chunks"""
930
+ print("[INFO] Aplicando refinamento secundário...")
931
+
932
+ initial_latents = torch.load(initial_latents_path).cpu()
933
+ total_latents = initial_latents.shape[2]
934
+
935
+ # Divide em chunks maiores
936
+ macro_chunk_size = 16
937
+ macro_overlap = 2
938
+
939
+ cut_points, segment_sizes = self._calculate_dynamic_cuts(
940
+ total_latents,
941
+ min_chunk_size=macro_chunk_size,
942
+ overlap=macro_overlap
943
+ )
944
+
945
+ height = initial_latents.shape[3] * 8
946
+ width = initial_latents.shape[4] * 8
947
+
948
+ conditioning_items = []
949
+ if image_filepaths:
950
+ for filepath in image_filepaths:
951
+ cond_tensor = self._prepare_conditioning_tensor(
952
+ filepath, height, width, (0, 0, 0)
953
+ )
954
+ conditioning_items.append(ConditioningItem(cond_tensor, 0, 1.0))
955
+
956
+ print(f" Refinando {len(cut_points)} chunks...")
957
+
958
+ # Submete TODAS as tarefas de refinamento
959
+ refine_queues = []
960
+ for i, (start, end) in enumerate(cut_points):
961
+ latent_chunk = initial_latents[:, :, start:end, :, :]
962
+
963
+ queue = self.gpu_pool.submit_generation_task(
964
+ task_id=f"refine_macro_{i}",
965
+ task_fn=self._refine_latents_worker,
966
+ latents=latent_chunk,
967
+ prompt=prompt,
968
+ negative_prompt=negative_prompt,
969
+ guidance_scale=guidance_scale,
970
+ seed=seed + i,
971
+ conditioning_items=conditioning_items if conditioning_items else None
972
+ )
973
+ refine_queues.append((i, queue))
974
+
975
+ # Processa decodes conforme refinamentos ficam prontos
976
+ print(f"\n Decodificando chunks refinados...")
977
+ decode_queues = []
978
+
979
+ for i, refine_queue in refine_queues:
980
+ refined_latents = self.gpu_pool.get_result(refine_queue, timeout=600)
981
+ print(f" ✓ Chunk {i} refinado")
982
+
983
+ decode_queue = self.gpu_pool.submit_decode_task(
984
+ task_id=f"decode_macro_{i}",
985
+ task_fn=self._decode_latents_worker,
986
+ latents=refined_latents,
987
+ decode_timestep=float(self.config.get("decode_timestep", 0.05))
988
+ )
989
+ decode_queues.append((i, decode_queue))
990
+
991
+ # Aguarda todos os decodes
992
+ print(f"\n Aguardando conclusão de todos os decodes...")
993
+ pixel_results = []
994
+
995
+ for i, decode_queue in decode_queues:
996
+ pixel_tensor = self.gpu_pool.get_result(decode_queue, timeout=300)
997
+ pixel_results.append(pixel_tensor)
998
+ print(f" ✓ Chunk {i} decodificado")
999
+
1000
+ # Costura resultados finais
1001
+ print(f"\n Costurando resultado final...")
1002
+ final_pixel_tensor = self._stitch_dynamic_chunks(
1003
+ pixel_results,
1004
+ segment_sizes,
1005
+ macro_overlap
1006
+ )
1007
+
1008
+ final_video_path = self._save_video_from_tensor(
1009
+ final_pixel_tensor,
1010
+ "refined_final_video",
1011
+ seed
1012
+ )
1013
+
1014
+ print(f"[SUCCESS] Vídeo refinado salvo em: {final_video_path}")
1015
+ self.gpu_pool.print_stats()
1016
+
1017
+ return final_video_path
1018
+
1019
+ def encode_latents_to_mp4(
1020
+ self,
1021
+ pixel_tensor: torch.Tensor,
1022
+ output_path: str,
1023
+ fps: float = 24.0
1024
+ ) -> str:
1025
+ """Codifica tensor de pixels em arquivo MP4"""
1026
+ print(f"[INFO] Codificando vídeo para MP4: {output_path}")
1027
+
1028
+ # Desnormaliza
1029
+ pixel_tensor = (pixel_tensor + 1.0) / 2.0 * 255.0
1030
+ pixel_tensor = torch.clamp(pixel_tensor, 0, 255)
1031
+
1032
+ # Converte para formato de vídeo
1033
+ video_encode_tool_singleton.encode_video_from_tensor(
1034
+ pixel_tensor,
1035
+ output_path,
1036
+ fps=fps
1037
+ )
1038
+
1039
+ print(f"[SUCCESS] Vídeo codificado: {output_path}")
1040
+ return output_path
1041
+
1042
+ # ==============================================================================
1043
+ # MÉTODOS DE CONFIGURAÇÃO E CARREGAMENTO
1044
+ # ==============================================================================
1045
+
1046
+ def _load_config(self, config_file: str) -> Dict:
1047
+ """Carrega configuração YAML"""
1048
+ config_path = LTX_VIDEO_REPO_DIR / "configs" / config_file
1049
+
1050
+ if not config_path.exists():
1051
+ print(f"[WARNING] Arquivo de config não encontrado: {config_path}")
1052
+ return {}
1053
+
1054
+ with open(config_path, "r") as f:
1055
+ config = yaml.safe_load(f)
1056
+
1057
+ return config or {}
1058
+
1059
+ def _load_models_from_hub(self) -> Tuple[LTXVideoPipeline, Optional[LatentUpsampler]]:
1060
+ """Carrega modelos do Hugging Face Hub"""
1061
+ print("[INFO] Carregando modelos do Hub...")
1062
+
1063
+ # Carrega pipeline
1064
+ pipeline = LTXVideoPipeline.from_pretrained(
1065
+ "Lightricks/LTX-Video",
1066
+ torch_dtype=torch.bfloat16
1067
+ )
1068
+
1069
+ # Carrega upsampler (opcional)
1070
+ try:
1071
+ upsampler = LatentUpsampler.from_pretrained(
1072
+ "Lightricks/LTX-Video",
1073
+ torch_dtype=torch.bfloat16
1074
+ )
1075
+ except Exception as e:
1076
+ print(f"[WARNING] Upsampler não disponível: {e}")
1077
+ upsampler = None
1078
+
1079
+ print("[SUCCESS] Modelos carregados com sucesso!")
1080
+ return pipeline, upsampler
1081
+
1082
+ def _move_models_to_device(self):
1083
+ """Move modelos para dispositivo principal (não usado com pools)"""
1084
+ # Implementado no _clone_pipeline_to_device
1085
+ pass
1086
+
1087
+ def _get_precision_dtype(self) -> torch.dtype:
1088
+ """Retorna tipo de dados de precisão baseado em disponibilidade"""
1089
+ if torch.cuda.is_available():
1090
+ device_props = torch.cuda.get_device_properties(0)
1091
+ if device_props.major >= 8: # A100, H100, etc.
1092
+ return torch.bfloat16
1093
+
1094
+ return torch.float16
1095
+
1096
+ # ==============================================================================
1097
+ # MÉTODOS AUXILIARES DE SALVAMENTO E GERENCIAMENTO
1098
+ # ==============================================================================
1099
+
1100
+ def _save_latents_to_disk(
1101
+ self,
1102
+ latents: torch.Tensor,
1103
+ prefix: str,
1104
+ seed: int
1105
+ ) -> str:
1106
+ """Salva latentes em arquivo .pt"""
1107
+ filename = f"{prefix}_{seed}.pt"
1108
+ filepath = self.tmp_dir / filename
1109
+
1110
+ torch.save(latents, filepath)
1111
+ print(f" Latentes salvos: {filepath}")
1112
+
1113
+ return str(filepath)
1114
+
1115
+ def _save_video_from_tensor(
1116
+ self,
1117
+ pixel_tensor: torch.Tensor,
1118
+ prefix: str,
1119
+ seed: int
1120
+ ) -> str:
1121
+ """Salva tensor de pixels como vídeo MP4"""
1122
+ filename = f"{prefix}_{seed}.mp4"
1123
+ filepath = RESULTS_DIR / filename
1124
+
1125
+ RESULTS_DIR.mkdir(parents=True, exist_ok=True)
1126
+
1127
+ self.encode_latents_to_mp4(pixel_tensor, str(filepath), fps=DEFAULT_FPS)
1128
+
1129
+ print(f" Vídeo salvo: {filepath}")
1130
+ return str(filepath)
1131
+
1132
+ def _finalize(self):
1133
+ """Finaliza o serviço e libera recursos"""
1134
+ print("[INFO] Finalizando VideoService...")
1135
+
1136
+ self.gpu_pool.print_stats()
1137
+ self.gpu_pool.shutdown()
1138
+
1139
+ if self.tmp_dir and self.tmp_dir.exists():
1140
+ shutil.rmtree(self.tmp_dir)
1141
+ print(f" Diretório temporário removido: {self.tmp_dir}")
1142
+
1143
+ # Limpa memória CUDA
1144
+ torch.cuda.empty_cache()
1145
+ gc.collect()
1146
+
1147
+ print("[SUCCESS] VideoService finalizado!")
1148
+
1149
+ def _seed_everething(self, seed: int):
1150
+ """Define seed para reproducibilidade"""
1151
+ random.seed(seed)
1152
+ np.random.seed(seed)
1153
+ torch.manual_seed(seed)
1154
+ torch.cuda.manual_seed_all(seed)
1155
+
1156
+ def _register_tmp_dir(self):
1157
+ """Registra diretório temporário para salvamento de latentes"""
1158
+ self.tmp_dir = Path(tempfile.mkdtemp(prefix="ltx_video_"))
1159
+ print(f" Diretório temporário: {self.tmp_dir}")
1160
+
1161
+ # ==============================================================================
1162
+ # PONTO DE ENTRADA E EXEMPLO DE USO
1163
+ # ==============================================================================
1164
+
1165
+ if __name__ == "__main__":
1166
+ print("\n" + "="*80)
1167
+ print("LTX VIDEO SERVICE - Multi-GPU Pool Manager")
1168
+ print("="*80 + "\n")
1169
+
1170
+ try:
1171
+ # Inicializa o serviço
1172
+ print("Criando instância do VideoService...")
1173
+ video_service = VideoService()
1174
+
1175
+ # Exemplo 1: Geração de baixa resolução
1176
+ print("\n[EXEMPLO 1] Geração de baixa resolução...")
1177
+ latents_path, seed = video_service.generate_low_resolution(
1178
+ prompt="A beautiful sunset over the ocean",
1179
+ negative_prompt="blurry, low quality",
1180
+ height=512,
1181
+ width=768,
1182
+ duration_secs=2.0,
1183
+ guidance_scale=3.0,
1184
+ seed=42,
1185
+ image_filepaths=None
1186
+ )
1187
+
1188
+ # Exemplo 2: Refinamento e decodificação
1189
+ print("\n[EXEMPLO 2] Refinamento e decodificação...")
1190
+ video_path, latents_path, final_tensor = video_service.refine_texture_only(
1191
+ latents_path=latents_path,
1192
+ prompt="A beautiful sunset over the ocean",
1193
+ negative_prompt="blurry, low quality",
1194
+ guidance_scale=3.0,
1195
+ seed=seed,
1196
+ image_filepaths=None,
1197
+ macro_chunk_size=8,
1198
+ macro_overlap=2
1199
+ )
1200
+
1201
+ print(f"\n✓ Vídeo final gerado: {video_path}")
1202
+
1203
+ except KeyboardInterrupt:
1204
+ print("\n\n[INFO] Interrompido pelo usuário.")
1205
+ except Exception as e:
1206
+ print(f"\n\n[ERROR] Erro na execução: {e}")
1207
+ import traceback
1208
+ traceback.print_exc()
1209
+ finally:
1210
+ if 'video_service' in locals():
1211
+ video_service._finalize()
1212
+
1213
+ print("\n" + "="*80)
1214
+ print("Execução concluída")
1215
+ print("="*80 + "\n")