caarleexx commited on
Commit
9094e95
·
verified ·
1 Parent(s): 232cc6e

Delete api/ltx_server_refactored.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored.py +0 -1388
api/ltx_server_refactored.py DELETED
@@ -1,1388 +0,0 @@
1
- # ltx_server_clean_refactor.py — VideoService (Modular Version with Simple Overlap Chunking)
2
-
3
- # ==============================================================================
4
- # 0. CONFIGURAÇÃO DE AMBIENTE E IMPORTAÇÕES
5
- # ==============================================================================
6
- import os
7
- import sys
8
- import gc
9
- import yaml
10
- import time
11
- import json
12
- import copy
13
- import random
14
- import shutil
15
- import warnings
16
- import tempfile
17
- import traceback
18
- import subprocess
19
- from pathlib import Path
20
- from typing import List, Dict, Optional, Tuple, Union
21
- import cv2
22
-
23
- ENABLE_MEMORY_OPTIMIZATION = os.getenv("ADUC_MEMORY_OPTIMIZATION", "1").lower() in ["1", "true", "yes"]
24
-
25
-
26
- # --- Configurações de Logging e Avisos ---
27
- warnings.filterwarnings("ignore", category=UserWarning)
28
- warnings.filterwarnings("ignore", category=FutureWarning)
29
- from huggingface_hub import logging as hf_logging
30
- hf_logging.set_verbosity_error()
31
-
32
- # --- Importações de Bibliotecas de ML/Processamento ---
33
- import torch
34
- import torch.nn.functional as F
35
- import numpy as np
36
- from PIL import Image
37
- from einops import rearrange
38
- from huggingface_hub import hf_hub_download
39
- from safetensors import safe_open
40
-
41
- from managers.vae_manager import vae_manager_singleton
42
- from tools.video_encode_tool import video_encode_tool_singleton
43
-
44
- # --- Constantes Globais ---
45
- LTXV_DEBUG = True # Mude para False para desativar logs de debug
46
- LTXV_FRAME_LOG_EVERY = 8
47
- DEPS_DIR = Path("/data")
48
- LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
49
- RESULTS_DIR = Path("/app/output")
50
- DEFAULT_FPS = 24.0
51
-
52
- def add_deps_to_path(repo_path: Path):
53
- """Adiciona o diretório do repositório ao sys.path para importações locais."""
54
- resolved_path = str(repo_path.resolve())
55
- if resolved_path not in sys.path:
56
- sys.path.insert(0, resolved_path)
57
-
58
- add_deps_to_path(LTX_VIDEO_REPO_DIR)
59
-
60
- # --- Importações Dependentes do Path Adicionado ---
61
- from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
62
- from ltx_video.pipelines.pipeline_ltx_video import adain_filter_latent
63
- from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
64
- from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
65
- from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
66
- from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
67
- from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
68
- from ltx_video.models.transformers.transformer3d import Transformer3DModel
69
- from ltx_video.schedulers.rf import RectifiedFlowScheduler
70
- from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
71
- import ltx_video.pipelines.crf_compressor as crf_compressor
72
-
73
- def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
74
- latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
75
- latent_upsampler.to(device)
76
- latent_upsampler.eval()
77
- return latent_upsampler
78
-
79
- def create_ltx_video_pipeline(
80
- ckpt_path: str,
81
- precision: str,
82
- text_encoder_model_name_or_path: str,
83
- sampler: Optional[str] = None,
84
- device: Optional[str] = None,
85
- enhance_prompt: bool = False,
86
- prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
87
- prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
88
- ) -> LTXVideoPipeline:
89
- ckpt_path = Path(ckpt_path)
90
- assert os.path.exists(
91
- ckpt_path
92
- ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
93
-
94
- with safe_open(ckpt_path, framework="pt") as f:
95
- metadata = f.metadata()
96
- config_str = metadata.get("config")
97
- configs = json.loads(config_str)
98
- allowed_inference_steps = configs.get("allowed_inference_steps", None)
99
-
100
- vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
101
- transformer = Transformer3DModel.from_pretrained(ckpt_path)
102
-
103
- # Use constructor if sampler is specified, otherwise use from_pretrained
104
- if sampler == "from_checkpoint" or not sampler:
105
- scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
106
- else:
107
- scheduler = RectifiedFlowScheduler(
108
- sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
109
- )
110
-
111
- text_encoder = T5EncoderModel.from_pretrained(
112
- text_encoder_model_name_or_path, subfolder="text_encoder"
113
- )
114
- patchifier = SymmetricPatchifier(patch_size=1)
115
- tokenizer = T5Tokenizer.from_pretrained(
116
- text_encoder_model_name_or_path, subfolder="tokenizer"
117
- )
118
-
119
- transformer = transformer.to(device)
120
- vae = vae.to(device)
121
- text_encoder = text_encoder.to(device)
122
-
123
- if enhance_prompt:
124
- prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
125
- prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
126
- )
127
- prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
128
- prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
129
- )
130
- prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
131
- prompt_enhancer_llm_model_name_or_path,
132
- torch_dtype="bfloat16",
133
- )
134
- prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
135
- prompt_enhancer_llm_model_name_or_path,
136
- )
137
- else:
138
- prompt_enhancer_image_caption_model = None
139
- prompt_enhancer_image_caption_processor = None
140
- prompt_enhancer_llm_model = None
141
- prompt_enhancer_llm_tokenizer = None
142
-
143
- vae = vae.to(torch.bfloat16)
144
- if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
145
- transformer = transformer.to(torch.bfloat16)
146
- text_encoder = text_encoder.to(torch.bfloat16)
147
-
148
- # Use submodels for the pipeline
149
- submodel_dict = {
150
- "transformer": transformer,
151
- "patchifier": patchifier,
152
- "text_encoder": text_encoder,
153
- "tokenizer": tokenizer,
154
- "scheduler": scheduler,
155
- "vae": vae,
156
- "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
157
- "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
158
- "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
159
- "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
160
- "allowed_inference_steps": allowed_inference_steps,
161
- }
162
-
163
- pipeline = LTXVideoPipeline(**submodel_dict)
164
-
165
- pipeline = pipeline.to(device)
166
- return pipeline
167
-
168
- # ==============================================================================
169
- # 2. FUNÇÕES AUXILIARES DE PROCESSAMENTO
170
- # ==============================================================================
171
-
172
- def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) -> Tuple[int, int, int, int]:
173
- """Calcula o preenchimento para centralizar uma imagem em uma nova dimensão."""
174
- pad_h = target_h - orig_h
175
- pad_w = target_w - orig_w
176
- pad_top = pad_h // 2
177
- pad_bottom = pad_h - pad_top
178
- pad_left = pad_w // 2
179
- pad_right = pad_w - pad_left
180
- return (pad_left, pad_right, pad_top, pad_bottom)
181
-
182
- def log_tensor_info(tensor: torch.Tensor, name: str = "Tensor"):
183
- """Exibe informações detalhadas sobre um tensor para depuração."""
184
- if not isinstance(tensor, torch.Tensor):
185
- print(f"\n[INFO] '{name}' não é um tensor.")
186
- return
187
- print(f"\n--- Tensor Info: {name} ---")
188
- print(f" - Shape: {tuple(tensor.shape)}")
189
- print(f" - Dtype: {tensor.dtype}")
190
- print(f" - Device: {tensor.device}")
191
- if tensor.numel() > 0:
192
- try:
193
- print(f" - Stats: Min={tensor.min().item():.4f}, Max={tensor.max().item():.4f}, Mean={tensor.mean().item():.4f}")
194
- except RuntimeError:
195
- print(" - Stats: Não foi possível calcular (ex: tensores bool).")
196
- print("-" * 30)
197
-
198
- # ==============================================================================
199
- # 3. CLASSE PRINCIPAL DO SERVIÇO DE VÍDEO
200
- # ==============================================================================
201
-
202
-
203
-
204
-
205
-
206
-
207
-
208
- # Nova configuração para 4 GPUs
209
- GPU_CONFIG = {
210
- "transformer_workers": [0, 1], # GPUs para transformer + text_encoder
211
- "vae_workers": [2, 3], # GPUs para VAE + upscaler
212
- "enable_multi_gpu": True
213
- }
214
-
215
-
216
-
217
-
218
-
219
- # Adicione/modifique estas configurações no início do arquivo
220
- PRECISION_CONFIG = {
221
- "enable_fp8": False, # Desabilitar FP8 devido a problemas de compatibilidade
222
- "default_dtype": torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
223
- "fallback_dtype": torch.float16
224
- }
225
-
226
- # Modifique a classe VideoService
227
- class VideoService:
228
-
229
-
230
- def _get_safe_precision_dtype(self):
231
- """Configuração de precisão mais segura para evitar conflitos de dtype."""
232
- if not torch.cuda.is_available():
233
- return torch.float32
234
-
235
- # Verificar suporte a bfloat16
236
- if torch.cuda.is_bf16_supported():
237
- print("[INFO] Usando bfloat16 (suportado pela GPU)")
238
- return torch.bfloat16
239
- else:
240
- print("[INFO] Usando float16 (bfloat16 não suportado)")
241
- return torch.float16
242
-
243
- def _load_models_with_safe_dtype(self):
244
- """Carrega modelos com dtype seguro e verifica compatibilidade."""
245
- print("[INFO] Carregando modelos com dtype seguro...")
246
-
247
- # CORREÇÃO: Usar dtype seguro explicitamente
248
- torch_dtype = self.runtime_autocast_dtype
249
-
250
- try:
251
- pipeline = LTXVideoPipeline.from_pretrained(
252
- "Lightricks/LTX-Video",
253
- torch_dtype=torch_dtype,
254
- variant="fp8" if PRECISION_CONFIG["enable_fp8"] else None,
255
- cache_dir=MODEL_CACHE_DIR
256
- )
257
- except Exception as e:
258
- print(f"[WARNING] Erro ao carregar com {torch_dtype}: {e}")
259
- print("[INFO] Tentando carregar com float16...")
260
- torch_dtype = torch.float16
261
- pipeline = LTXVideoPipeline.from_pretrained(
262
- "Lightricks/LTX-Video",
263
- torch_dtype=torch_dtype,
264
- cache_dir=MODEL_CACHE_DIR
265
- )
266
-
267
- # CORREÇÃO: Verificar e ajustar dtypes dos componentes do modelo
268
- self._ensure_consistent_dtypes(pipeline, torch_dtype)
269
-
270
- # Carregar upscaler com o mesmo dtype
271
- latent_upsampler = self._load_latent_upsampler(torch_dtype)
272
-
273
- return pipeline, latent_upsampler
274
-
275
- def _ensure_consistent_dtypes(self, pipeline, expected_dtype):
276
- """Garante que todos os componentes do pipeline tenham dtypes consistentes."""
277
- print("[INFO] Verificando consistência de dtypes...")
278
-
279
- components = [
280
- (pipeline.transformer, "transformer"),
281
- (pipeline.text_encoder, "text_encoder"),
282
- (pipeline.vae, "vae"),
283
- (pipeline.patchifier, "patchifier")
284
- ]
285
-
286
- for component, name in components:
287
- if hasattr(component, 'parameters') and next(component.parameters(), None) is not None:
288
- actual_dtype = next(component.parameters()).dtype
289
- if actual_dtype != expected_dtype:
290
- print(f"[INFO] Convertendo {name} de {actual_dtype} para {expected_dtype}")
291
- component.to(dtype=expected_dtype)
292
-
293
- print("[INFO] Verificação de dtypes concluída.")
294
-
295
- def _load_latent_upsampler(self, torch_dtype):
296
- """Carrega o latent upscaler com dtype seguro."""
297
- try:
298
- from ltx_video.models import LatentUpscaler
299
- upscaler = LatentUpscaler.from_pretrained(
300
- "Lightricks/LTX-Video",
301
- subfolder="ltxv-spatial-upscaler-0.9.8",
302
- torch_dtype=torch_dtype,
303
- cache_dir=MODEL_CACHE_DIR
304
- )
305
- return upscaler
306
- except Exception as e:
307
- print(f"[WARNING] Não foi possível carregar o latent upscaler: {e}")
308
- return None
309
-
310
- def _setup_4gpu_workers(self):
311
- """Configura 4 workers com verificação de dtype."""
312
- if self.multi_gpu_enabled:
313
- print("[INFO] Distribuindo modelos em 4 workers...")
314
-
315
- # Workers 0 e 1: Transformer + Text Encoder
316
- for i, device in enumerate(self.transformer_devices):
317
- print(f"[INFO] Worker {i} (Transformer): {device}")
318
- if i == 0:
319
- self.pipeline.transformer.to(device)
320
- self.pipeline.text_encoder.to(device)
321
- #self.pipeline.patchifier.to(device)
322
- # Nota: Para multi-worker transformer, precisaríamos de cópias do modelo
323
-
324
- # Workers 2 e 3: VAE
325
- for i, device in enumerate(self.vae_devices):
326
- print(f"[INFO] Worker {i+2} (VAE): {device}")
327
- if i == 0:
328
- self.pipeline.vae.to(device)
329
-
330
- # Upscaler
331
- if self.latent_upsampler:
332
- self.latent_upsampler.to(self.vae_devices[0])
333
-
334
- print("[INFO] Distribuição 4-Workers concluída.")
335
-
336
- # CORREÇÃO: Verificar dtypes após mover para GPU
337
- self._verify_gpu_dtypes()
338
- else:
339
- self.pipeline.to(self.device_ltx)
340
- if self.latent_upsampler:
341
- self.latent_upsampler.to(self.device_ltx)
342
-
343
- def _verify_gpu_dtypes(self):
344
- """Verifica se os dtypes estão consistentes após mover para GPU."""
345
- print("[INFO] Verificando dtypes nas GPUs...")
346
-
347
- components = [
348
- (self.pipeline.transformer, "transformer"),
349
- (self.pipeline.vae, "vae")
350
- ]
351
-
352
- for component, name in components:
353
- if hasattr(component, 'parameters') and next(component.parameters(), None) is not None:
354
- param = next(component.parameters())
355
- print(f" {name}: dtype={param.dtype}, device={param.device}")
356
-
357
- print("[INFO] Verificação de GPU dtypes concluída.")
358
-
359
- def generate_low_resolution1(self, prompt: str, negative_prompt: str,
360
- height: int, width: int, num_frames: int,
361
- guidance_scale: float, seed: Optional[int] = None,
362
- conditioning_items: Optional[List[ConditioningItem]] = None) -> Tuple[str, str, int]:
363
- """Geração de baixa resolução com dtype seguro."""
364
- print("\n[INFO] Iniciando ETAPA 1: Geração de Baixa Resolução...")
365
- self._set_generation_environment()
366
-
367
- temp_dir = tempfile.mkdtemp(prefix="ltxv_low_")
368
- self._register_tmp_dir(temp_dir)
369
- used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
370
-
371
- # Determinar dispositivo
372
- if self.multi_gpu_enabled:
373
- device = self.transformer_devices[0]
374
- else:
375
- device = self.device_ltx
376
-
377
- print(f" - Usando Seed: {used_seed}")
378
- print(f" - Frames: {num_frames}, Duração: {num_frames/DEFAULT_FPS:.1f}s")
379
- print(f" - Dimensões de Saída: {height}x{width}")
380
- print(f" - Dispositivo: {device}, Dtype: {self.runtime_autocast_dtype}")
381
-
382
- # CORREÇÃO: Configuração de autocast mais robusta
383
- device_type = device.split(':')[0] if ':' in device else device
384
- enabled_autocast = 'cuda' in device and self.runtime_autocast_dtype in [torch.float16, torch.bfloat16]
385
-
386
- print(f" - Autocast habilitado: {enabled_autocast}")
387
-
388
- try:
389
- with torch.autocast(device_type=device_type,
390
- dtype=self.runtime_autocast_dtype,
391
- enabled=enabled_autocast):
392
-
393
- first_pass_kwargs = {
394
- "prompt": prompt, "negative_prompt": negative_prompt,
395
- "height": height, "width": width,
396
- "frame_rate": int(DEFAULT_FPS), "num_frames": num_frames,
397
- "guidance_scale": float(guidance_scale),
398
- "output_type": "latent",
399
- "generator": torch.Generator(device=device).manual_seed(used_seed),
400
- "conditioning_items": conditioning_items,
401
- **(self.config.get("first_pass", {}))
402
- }
403
-
404
- print(" - Enviando para a pipeline LTX...")
405
- latents = self.pipeline(**first_pass_kwargs).images
406
- print(f" [LOG] Latentes gerados. Shape: {latents.shape}, Dtype: {latents.dtype}")
407
-
408
- except RuntimeError as e:
409
- print(f"[ERROR] Erro durante a geração: {e}")
410
- print("[INFO] Tentando fallback para float32...")
411
- # Fallback para float32
412
- with torch.autocast(device_type=device_type, dtype=torch.float32, enabled=False):
413
- first_pass_kwargs = {
414
- "prompt": prompt, "negative_prompt": negative_prompt,
415
- "height": height, "width": width,
416
- "frame_rate": int(DEFAULT_FPS), "num_frames": num_frames,
417
- "guidance_scale": float(guidance_scale),
418
- "output_type": "latent",
419
- "generator": torch.Generator(device=device).manual_seed(used_seed),
420
- "conditioning_items": conditioning_items,
421
- **(self.config.get("first_pass", {}))
422
- }
423
- latents = self.pipeline(**first_pass_kwargs).images
424
-
425
- # Resto do método permanece igual...
426
- latents_cpu = latents.cpu()
427
- del latents
428
- torch.cuda.empty_cache()
429
-
430
- # ... (decodificação e salvamento)
431
- latents_path = self._save_latents_to_disk(latents_cpu, "latents_low", used_seed)
432
-
433
- print("\n[INFO] Decodificando vídeo de baixa resolução...")
434
- self._set_decode_environment()
435
-
436
- # Decodificação (similar ao código anterior)
437
- total_latents = latents_cpu.shape[2]
438
- pontos_de_corte, segment_sizes = self._calculate_dynamic_cuts(total_latents)
439
-
440
- if len(pontos_de_corte) == 1:
441
- vae_device = self.vae_devices[0] if self.multi_gpu_enabled else self.device_vae
442
- latents_for_decode = latents_cpu.to(vae_device)
443
- vae_manager = self._get_vae_manager(vae_device)
444
-
445
- pixel_tensor = vae_manager.decode(
446
- latents_for_decode,
447
- decode_timestep=float(self.config.get("decode_timestep", 0.05))
448
- ).cpu()
449
- else:
450
- print(f" [LOG] Decodificação em {len(pontos_de_corte)} chunks...")
451
- pixel_chunks_list = []
452
- for i, (start, end) in enumerate(pontos_de_corte):
453
- start, end = max(0, start), min(total_latents, end)
454
- if start >= end:
455
- continue
456
-
457
- latent_chunk = latents_cpu[:, :, start:end, :, :]
458
- vae_device = self.vae_devices[0] if self.multi_gpu_enabled else self.device_vae
459
- latent_chunk = latent_chunk.to(vae_device)
460
- vae_manager = self._get_vae_manager(vae_device)
461
-
462
- print(f" -> Decodificando Grupo {i+1} (latentes {start} a {end-1})")
463
-
464
- pixel_chunk = vae_manager.decode(
465
- latent_chunk,
466
- decode_timestep=float(self.config.get("decode_timestep", 0.05))
467
- )
468
- pixel_chunks_list.append(pixel_chunk.cpu())
469
- torch.cuda.empty_cache()
470
-
471
- pixel_tensor = self._stitch_dynamic_chunks(pixel_chunks_list, segment_sizes)
472
-
473
- video_path = self._save_video_from_tensor(pixel_tensor, "video_low", used_seed, temp_dir)
474
- self._set_generation_environment()
475
-
476
- del latents_cpu
477
- self._finalize()
478
-
479
- print("\n[SUCCESS] Geração de Baixa Resolução Concluída")
480
- return video_path, latents_path, used_seed
481
-
482
-
483
- def __init__(self):
484
- """Inicializa o serviço com 4 workers especializados."""
485
- t0 = time.perf_counter()
486
- print("[INFO] Inicializando VideoService com 4 Workers...")
487
-
488
- # Configuração para 4 GPUs
489
- self.multi_gpu_enabled = GPU_CONFIG["enable_multi_gpu"] and torch.cuda.device_count() >= 4
490
-
491
- if self.multi_gpu_enabled:
492
- self.transformer_devices = [f"cuda:{gpu}" for gpu in GPU_CONFIG["transformer_workers"]]
493
- self.vae_devices = [f"cuda:{gpu}" for gpu in GPU_CONFIG["vae_workers"]]
494
- self.current_transformer_idx = 0
495
- self.current_vae_idx = 0
496
-
497
- print(f"[INFO] Configuração 4-Workers:")
498
- print(f" Transformer Workers: {self.transformer_devices}")
499
- print(f" VAE Workers: {self.vae_devices}")
500
- else:
501
- self.device_ltx = self.device_vae = "cuda" if torch.cuda.is_available() else "cpu"
502
- print("[INFO] Usando configuração single-GPU")
503
-
504
- self.config = self._load_config("ltxv-13b-0.9.8-distilled-fp8.yaml")
505
- self.pipeline, self.latent_upsampler = self._load_models_from_hub()
506
- self._setup_4gpu_workers()
507
-
508
- self.runtime_autocast_dtype = self._get_precision_dtype()
509
-
510
- # Configurar VAE managers para todas as GPUs VAE
511
- self.vae_managers = []
512
- if self.multi_gpu_enabled:
513
- for vae_device in self.vae_devices:
514
- # Usar o mesmo VAE manager singleton mas configurar para dispositivos diferentes
515
- manager = type(vae_manager_singleton)() # Nova instância
516
- manager.attach_pipeline(
517
- self.pipeline,
518
- device=vae_device,
519
- autocast_dtype=self.runtime_autocast_dtype
520
- )
521
- self.vae_managers.append(manager)
522
- else:
523
- vae_manager_singleton.attach_pipeline(
524
- self.pipeline,
525
- device=self.device_vae,
526
- autocast_dtype=self.runtime_autocast_dtype
527
- )
528
-
529
- self._tmp_dirs = set()
530
- RESULTS_DIR.mkdir(exist_ok=True)
531
- print(f"[INFO] VideoService 4-Workers pronto. Tempo: {time.perf_counter()-t0:.2f}s")
532
-
533
-
534
-
535
- def _set_generation_environment(self):
536
- """Prepara o ambiente para geração (LTX pipeline)."""
537
- if not ENABLE_MEMORY_OPTIMIZATION:
538
- return
539
-
540
- print("\n [VRAM Manager] Configurando ambiente de GERAÇÃO...")
541
-
542
- if self.multi_gpu_enabled:
543
- transformer_device = self.transformer_devices[0] # Usar primeira GPU transformer
544
- # Garantir que transformer e text_encoder estão na GPU correta
545
- if not next(self.pipeline.transformer.parameters()).is_cuda:
546
- self.pipeline.transformer.to(transformer_device)
547
- if not next(self.pipeline.text_encoder.parameters()).is_cuda:
548
- self.pipeline.text_encoder.to(transformer_device)
549
- # Mover VAE para CPU durante geração
550
- if next(self.pipeline.vae.parameters()).is_cuda:
551
- self.pipeline.vae.to('cpu')
552
- else:
553
- # Comportamento original para single GPU
554
- if next(self.pipeline.vae.parameters()).is_cuda:
555
- self.pipeline.vae.to('cpu')
556
- if not next(self.pipeline.transformer.parameters()).is_cuda:
557
- self.pipeline.transformer.to(self.device_ltx)
558
- if not next(self.pipeline.text_encoder.parameters()).is_cuda:
559
- self.pipeline.text_encoder.to(self.device_ltx)
560
-
561
- torch.cuda.empty_cache()
562
- print(" [VRAM Manager] Ambiente de GERAÇÃO pronto.\n")
563
-
564
- def _set_decode_environment(self):
565
- """Prepara o ambiente para decodificação (VAE pipeline)."""
566
- if not ENABLE_MEMORY_OPTIMIZATION:
567
- return
568
-
569
- print("\n [VRAM Manager] Configurando ambiente de DECODIFICAÇÃO...")
570
-
571
- if self.multi_gpu_enabled:
572
- # Mover transformer e text_encoder para CPU
573
- if next(self.pipeline.transformer.parameters()).is_cuda:
574
- self.pipeline.transformer.to('cpu')
575
- if next(self.pipeline.text_encoder.parameters()).is_cuda:
576
- self.pipeline.text_encoder.to('cpu')
577
-
578
- # Garantir que VAE está na primeira GPU VAE para decodificação
579
- vae_device = self.vae_devices[0]
580
- if not next(self.pipeline.vae.parameters()).is_cuda:
581
- self.pipeline.vae.to(vae_device)
582
- else:
583
- # Comportamento original para single GPU
584
- if next(self.pipeline.transformer.parameters()).is_cuda:
585
- self.pipeline.transformer.to('cpu')
586
- if next(self.pipeline.text_encoder.parameters()).is_cuda:
587
- self.pipeline.text_encoder.to('cpu')
588
- if not next(self.pipeline.vae.parameters()).is_cuda:
589
- self.pipeline.vae.to(self.device_vae)
590
-
591
- torch.cuda.empty_cache()
592
- print(" [VRAM Manager] Ambiente de DECODIFICAÇÃO pronto.\n")
593
-
594
- def _get_vae_manager(self, device):
595
- """Retorna o VAE manager para o dispositivo especificado."""
596
- if not self.multi_gpu_enabled:
597
- return vae_manager_singleton
598
-
599
- # Encontrar o manager correspondente ao dispositivo
600
- device_index = int(device.split(':')[-1])
601
- for i, vae_device in enumerate(self.vae_devices):
602
- if int(vae_device.split(':')[-1]) == device_index:
603
- return self.vae_managers[i]
604
- return self.vae_managers[0] # Fallback
605
-
606
- def refine_texture_only(self, latents_path: str, prompt: str, negative_prompt: str,
607
- guidance_scale: float, seed: Optional[int] = None,
608
- conditioning_items: Optional[List[ConditioningItem]] = None) -> Tuple[str, str, torch.Tensor]:
609
- """Versão simplificada para 4 workers."""
610
- print("\n[INFO] Iniciando ETAPA 2 com 4 Workers...")
611
-
612
- temp_dir = tempfile.mkdtemp(prefix="ltxv_refine_")
613
- self._register_tmp_dir(temp_dir)
614
- used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
615
-
616
- # FASE 1: Geração com worker Transformer
617
- print("[LOG] FASE 1: Geração de Latentes")
618
- self._set_generation_environment()
619
-
620
- # Carregar latentes
621
- latents_to_refine = torch.load(latents_path)
622
- transformer_device = self.transformer_devices[0] # Usar primeira GPU transformer
623
- latents_to_refine = latents_to_refine.to(transformer_device)
624
- print(f" [LOG] Latentes carregados no Worker {transformer_device}. Shape: {latents_to_refine.shape}")
625
-
626
- with torch.autocast(device_type=transformer_device.split(':')[0],
627
- dtype=self.runtime_autocast_dtype,
628
- enabled=('cuda' in transformer_device)):
629
-
630
- refine_height = latents_to_refine.shape[3] * self.pipeline.vae_scale_factor
631
- refine_width = latents_to_refine.shape[4] * self.pipeline.vae_scale_factor
632
-
633
- second_pass_kwargs = {
634
- "prompt": prompt, "negative_prompt": negative_prompt,
635
- "height": refine_height, "width": refine_width,
636
- "frame_rate": int(DEFAULT_FPS), "num_frames": latents_to_refine.shape[2],
637
- "latents": latents_to_refine, "guidance_scale": float(guidance_scale),
638
- "output_type": "latent",
639
- "generator": torch.Generator(device=transformer_device).manual_seed(used_seed),
640
- "conditioning_items": conditioning_items,
641
- **(self.config.get("second_pass", {}))
642
- }
643
-
644
- final_latents = self.pipeline(**second_pass_kwargs).images
645
- print(f" [LOG] Latentes refinados. Shape: {final_latents.shape}")
646
-
647
- # Mover latentes refinados para CPU
648
- final_latents_cpu = final_latents.cpu()
649
- del final_latents, latents_to_refine
650
- torch.cuda.empty_cache()
651
-
652
- # FASE 2: Decodificação
653
- print("\n[LOG] FASE 2: Decodificação")
654
- self._set_decode_environment()
655
-
656
- total_latents = final_latents_cpu.shape[2]
657
- pontos_de_corte, segment_sizes = self._calculate_dynamic_cuts(total_latents)
658
-
659
- if len(pontos_de_corte) == 1:
660
- vae_device = self.vae_devices[0] # Usar primeira GPU VAE
661
- latents_for_decode = final_latents_cpu.to(vae_device)
662
- vae_manager = self._get_vae_manager(vae_device)
663
-
664
- pixel_tensor = vae_manager.decode(
665
- latents_for_decode,
666
- decode_timestep=float(self.config.get("decode_timestep", 0.05))
667
- ).cpu()
668
- else:
669
- print(f" [LOG] Decodificação em {len(pontos_de_corte)} chunks...")
670
- pixel_chunks_list = []
671
-
672
- for i, (start, end) in enumerate(pontos_de_corte):
673
- start, end = max(0, start), min(total_latents, end)
674
- if start >= end:
675
- continue
676
-
677
- latent_chunk = final_latents_cpu[:, :, start:end, :, :]
678
- # Usar sempre a primeira GPU VAE (evita problemas com múltiplos VAEs)
679
- vae_device = self.vae_devices[0]
680
- latent_chunk = latent_chunk.to(vae_device)
681
- vae_manager = self._get_vae_manager(vae_device)
682
-
683
- print(f" -> Decodificando Grupo {i+1} (latentes {start} a {end-1})")
684
-
685
- pixel_chunk = vae_manager.decode(
686
- latent_chunk,
687
- decode_timestep=float(self.config.get("decode_timestep", 0.05))
688
- )
689
- pixel_chunks_list.append(pixel_chunk.cpu())
690
- torch.cuda.empty_cache()
691
-
692
- pixel_tensor = self._stitch_dynamic_chunks(pixel_chunks_list, segment_sizes)
693
-
694
- # Salvar resultados
695
- video_path_out = self._save_video_from_tensor(pixel_tensor, "refined_video_final", used_seed, temp_dir)
696
- latents_path_out = self._save_latents_to_disk(final_latents_cpu, "latents_refined_final", used_seed)
697
-
698
- # Restaurar ambiente
699
- self._set_generation_environment()
700
-
701
- del final_latents_cpu
702
- self._finalize()
703
-
704
- print("\n[SUCCESS] ETAPA 2 com 4 Workers Concluída")
705
- return video_path_out, latents_path_out, pixel_tensor
706
-
707
-
708
-
709
-
710
- def _get_next_transformer_device(self):
711
- """Retorna o próximo dispositivo transformer (round-robin)."""
712
- if not self.multi_gpu_enabled:
713
- return self.device_ltx
714
-
715
- device = self.transformer_devices[self.current_transformer_idx]
716
- self.current_transformer_idx = (self.current_transformer_idx + 1) % len(self.transformer_devices)
717
- return device
718
-
719
- def _get_next_vae_device(self):
720
- """Retorna o próximo dispositivo VAE (round-robin)."""
721
- if not self.multi_gpu_enabled:
722
- return self.device_vae
723
-
724
- device = self.vae_devices[self.current_vae_idx]
725
- self.current_vae_idx = (self.current_vae_idx + 1) % len(self.vae_devices)
726
- return device
727
-
728
-
729
- @torch.no_grad()
730
- def _upsample_and_filter_latents(self, latents: torch.Tensor) -> torch.Tensor:
731
- """Upsampling com suporte a múltiplos workers VAE."""
732
- if not self.latent_upsampler:
733
- raise ValueError("Latent Upsampler não está carregado.")
734
-
735
- # Selecionar worker VAE para upscaling
736
- upsample_device = self._get_next_vae_device()
737
- latents = latents.to(upsample_device)
738
-
739
- latents_unnormalized = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
740
- upsampled_latents_unnormalized = self.latent_upsampler(latents_unnormalized)
741
- upsampled_latents_normalized = normalize_latents(upsampled_latents_unnormalized, self.pipeline.vae, vae_per_channel_normalize=True)
742
-
743
- return adain_filter_latent(latents=upsampled_latents_normalized, reference_latents=latents)
744
-
745
-
746
- def get_gpu_usage(self):
747
- """Monitora o uso de VRAM em todas as 4 GPUs."""
748
- if not torch.cuda.is_available():
749
- return "CUDA não disponível"
750
-
751
- info = []
752
- for i in range(torch.cuda.device_count()):
753
- alloc = torch.cuda.memory_allocated(i) / 1024**3
754
- cached = torch.cuda.memory_reserved(i) / 1024**3
755
- info.append(f"GPU{i}: {alloc:.2f}GB / {cached:.2f}GB")
756
-
757
- return " | ".join(info)
758
-
759
-
760
-
761
-
762
-
763
- # --------------------------------------------------------------------------
764
- # --- Métodos Públicos (API do Serviço) ---
765
- # --------------------------------------------------------------------------
766
-
767
- def _load_image_to_tensor_with_resize_and_crop(
768
- self,
769
- image_input: Union[str, Image.Image],
770
- target_height: int = 512,
771
- target_width: int = 768,
772
- just_crop: bool = False,
773
- ) -> torch.Tensor:
774
- """Load and process an image into a tensor.
775
-
776
- Args:
777
- image_input: Either a file path (str) or a PIL Image object
778
- target_height: Desired height of output tensor
779
- target_width: Desired width of output tensor
780
- just_crop: If True, only crop the image to the target size without resizing
781
- """
782
- if isinstance(image_input, str):
783
- image = Image.open(image_input).convert("RGB")
784
- elif isinstance(image_input, Image.Image):
785
- image = image_input
786
- else:
787
- raise ValueError("image_input must be either a file path or a PIL Image object")
788
-
789
- input_width, input_height = image.size
790
- aspect_ratio_target = target_width / target_height
791
- aspect_ratio_frame = input_width / input_height
792
- if aspect_ratio_frame > aspect_ratio_target:
793
- new_width = int(input_height * aspect_ratio_target)
794
- new_height = input_height
795
- x_start = (input_width - new_width) // 2
796
- y_start = 0
797
- else:
798
- new_width = input_width
799
- new_height = int(input_width / aspect_ratio_target)
800
- x_start = 0
801
- y_start = (input_height - new_height) // 2
802
-
803
- image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
804
- if not just_crop:
805
- image = image.resize((target_width, target_height))
806
-
807
- image = np.array(image)
808
- image = cv2.GaussianBlur(image, (3, 3), 0)
809
- frame_tensor = torch.from_numpy(image).float()
810
- frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
811
- frame_tensor = frame_tensor.permute(2, 0, 1)
812
- frame_tensor = (frame_tensor / 127.5) - 1.0
813
- # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
814
- return frame_tensor.unsqueeze(0).unsqueeze(2)
815
-
816
- def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
817
- print(f"[DEBUG] Carregando condicionamento: {filepath}")
818
- tensor = self._load_image_to_tensor_with_resize_and_crop(filepath, height, width)
819
- tensor = torch.nn.functional.pad(tensor, padding_values)
820
- out = tensor.to(self.transformer_devices[0] , dtype=self.runtime_autocast_dtype) if self.transformer_devices[0] == "cuda" else tensor.to(self.transformer_devices[0] )
821
- print(f"[DEBUG] Cond shape={tuple(out.shape)} dtype={out.dtype} device={out.device}")
822
- return out
823
-
824
- def generate_low_resolution(
825
- self,
826
- prompt: str,
827
- negative_prompt: str,
828
- height: int,
829
- width: int,
830
- duration_secs: float,
831
- guidance_scale: float,
832
- seed: Optional[int] = None,
833
- image_filepaths: Optional[List[str]] = None
834
- ) -> Tuple[str, str, int]:
835
- """
836
- ETAPA 1: Gera um vídeo e latentes em resolução base a partir de um prompt e
837
- condicionamentos opcionais.
838
- """
839
- print("[INFO] Iniciando ETAPA 1: Geração de Baixa Resolução...")
840
-
841
- self._set_generation_environment()
842
-
843
- # --- Configuração de Seed e Diretórios ---
844
- used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
845
- #seed_everything(used_seed)
846
- print(f" - Usando Seed: {used_seed}")
847
-
848
- temp_dir = tempfile.mkdtemp(prefix="ltxv_low_")
849
- self._register_tmp_dir(temp_dir)
850
- results_dir = "/app/output"
851
- os.makedirs(results_dir, exist_ok=True)
852
-
853
- # --- Cálculo de Dimensões e Frames ---
854
- actual_num_frames = int(round(duration_secs * DEFAULT_FPS))
855
- downscaled_height, downscaled_width = self._calculate_downscaled_dims(height, width)
856
-
857
-
858
- height_padded = ((downscaled_height - 1) // 32 + 1) * 32
859
- width_padded = ((downscaled_width - 1) // 32 + 1) * 32
860
- padding_values = calculate_padding(downscaled_height, downscaled_width, height_padded, width_padded)
861
-
862
- conditioning_items = []
863
- for filepath in image_filepaths:
864
- cond_tensor = self._prepare_conditioning_tensor(filepath, downscaled_height, downscaled_width, padding_values)
865
- conditioning_items.append(ConditioningItem(cond_tensor, 0, 1.0))
866
-
867
-
868
- print(f" - Frames: {actual_num_frames}, Duração: {duration_secs}s")
869
- print(f" - Dimensões de Saída: {downscaled_height}x{downscaled_width}")
870
-
871
- # --- Execução da Pipeline ---
872
- with torch.autocast(device_type=self.transformer_devices[0] .split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.transformer_devices[0] == 'cuda')):
873
-
874
- first_pass_kwargs = {
875
- "prompt": prompt,
876
- "negative_prompt": negative_prompt,
877
- "height": downscaled_height,
878
- "width": downscaled_width,
879
- "num_frames": (actual_num_frames//8)+1,
880
- "frame_rate": int(DEFAULT_FPS),
881
- "generator": torch.Generator(device=self.transformer_devices[0] ).manual_seed(used_seed),
882
- "output_type": "latent",
883
- "conditioning_items": conditioning_items,
884
- "guidance_scale": float(guidance_scale),
885
- **(self.config.get("first_pass", {}))
886
- }
887
-
888
- print(" - Enviando para a pipeline LTX...")
889
- latents = self.pipeline(**first_pass_kwargs).images
890
- print(f" - Latentes gerados com shape: {latents.shape}")
891
-
892
-
893
- #_upsample_and_filter_latents
894
- latents = self._upsample_and_filter_latents(latents)
895
- print(f" - Latentes com upscaler: {latents.shape}")
896
-
897
- tensor_path = self._save_latents_to_disk(latents, "latents_low_res", used_seed)
898
-
899
- self._finalize()
900
-
901
- final_video_path, final_latents_path, _ = self.refine_texture_only(
902
- latents_path=tensor_path,
903
- prompt=prompt,
904
- negative_prompt=negative_prompt,
905
- guidance_scale=guidance_scale,
906
- seed=used_seed,
907
- conditioning_items=conditioning_items,
908
- )
909
-
910
- # --- Limpeza ---
911
- self._finalize()
912
- self._set_generation_environment()
913
-
914
- print("[SUCCESS] ETAPA 1 Concluída.")
915
- return final_video_path, final_latents_path, used_seed
916
-
917
-
918
- def apply_secondary_refinement(
919
- self,
920
- latents_path: str,
921
- prompt: str,
922
- negative_prompt: str,
923
- guidance_scale: float,
924
- seed: int,
925
- # Parâmetros para controlar a divisão principal (Nível 1)
926
- macro_chunk_size: int = 8,
927
- macro_overlap: int = 2
928
- ) -> str: # A função agora retorna apenas o caminho do vídeo final.
929
- """
930
- Função "ponte" aprimorada que orquestra um refinamento secundário
931
- usando uma lógica de MACRO-DIVISÃO aninhada para processar vídeos
932
- muito longos de forma robusta.
933
- """
934
-
935
- print("[LOG] Preparando ambiente da GPU para o refinamento...")
936
- self._set_generation_environment()
937
-
938
- print(f"[LOG] Carregando latentes principais de: {latents_path}")
939
- initial_latents = torch.load(latents_path).cpu()
940
- total_latents = initial_latents.shape[2]
941
-
942
- print(f"[LOG] Nível 1 (Macro): Calculando divisão para {total_latents} latentes...")
943
- macro_cuts, macro_segment_sizes = self._calculate_dynamic_cuts(
944
- total_latents,
945
- min_chunk_size=macro_chunk_size,
946
- overlap=macro_overlap
947
- )
948
- print(f"[LOG] Nível 1 (Macro): Trabalho dividido em {len(macro_cuts)} tarefas principais.")
949
-
950
- # 3. EXECUTAR CADA TAREFA EM UM LOOP
951
- pixel_results = []
952
- for i, (start, end) in enumerate(macro_cuts):
953
- task_id_str = f"[Tarefa {i+1}/{len(macro_cuts)}]"
954
- print(f"\n--- Processando {task_id_str} (latentes {start} a {end-1}) ---")
955
-
956
- latent_chunk = initial_latents[:, :, start:end, :, :]
957
- tensor_path = self._save_latents_to_disk(latent_chunk, "latents_chuck_i", seed)
958
-
959
- _video_path, _latents_path, pixel_tensor_chunk = self.refine_texture_only(
960
- latents_path=tensor_path,
961
- prompt=prompt,
962
- negative_prompt=negative_prompt,
963
- guidance_scale=guidance_scale,
964
- seed=seed + i, # Garante seeds diferentes para cada tarefa
965
- conditioning_items=None,
966
- )
967
-
968
- # Armazena o tensor de pixels resultante em memória
969
- pixel_results.append(pixel_tensor_chunk)
970
- torch.cuda.empty_cache() # Limpa VRAM entre as tarefas
971
-
972
-
973
- final_pixel_tensor = self._stitch_dynamic_chunks(
974
- pixel_chunks_list=pixel_results,
975
- segment_sizes=macro_segment_sizes,
976
- overlap=macro_overlap
977
- )
978
-
979
- print(f"[LOG] Costura final (Nível 1) concluída. Shape do tensor final: {final_pixel_tensor.shape}")
980
-
981
- # 5. SALVAR O VÍDEO FINAL E LIMPAR
982
- final_video_path = self._save_video_from_tensor(
983
- pixel_tensor=final_pixel_tensor,
984
- base_filename="final_video_stitched",
985
- seed=seed,
986
- # Salva o vídeo em um diretório temporário antes de movê-lo para a saída final
987
- temp_dir=tempfile.mkdtemp(prefix="ltxv_final_")
988
- )
989
-
990
- del pixel_results, final_pixel_tensor
991
- self._finalize() # Limpa todos os diretórios temporários registrados e a memória
992
-
993
- print(f"\n[SUCCESS] Processo de Macro-Divisão concluído. Vídeo final em: {final_video_path}")
994
- self._set_generation_environment()
995
-
996
- # Retorna apenas o caminho do vídeo final consolidado
997
- return final_video_path, latents_path
998
-
999
- def _calculate_dynamic_cuts(
1000
- self,
1001
- total_latents: int,
1002
- min_chunk_size: int = 5,
1003
- overlap: int = 2
1004
- ) -> tuple[list[tuple[int, int]], list[int]]:
1005
- """
1006
- Calcula dinamicamente os pontos de corte para 'X' chunks.
1007
- """
1008
- if total_latents <= min_chunk_size + overlap:
1009
- print(f" [LOG] Detecção: Vídeo muito curto ({total_latents} latentes). Usando 1 chunk.")
1010
- return [(0, total_latents)], [total_latents]
1011
-
1012
- # Regra: O cálculo principal é feito em (total - 2) latentes
1013
- effective_total_latents = total_latents - 2
1014
-
1015
- # Determina o número de chunks (X) para maximizar o uso da VRAM
1016
- num_chunks = effective_total_latents // min_chunk_size
1017
- if num_chunks == 0: # Garante pelo menos um chunk
1018
- num_chunks = 1
1019
-
1020
- # Distribui os latentes entre os chunks
1021
- base_size = effective_total_latents // num_chunks
1022
- remainder = effective_total_latents % num_chunks
1023
-
1024
- segment_sizes = []
1025
- for i in range(num_chunks):
1026
- size = base_size + (1 if i < remainder else 0)
1027
- segment_sizes.append(size)
1028
-
1029
- # Regra: Adiciona os 2 latentes restantes ao último chunk
1030
- segment_sizes[-1] += 2
1031
-
1032
- print(f" [LOG] Divisão dinâmica: {total_latents} latentes em {num_chunks} chunks.")
1033
- print(f" Tamanhos de conteúdo: {segment_sizes}")
1034
-
1035
- # Calcula os pontos de corte (start, end) com sobreposição
1036
- cut_points = []
1037
- cursor = 0
1038
- for i in range(num_chunks):
1039
- start_pos = cursor if i == 0 else cursor - overlap
1040
-
1041
- # O último chunk sempre vai até o final
1042
- end_pos = total_latents if i == num_chunks - 1 else cursor + segment_sizes[i] + overlap
1043
-
1044
- cut_points.append((start_pos, end_pos))
1045
- cursor += segment_sizes[i]
1046
-
1047
- return cut_points, segment_sizes
1048
-
1049
- def _stitch_dynamic_chunks(
1050
- self,
1051
- pixel_chunks_list: list[torch.Tensor],
1052
- segment_sizes: list[int],
1053
- overlap: int = 2
1054
- ) -> torch.Tensor:
1055
- """
1056
- Costura uma lista de chunks de pixels decodificados.
1057
- """
1058
- if not pixel_chunks_list:
1059
- return torch.empty(0)
1060
- if len(pixel_chunks_list) == 1:
1061
- return pixel_chunks_list[0]
1062
-
1063
- final_parts = []
1064
-
1065
- # 1. Processa o primeiro chunk
1066
- # Mantém apenas os frames correspondentes ao seu tamanho de conteúdo
1067
- first_chunk_frame_count = segment_sizes[0] * 8
1068
- final_parts.append(pixel_chunks_list[0][:, :, :first_chunk_frame_count, :, :])
1069
-
1070
- # 2. Processa os chunks restantes
1071
- for i in range(1, len(pixel_chunks_list)):
1072
- chunk = pixel_chunks_list[i]
1073
- # Descarta os frames da sobreposição inicial e pega todo o resto
1074
- discard_frames = overlap * 8
1075
- final_parts.append(chunk[:, :, discard_frames:, :, :])
1076
-
1077
- return torch.cat(final_parts, dim=2)
1078
-
1079
- def refine_texture_onl4y(
1080
- self,
1081
- latents_path: str,
1082
- prompt: str,
1083
- negative_prompt: str,
1084
- guidance_scale: float,
1085
- seed: Optional[int] = None,
1086
- conditioning_items: Optional[List[ConditioningItem]] = None
1087
- ) -> Tuple[str, str]:
1088
- """
1089
- Refina e decodifica latentes com gerenciamento explícito de modelos
1090
- na GPU e lógica de chunking dinâmico para máxima performance e robustez.
1091
- """
1092
- print("\n======================================================================")
1093
- print("====== [INFO] Iniciando ETAPA 2: Refinamento e Decodificação Dinâmica ======")
1094
- print("======================================================================\n")
1095
-
1096
- temp_dir = tempfile.mkdtemp(prefix="ltxv_refine_")
1097
- self._register_tmp_dir(temp_dir)
1098
- used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
1099
-
1100
- # --- FASE 1: GERAÇÃO DE LATENTES (TRABALHO DO TRANSFORMER) ---
1101
- print("[LOG] FASE 1: Geração de Latentes (Transformer na GPU)")
1102
- self._set_generation_environment()
1103
-
1104
- latents_to_refine = torch.load(latents_path).to(self.transformer_devices[0] )
1105
- print(f" [LOG] Latentes carregados para a GPU. Shape: {latents_to_refine.shape}")
1106
-
1107
- with torch.autocast(device_type=self.transformer_devices[0] .split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.device == 'cuda')):
1108
- refine_height = latents_to_refine.shape[3] * self.pipeline.vae_scale_factor
1109
- refine_width = latents_to_refine.shape[4] * self.pipeline.vae_scale_factor
1110
- second_pass_kwargs = {
1111
- "prompt": prompt, "negative_prompt": negative_prompt, "height": refine_height, "width": refine_width,
1112
- "frame_rate": int(DEFAULT_FPS), "num_frames": latents_to_refine.shape[2],
1113
- "latents": latents_to_refine, "guidance_scale": float(guidance_scale), "output_type": "latent",
1114
- "generator": torch.Generator(device=self.transformer_devices[0] ).manual_seed(used_seed),
1115
- "conditioning_items": conditioning_items, **(self.config.get("second_pass", {}))
1116
- }
1117
- print(" [LOG] Enviando para a pipeline de refinamento (Transformer)...")
1118
- final_latents = self.pipeline(**second_pass_kwargs).images
1119
- print(f" [LOG] [SUCESSO] Latentes refinados. Shape: {final_latents.shape}")
1120
-
1121
- print(" [LOG] Geração de latentes concluída. Movendo resultado para a CPU.")
1122
- final_latents_cpu = final_latents.cpu()
1123
- del final_latents, latents_to_refine
1124
- torch.cuda.empty_cache()
1125
-
1126
- # --- FASE 2: DECODIFICAÇÃO EM CHUNKS (TRABALHO DO VAE) ---
1127
- print("\n[LOG] FASE 2: Decodificação de Latentes (VAE na GPU)")
1128
- self._set_decode_environment()
1129
-
1130
- total_latents = final_latents_cpu.shape[2]
1131
-
1132
- # AQUI ESTÁ A MUDANÇA: Substituímos a lógica fixa pela chamada da função dinâmica.
1133
- pontos_de_corte, segment_sizes = self._calculate_dynamic_cuts(total_latents)
1134
-
1135
- if len(pontos_de_corte) == 1:
1136
- pixel_tensor = vae_manager_singleton.decode(
1137
- final_latents_cpu.to(self.transformer_devices[0] ),
1138
- decode_timestep=float(self.config.get("decode_timestep", 0.05))
1139
- ).cpu()
1140
- else:
1141
- print(f" [LOG] Ativando modo de janela deslizante para {len(pontos_de_corte)} chunks.")
1142
- pixel_chunks_list = []
1143
- for i, (start, end) in enumerate(pontos_de_corte):
1144
- # Garante que os slices sejam válidos dentro dos limites do tensor.
1145
- start, end = max(0, start), min(total_latents, end)
1146
- if start >= end: continue
1147
-
1148
- latent_chunk = final_latents_cpu[:, :, start:end, :, :]
1149
- print(f" -> Decodificando Grupo {i+1}/{len(pontos_de_corte)} (latentes {start} a {end-1}), shape: {latent_chunk.shape}")
1150
-
1151
- pixel_chunk = vae_manager_singleton.decode(
1152
- latent_chunk.to(self.transformer_devices[0] ),
1153
- decode_timestep=float(self.config.get("decode_timestep", 0.05))
1154
- )
1155
- pixel_chunks_list.append(pixel_chunk.cpu())
1156
- torch.cuda.empty_cache()
1157
-
1158
- print(" [LOG] Costurando os vídeos decodificados...")
1159
- pixel_tensor = self._stitch_dynamic_chunks(pixel_chunks_list, segment_sizes)
1160
-
1161
- print(f"\n[LOG] [SUCESSO] Tensor de pixels final montado na CPU com shape: {pixel_tensor.shape}")
1162
-
1163
- # --- FASE 3: SALVAMENTO E RESTAURAÇÃO DO AMBIENTE ---
1164
- print("\n[LOG] FASE 3: Salvamento e Restauração do Ambiente da GPU")
1165
-
1166
- video_path_out = self._save_video_from_tensor(pixel_tensor, "refined_video_final", used_seed, temp_dir)
1167
- latents_path_out = self._save_latents_to_disk(final_latents_cpu, "latents_refined_final", used_seed)
1168
-
1169
- print(" [LOG] Tarefa concluída. Restaurando ambiente de GERAÇÃO na GPU para a próxima execução...")
1170
- self._set_decode_environment()
1171
-
1172
- print(" [LOG] Liberando tensores finais da memória da CPU.")
1173
- del final_latents_cpu
1174
- self._finalize()
1175
-
1176
- return video_path_out, latents_path_out, pixel_tensor
1177
-
1178
-
1179
- def encode_latents_to_mp4(self, latents_path: str, fps: int = int(DEFAULT_FPS)) -> str:
1180
- """Decodifica um tensor de latentes salvo e o salva como um vídeo MP4."""
1181
- latents = torch.load(latents_path)
1182
- temp_dir = tempfile.mkdtemp(prefix="ltxv_enc_")
1183
- self._register_tmp_dir(temp_dir)
1184
- seed = random.randint(0, 99999) # Seed apenas para nome do arquivo
1185
-
1186
- try:
1187
- chunks = self._split_latents_with_overlap(latents)
1188
- pixel_chunks = []
1189
-
1190
- with torch.autocast(device_type=self.transformer_devices[0] .split(':')[0], dtype=self.runtime_autocast_dtype, enabled=(self.transformer_devices[0] == 'cuda')):
1191
- for chunk in chunks:
1192
- if chunk.shape[2] == 0: continue
1193
- pixel_chunk = vae_manager_singleton.decode(chunk.to(self.transformer_devices[0] ), decode_timestep=float(self.config.get("decode_timestep", 0.05)))
1194
- pixel_chunks.append(pixel_chunk)
1195
-
1196
- final_pixel_tensor = self._merge_chunks_with_overlap(pixel_chunks)
1197
- final_video_path = self._save_video_from_tensor(final_pixel_tensor, f"final_video_{seed}", seed, temp_dir, fps=fps)
1198
- return final_video_path
1199
-
1200
- except Exception as e:
1201
- print(f"[ERROR] Falha ao encodar latentes para MP4: {e}")
1202
- traceback.print_exc()
1203
- raise
1204
- finally:
1205
- self._finalize()
1206
-
1207
- def _finalize(self):
1208
- """Limpa a memória da GPU e os diretórios temporários."""
1209
- if LTXV_DEBUG:
1210
- print("[DEBUG] Finalize: iniciando limpeza...")
1211
-
1212
- gc.collect()
1213
- if torch.cuda.is_available():
1214
- torch.cuda.empty_cache()
1215
- torch.cuda.ipc_collect()
1216
-
1217
- # Limpa todos os diretórios temporários registrados
1218
- for d in list(self._tmp_dirs):
1219
- shutil.rmtree(d, ignore_errors=True)
1220
- self._tmp_dirs.remove(d)
1221
- if LTXV_DEBUG:
1222
- print(f"[DEBUG] Diretório temporário removido: {d}")
1223
-
1224
- def _load_config(self, config_filename: str) -> Dict:
1225
- """Carrega o arquivo de configuração YAML."""
1226
- config_path = LTX_VIDEO_REPO_DIR / "configs" / config_filename
1227
- print(f"[INFO] Carregando configuração de: {config_path}")
1228
- with open(config_path, "r") as file:
1229
- return yaml.safe_load(file)
1230
-
1231
- def _load_models_from_hub(self):
1232
- """Baixa e cria as instâncias da pipeline e do upsampler."""
1233
- t0 = time.perf_counter()
1234
- LTX_REPO = "Lightricks/LTX-Video"
1235
- print("[INFO] Baixando checkpoint principal...")
1236
- self.config["checkpoint_path"] = hf_hub_download(
1237
- repo_id=LTX_REPO, filename=self.config["checkpoint_path"],
1238
- token=os.getenv("HF_TOKEN"),
1239
- )
1240
- print(f"[INFO] Checkpoint principal em: {self.config['checkpoint_path']}")
1241
-
1242
- print("[INFO] Construindo pipeline...")
1243
- pipeline = create_ltx_video_pipeline(
1244
- ckpt_path=self.config["checkpoint_path"],
1245
- precision=self.config["precision"],
1246
- text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
1247
- sampler=self.config["sampler"],
1248
- device="cpu", # Carrega em CPU primeiro
1249
- enhance_prompt=False
1250
- )
1251
- print("[INFO] Pipeline construída.")
1252
-
1253
- latent_upsampler = None
1254
- if self.config.get("spatial_upscaler_model_path"):
1255
- print("[INFO] Baixando upscaler espacial...")
1256
- self.config["spatial_upscaler_model_path"] = hf_hub_download(
1257
- repo_id=LTX_REPO, filename=self.config["spatial_upscaler_model_path"],
1258
- token=os.getenv("HF_TOKEN")
1259
- )
1260
- print(f"[INFO] Upscaler em: {self.config['spatial_upscaler_model_path']}")
1261
-
1262
- print("[INFO] Construindo latent_upsampler...")
1263
- latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
1264
- print("[INFO] Latent upsampler construído.")
1265
-
1266
- print(f"[INFO] Carregamento de modelos concluído em {time.perf_counter()-t0:.2f}s")
1267
- return pipeline, latent_upsampler
1268
-
1269
- def _move_models_to_device(self):
1270
- """Move os modelos carregados para o dispositivo de computação (GPU/CPU)."""
1271
- print(f"[INFO] Movendo modelos para o dispositivo: {self.transformer_devices[0] }")
1272
- self.pipeline.to(self.transformer_devices[0] )
1273
- if self.latent_upsampler:
1274
- self.latent_upsampler.to(self.transformer_devices[0] )
1275
-
1276
- def _get_precision_dtype(self) -> torch.dtype:
1277
- """Determina o dtype para autocast com base na configuração de precisão."""
1278
- prec = str(self.config.get("precision", "")).lower()
1279
- if prec in ["float8_e4m3fn", "bfloat16"]:
1280
- return torch.bfloat16
1281
- elif prec == "mixed_precision":
1282
- return torch.float16
1283
- return torch.float32
1284
-
1285
- @torch.no_grad()
1286
- def _upsample_and_filter_latents4(self, latents: torch.Tensor) -> torch.Tensor:
1287
- """Aplica o upsample espacial e o filtro AdaIN aos latentes."""
1288
- if not self.latent_upsampler:
1289
- raise ValueError("Latent Upsampler não está carregado para a operação de upscale.")
1290
-
1291
- latents_unnormalized = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
1292
- upsampled_latents_unnormalized = self.latent_upsampler(latents_unnormalized)
1293
- upsampled_latents_normalized = normalize_latents(upsampled_latents_unnormalized, self.pipeline.vae, vae_per_channel_normalize=True)
1294
-
1295
- # Filtro AdaIN para manter consistência de cor/estilo com o vídeo de baixa resolução
1296
- return adain_filter_latent(latents=upsampled_latents_normalized, reference_latents=latents)
1297
-
1298
- def _prepare_conditioning_tensor_from_path(self, filepath: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
1299
- """Carrega uma imagem, redimensiona, aplica padding e move para o dispositivo."""
1300
- tensor = self._load_image_to_tensor_with_resize_and_crop(filepath, height, width)
1301
- tensor = F.pad(tensor, padding)
1302
- return tensor.to(self.transformer_devices[0] , dtype=self.runtime_autocast_dtype)
1303
-
1304
- def _calculate_downscaled_dims(self, height: int, width: int) -> Tuple[int, int]:
1305
- """Calcula as dimensões para o primeiro passo (baixa resolução)."""
1306
- height_padded = ((height - 1) // 8 + 1) * 8
1307
- width_padded = ((width - 1) // 8 + 1) * 8
1308
-
1309
- downscale_factor = self.config.get("downscale_factor", 0.6666666)
1310
- vae_scale_factor = self.pipeline.vae_scale_factor
1311
-
1312
- target_w = int(width_padded * downscale_factor)
1313
- downscaled_width = target_w - (target_w % vae_scale_factor)
1314
-
1315
- target_h = int(height_padded * downscale_factor)
1316
- downscaled_height = target_h - (target_h % vae_scale_factor)
1317
-
1318
- return downscaled_height, downscaled_width
1319
-
1320
- def _split_latents_with_overlap(self, latents: torch.Tensor, overlap: int = 1) -> List[torch.Tensor]:
1321
- """Divide um tensor de latentes em dois chunks com sobreposição."""
1322
- total_frames = latents.shape[2]
1323
- if total_frames <= overlap:
1324
- return [latents]
1325
-
1326
- mid_point = max(overlap, total_frames // 2)
1327
- chunk1 = latents[:, :, :mid_point, :, :]
1328
- # O segundo chunk começa 'overlap' frames antes para criar a sobreposição
1329
- chunk2 = latents[:, :, mid_point - overlap:, :, :]
1330
-
1331
- return [c for c in [chunk1, chunk2] if c.shape[2] > 0]
1332
-
1333
- def _merge_chunks_with_overlap(self, chunks: List[torch.Tensor], overlap: int = 1) -> torch.Tensor:
1334
- """Junta uma lista de chunks, removendo a sobreposição."""
1335
- if not chunks:
1336
- return torch.empty(0)
1337
- if len(chunks) == 1:
1338
- return chunks[0]
1339
-
1340
- # Pega o primeiro chunk sem o frame de sobreposição final
1341
- merged_list = [chunks[0][:, :, :-overlap, :, :]]
1342
- # Adiciona os chunks restantes
1343
- merged_list.extend(chunks[1:])
1344
-
1345
- return torch.cat(merged_list, dim=2)
1346
-
1347
- def _save_latents_to_disk(self, latents_tensor: torch.Tensor, base_filename: str, seed: int) -> str:
1348
- """Salva um tensor de latentes em um arquivo .pt."""
1349
- latents_cpu = latents_tensor.detach().to("cpu")
1350
- tensor_path = RESULTS_DIR / f"{base_filename}_{seed}.pt"
1351
- torch.save(latents_cpu, tensor_path)
1352
- if LTXV_DEBUG:
1353
- print(f"[DEBUG] Latentes salvos em: {tensor_path}")
1354
- return str(tensor_path)
1355
-
1356
- def _save_video_from_tensor(self, pixel_tensor: torch.Tensor, base_filename: str, seed: int, temp_dir: str, fps: int = int(DEFAULT_FPS)) -> str:
1357
- """Salva um tensor de pixels como um arquivo de vídeo MP4."""
1358
- temp_path = os.path.join(temp_dir, f"{base_filename}_{seed}.mp4")
1359
- video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=fps)
1360
-
1361
- final_path = RESULTS_DIR / f"{base_filename}_{seed}.mp4"
1362
- shutil.move(temp_path, final_path)
1363
- print(f"[INFO] Vídeo final salvo em: {final_path}")
1364
- return str(final_path)
1365
-
1366
- def _seed_everething(self, seed: int):
1367
- random.seed(seed)
1368
- np.random.seed(seed)
1369
- torch.manual_seed(seed)
1370
- if torch.cuda.is_available():
1371
- torch.cuda.manual_seed(seed)
1372
- if torch.backends.mps.is_available():
1373
- torch.mps.manual_seed(seed)
1374
-
1375
- def _register_tmp_dir(self, dir_path: str):
1376
- """Registra um diretório temporário para limpeza posterior."""
1377
- if dir_path and os.path.isdir(dir_path):
1378
- self._tmp_dirs.add(dir_path)
1379
- if LTXV_DEBUG:
1380
- print(f"[DEBUG] Diretório temporário registrado: {dir_path}")
1381
-
1382
- # ==============================================================================
1383
- # 4. INSTANCIAÇÃO E PONTO DE ENTRADA (Exemplo)
1384
- # ==============================================================================
1385
-
1386
- print("Criando instância do VideoService. O carregamento do modelo começará agora...")
1387
- video_generation_service = VideoService()
1388
- print("Instância do VideoService pronta para uso.")