Spaces:
Paused
Paused
Upload 2 files
Browse files- api/ltx_server_refactored_complete (1).py +288 -0
- api/vae_server.py +162 -0
api/ltx_server_refactored_complete (1).py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FILE: api/ltx_server_refactored_complete.py
|
| 2 |
+
# DESCRIPTION: Final orchestrator for LTX-Video generation.
|
| 3 |
+
# This version internalizes conditioning item preparation, accepting a raw
|
| 4 |
+
# list of media items directly in its main generation function for maximum simplicity and encapsulation.
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import shutil
|
| 11 |
+
import sys
|
| 12 |
+
import tempfile
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import yaml
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
|
| 23 |
+
# ==============================================================================
|
| 24 |
+
# --- SETUP E IMPORTAÇÕES DO PROJETO ---
|
| 25 |
+
# ==============================================================================
|
| 26 |
+
|
| 27 |
+
# Configuração de logging e supressão de warnings
|
| 28 |
+
import warnings
|
| 29 |
+
warnings.filterwarnings("ignore")
|
| 30 |
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 31 |
+
log_level = os.environ.get("ADUC_LOG_LEVEL", "INFO").upper()
|
| 32 |
+
logging.basicConfig(level=log_level, format='[%(levelname)s] [%(name)s] %(message)s')
|
| 33 |
+
|
| 34 |
+
# --- Constantes de Configuração ---
|
| 35 |
+
DEPS_DIR = Path("/data")
|
| 36 |
+
LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
|
| 37 |
+
RESULTS_DIR = Path("/app/output")
|
| 38 |
+
DEFAULT_FPS = 24.0
|
| 39 |
+
FRAMES_ALIGNMENT = 8
|
| 40 |
+
LTX_REPO_ID = "Lightricks/LTX-Video"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# --- Módulos da nossa Arquitetura ---
|
| 44 |
+
try:
|
| 45 |
+
from api.gpu_manager import gpu_manager
|
| 46 |
+
from api.vae_server import vae_server_singleton
|
| 47 |
+
from tools.video_encode_tool import video_encode_tool_singleton
|
| 48 |
+
from api.ltx.ltx_utils import build_ltx_pipeline_on_cpu, seed_everything
|
| 49 |
+
from api.ltx_pool_manager import LatentConditioningItem
|
| 50 |
+
from api.utils.debug_utils import log_function_io
|
| 51 |
+
except ImportError as e:
|
| 52 |
+
logging.critical(f"A crucial import from the local API/architecture failed. Error: {e}", exc_info=True)
|
| 53 |
+
sys.exit(1)
|
| 54 |
+
|
| 55 |
+
# ==============================================================================
|
| 56 |
+
# --- CLASSE DE SERVIÇO (O ORQUESTRADOR) ---
|
| 57 |
+
# ==============================================================================
|
| 58 |
+
|
| 59 |
+
class VideoService:
|
| 60 |
+
"""
|
| 61 |
+
Orchestrates the high-level logic of video generation, with internalized
|
| 62 |
+
conditioning item preparation.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
@log_function_io
|
| 66 |
+
def __init__(self):
|
| 67 |
+
t0 = time.time()
|
| 68 |
+
logging.info("Initializing VideoService Orchestrator...")
|
| 69 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
target_main_device_str = str(gpu_manager.get_ltx_device())
|
| 72 |
+
target_vae_device_str = str(gpu_manager.get_ltx_vae_device())
|
| 73 |
+
logging.info(f"LTX allocated to devices: Main='{target_main_device_str}', VAE='{target_vae_device_str}'")
|
| 74 |
+
|
| 75 |
+
self.config = self._load_config()
|
| 76 |
+
self._resolve_model_paths_from_cache()
|
| 77 |
+
|
| 78 |
+
self.pipeline, self.latent_upsampler = build_ltx_pipeline_on_cpu(self.config)
|
| 79 |
+
|
| 80 |
+
self.main_device = torch.device("cpu")
|
| 81 |
+
self.vae_device = torch.device("cpu")
|
| 82 |
+
self.move_to_device(main_device_str=target_main_device_str, vae_device_str=target_vae_device_str)
|
| 83 |
+
|
| 84 |
+
self._apply_precision_policy()
|
| 85 |
+
logging.info(f"VideoService ready. Startup time: {time.time() - t0:.2f}s")
|
| 86 |
+
|
| 87 |
+
def _load_config(self) -> Dict:
|
| 88 |
+
"""Loads the YAML configuration file."""
|
| 89 |
+
config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
|
| 90 |
+
with open(config_path, "r") as file:
|
| 91 |
+
return yaml.safe_load(file)
|
| 92 |
+
|
| 93 |
+
def _resolve_model_paths_from_cache(self):
|
| 94 |
+
"""Finds the absolute paths to model files in the cache and updates the in-memory config."""
|
| 95 |
+
logging.info("Resolving model paths from Hugging Face cache...")
|
| 96 |
+
cache_dir = os.environ.get("HF_HOME")
|
| 97 |
+
try:
|
| 98 |
+
main_ckpt_path = hf_hub_download(repo_id=LTX_REPO_ID, filename=self.config["checkpoint_path"], cache_dir=cache_dir)
|
| 99 |
+
self.config["checkpoint_path"] = main_ckpt_path
|
| 100 |
+
if self.config.get("spatial_upscaler_model_path"):
|
| 101 |
+
upscaler_path = hf_hub_download(repo_id=LTX_REPO_ID, filename=self.config["spatial_upscaler_model_path"], cache_dir=cache_dir)
|
| 102 |
+
self.config["spatial_upscaler_model_path"] = upscaler_path
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logging.critical(f"Failed to resolve model paths. Ensure setup.py ran correctly. Error: {e}", exc_info=True)
|
| 105 |
+
sys.exit(1)
|
| 106 |
+
|
| 107 |
+
@log_function_io
|
| 108 |
+
def move_to_device(self, main_device_str: str, vae_device_str: str):
|
| 109 |
+
"""Moves pipeline components to their designated target devices."""
|
| 110 |
+
target_main_device = torch.device(main_device_str)
|
| 111 |
+
target_vae_device = torch.device(vae_device_str)
|
| 112 |
+
self.main_device = target_main_device
|
| 113 |
+
self.vae_device = target_vae_device
|
| 114 |
+
self.pipeline.to(self.main_device)
|
| 115 |
+
self.pipeline.vae.to(self.vae_device)
|
| 116 |
+
if self.latent_upsampler: self.latent_upsampler.to(self.main_device)
|
| 117 |
+
logging.info("LTX models successfully moved to target devices.")
|
| 118 |
+
|
| 119 |
+
def move_to_cpu(self):
|
| 120 |
+
"""Moves all LTX components to CPU to free VRAM for other services."""
|
| 121 |
+
self.move_to_device(main_device_str="cpu", vae_device_str="cpu")
|
| 122 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 123 |
+
|
| 124 |
+
def finalize(self):
|
| 125 |
+
"""Cleans up GPU memory after a generation task."""
|
| 126 |
+
gc.collect()
|
| 127 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 128 |
+
try: torch.cuda.ipc_collect();
|
| 129 |
+
except Exception: pass
|
| 130 |
+
|
| 131 |
+
# ==========================================================================
|
| 132 |
+
# --- LÓGICA DE NEGÓCIO: ORQUESTRADOR PÚBLICO UNIFICADO ---
|
| 133 |
+
# ==========================================================================
|
| 134 |
+
|
| 135 |
+
@log_function_io
|
| 136 |
+
def generate_low_resolution(
|
| 137 |
+
self,
|
| 138 |
+
prompt_list: List[str],
|
| 139 |
+
initial_media_items: Optional[List[Tuple[Union[str, Image.Image, torch.Tensor], int, float]]] = None,
|
| 140 |
+
**kwargs
|
| 141 |
+
) -> Tuple[Optional[str], Optional[str], Optional[int]]:
|
| 142 |
+
"""
|
| 143 |
+
[UNIFIED ORCHESTRATOR] Generates a low-resolution video from a prompt and a raw list of media items.
|
| 144 |
+
"""
|
| 145 |
+
logging.info("Starting unified low-resolution generation...")
|
| 146 |
+
used_seed = self._get_random_seed()
|
| 147 |
+
seed_everything(used_seed)
|
| 148 |
+
logging.info(f"Using randomly generated seed: {used_seed}")
|
| 149 |
+
|
| 150 |
+
if not prompt_list: raise ValueError("Prompt is empty or contains no valid lines.")
|
| 151 |
+
|
| 152 |
+
is_narrative = len(prompt_list) > 1
|
| 153 |
+
num_chunks = len(prompt_list)
|
| 154 |
+
total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
|
| 155 |
+
frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT)
|
| 156 |
+
overlap_frames = 9 if is_narrative else 0
|
| 157 |
+
|
| 158 |
+
initial_conditions = []
|
| 159 |
+
if initial_media_items:
|
| 160 |
+
logging.info("Preparing initial conditioning items from raw media list...")
|
| 161 |
+
initial_conditions = vae_server_singleton.generate_conditioning_items(
|
| 162 |
+
media_items=[item[0] for item in initial_media_items],
|
| 163 |
+
target_frames=[item[1] for item in initial_media_items],
|
| 164 |
+
strengths=[item[2] for item in initial_media_items],
|
| 165 |
+
target_resolution=(kwargs['height'], kwargs['width'])
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
temp_latent_paths = []
|
| 169 |
+
overlap_condition_item: Optional[LatentConditioningItem] = None
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
for i, chunk_prompt in enumerate(prompt_list):
|
| 173 |
+
logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
|
| 174 |
+
|
| 175 |
+
if i < num_chunks - 1:
|
| 176 |
+
current_frames_base = frames_per_chunk
|
| 177 |
+
else:
|
| 178 |
+
processed_frames_base = (num_chunks - 1) * frames_per_chunk
|
| 179 |
+
current_frames_base = total_frames - processed_frames_base
|
| 180 |
+
|
| 181 |
+
current_frames = current_frames_base + (overlap_frames if i > 0 else 0)
|
| 182 |
+
current_frames = self._align(current_frames, alignment_rule='n*8+1')
|
| 183 |
+
|
| 184 |
+
current_conditions = initial_conditions if i == 0 else []
|
| 185 |
+
if overlap_condition_item: current_conditions.append(overlap_condition_item)
|
| 186 |
+
|
| 187 |
+
chunk_latents = self._generate_single_chunk_low(
|
| 188 |
+
prompt=chunk_prompt, num_frames=current_frames, seed=used_seed + i,
|
| 189 |
+
conditioning_items=current_conditions, **kwargs
|
| 190 |
+
)
|
| 191 |
+
if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for scene {i+1}.")
|
| 192 |
+
|
| 193 |
+
if is_narrative and i < num_chunks - 1:
|
| 194 |
+
overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
|
| 195 |
+
overlap_condition_item = LatentConditioningItem(
|
| 196 |
+
latent_tensor=overlap_latents.cpu(),
|
| 197 |
+
media_frame_number=0,
|
| 198 |
+
conditioning_strength=1.0
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
|
| 202 |
+
|
| 203 |
+
chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt"
|
| 204 |
+
torch.save(chunk_latents.cpu(), chunk_path)
|
| 205 |
+
temp_latent_paths.append(chunk_path)
|
| 206 |
+
|
| 207 |
+
base_filename = "narrative_video" if is_narrative else "single_video"
|
| 208 |
+
all_tensors_cpu = [torch.load(p) for p in temp_latent_paths]
|
| 209 |
+
final_latents = torch.cat(all_tensors_cpu, dim=2)
|
| 210 |
+
|
| 211 |
+
video_path, latents_path = self._finalize_generation(final_latents, base_filename, used_seed)
|
| 212 |
+
return video_path, latents_path, used_seed
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logging.error(f"Error during unified generation: {e}", exc_info=True)
|
| 215 |
+
return None, None, None
|
| 216 |
+
finally:
|
| 217 |
+
for path in temp_latent_paths:
|
| 218 |
+
if path.exists(): path.unlink()
|
| 219 |
+
self.finalize()
|
| 220 |
+
|
| 221 |
+
# ==========================================================================
|
| 222 |
+
# --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
|
| 223 |
+
# ==========================================================================
|
| 224 |
+
|
| 225 |
+
def _log_conditioning_items(self, items: List[Union[ConditioningItem, LatentConditioningItem]]):
|
| 226 |
+
"""Logs detailed information about a list of ConditioningItem objects."""
|
| 227 |
+
if logging.getLogger().isEnabledFor(logging.DEBUG):
|
| 228 |
+
# (Lógica de logging para debug)
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
@log_function_io
|
| 232 |
+
def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]:
|
| 233 |
+
"""[WORKER] Calls the pipeline to generate a single chunk of latents."""
|
| 234 |
+
# (A lógica desta função permanece a mesma)
|
| 235 |
+
pass # Placeholder
|
| 236 |
+
|
| 237 |
+
@log_function_io
|
| 238 |
+
def _finalize_generation(self, final_latents: torch.Tensor, base_filename: str, seed: int) -> Tuple[str, str]:
|
| 239 |
+
"""Consolidates latents, decodes them to video, and saves final artifacts."""
|
| 240 |
+
logging.info("Finalizing generation: decoding latents to video.")
|
| 241 |
+
final_latents_path = RESULTS_DIR / f"latents_{base_filename}_{seed}.pt"
|
| 242 |
+
torch.save(final_latents, final_latents_path)
|
| 243 |
+
logging.info(f"Final latents saved to: {final_latents_path}")
|
| 244 |
+
|
| 245 |
+
pixel_tensor = vae_server_singleton.decode_to_pixels(
|
| 246 |
+
final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05))
|
| 247 |
+
)
|
| 248 |
+
video_path = self._save_and_log_video(pixel_tensor, f"{base_filename}_{seed}")
|
| 249 |
+
return str(video_path), str(final_latents_path)
|
| 250 |
+
|
| 251 |
+
def _apply_ui_overrides(self, config_dict: Dict, overrides: Dict):
|
| 252 |
+
"""Applies advanced settings from the UI to a config dictionary."""
|
| 253 |
+
# (Lógica de overrides da UI permanece a mesma)
|
| 254 |
+
pass # Placeholder
|
| 255 |
+
|
| 256 |
+
def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
|
| 257 |
+
"""Saves a pixel tensor (on CPU) to an MP4 file."""
|
| 258 |
+
# (Lógica de salvar vídeo permanece a mesma)
|
| 259 |
+
pass # Placeholder
|
| 260 |
+
|
| 261 |
+
def _apply_precision_policy(self):
|
| 262 |
+
# (Lógica de precisão permanece a mesma)
|
| 263 |
+
pass # Placeholder
|
| 264 |
+
|
| 265 |
+
def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT, alignment_rule: str = 'default') -> int:
|
| 266 |
+
"""Aligns a dimension based on a rule."""
|
| 267 |
+
if alignment_rule == 'n*8+1':
|
| 268 |
+
return ((dim - 1) // alignment) * alignment + 1
|
| 269 |
+
return ((dim - 1) // alignment + 1) * alignment
|
| 270 |
+
|
| 271 |
+
def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
|
| 272 |
+
num_frames = int(round(duration_s * DEFAULT_FPS))
|
| 273 |
+
aligned_frames = self._align(num_frames, alignment=FRAMES_ALIGNMENT)
|
| 274 |
+
return max(aligned_frames, min_frames)
|
| 275 |
+
|
| 276 |
+
def _get_random_seed(self) -> int:
|
| 277 |
+
"""Always generates and returns a new random seed."""
|
| 278 |
+
return random.randint(0, 2**32 - 1)
|
| 279 |
+
|
| 280 |
+
# ==============================================================================
|
| 281 |
+
# --- INSTANCIAÇÃO SINGLETON ---
|
| 282 |
+
# ==============================================================================
|
| 283 |
+
try:
|
| 284 |
+
video_generation_service = VideoService()
|
| 285 |
+
logging.info("Global VideoService orchestrator instance created successfully.")
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logging.critical(f"Failed to initialize VideoService: {e}", exc_info=True)
|
| 288 |
+
sys.exit(1)
|
api/vae_server.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FILE: api/vae_server.py
|
| 2 |
+
# DESCRIPTION: A dedicated, "hot" VAE service specialist.
|
| 3 |
+
# It loads the VAE model onto a dedicated GPU and keeps it in memory
|
| 4 |
+
# to handle all encoding and decoding requests with minimal latency.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Union, Tuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from api.ltx_pool_manager import LatentConditioningItem
|
| 18 |
+
from api.gpu_manager import gpu_manager
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# --- Importações da Arquitetura e do LTX ---
|
| 22 |
+
try:
|
| 23 |
+
# Adiciona o path para as bibliotecas do LTX
|
| 24 |
+
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
|
| 25 |
+
if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
|
| 26 |
+
sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
|
| 27 |
+
|
| 28 |
+
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 29 |
+
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
|
| 30 |
+
except ImportError as e:
|
| 31 |
+
raise ImportError(f"A crucial import failed for VaeServer. Check dependencies. Error: {e}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class VaeServer:
|
| 35 |
+
_instance = None
|
| 36 |
+
|
| 37 |
+
def __new__(cls, *args, **kwargs):
|
| 38 |
+
if cls._instance is None:
|
| 39 |
+
cls._instance = super().__new__(cls)
|
| 40 |
+
cls._instance._initialized = False
|
| 41 |
+
return cls._instance
|
| 42 |
+
|
| 43 |
+
def __init__(self):
|
| 44 |
+
if self._initialized: return
|
| 45 |
+
|
| 46 |
+
logging.info("⚙️ Initializing VaeServer Singleton...")
|
| 47 |
+
t0 = time.time()
|
| 48 |
+
|
| 49 |
+
# 1. Obter o dispositivo VAE dedicado do gerenciador central
|
| 50 |
+
self.device = gpu_manager.get_ltx_vae_device()
|
| 51 |
+
|
| 52 |
+
# 2. Carregar o modelo VAE do checkpoint do LTX
|
| 53 |
+
# Assumimos que o setup.py já baixou os modelos.
|
| 54 |
+
try:
|
| 55 |
+
from api.ltx_pool_manager import ltx_pool_manager
|
| 56 |
+
# Reutiliza a configuração e o pipeline já carregados pelo LTX Pool Manager
|
| 57 |
+
# para garantir que estamos usando o mesmo VAE.
|
| 58 |
+
self.vae = ltx_pool_manager.get_pipeline().vae
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logging.critical(f"Failed to get VAE from LTXPoolManager. Is it initialized first? Error: {e}", exc_info=True)
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
# 3. Garante que o VAE está no dispositivo correto e em modo de avaliação
|
| 64 |
+
self.vae.to(self.device)
|
| 65 |
+
self.vae.eval()
|
| 66 |
+
self.dtype = self.vae.dtype
|
| 67 |
+
|
| 68 |
+
self._initialized = True
|
| 69 |
+
logging.info(f"✅ VaeServer ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s")
|
| 70 |
+
|
| 71 |
+
def _cleanup_gpu(self):
|
| 72 |
+
"""Limpa a VRAM da GPU do VAE."""
|
| 73 |
+
if torch.cuda.is_available():
|
| 74 |
+
with torch.cuda.device(self.device):
|
| 75 |
+
torch.cuda.empty_cache()
|
| 76 |
+
|
| 77 |
+
def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor:
|
| 78 |
+
"""Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera."""
|
| 79 |
+
if isinstance(item, Image.Image):
|
| 80 |
+
from PIL import ImageOps
|
| 81 |
+
img = item.convert("RGB")
|
| 82 |
+
# Redimensiona mantendo a proporção e cortando o excesso
|
| 83 |
+
processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
|
| 84 |
+
image_np = np.array(processed_img).astype(np.float32) / 255.0
|
| 85 |
+
tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW
|
| 86 |
+
elif isinstance(item, torch.Tensor):
|
| 87 |
+
# Se já for um tensor, apenas garante que está no formato CHW
|
| 88 |
+
if item.ndim == 4 and item.shape[0] == 1: # Remove dimensão de batch se houver
|
| 89 |
+
tensor = item.squeeze(0)
|
| 90 |
+
elif item.ndim == 3:
|
| 91 |
+
tensor = item
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}")
|
| 94 |
+
else:
|
| 95 |
+
raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}")
|
| 96 |
+
|
| 97 |
+
# Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1]
|
| 98 |
+
tensor_5d = tensor.unsqueeze(0).unsqueeze(2) # Adiciona B=1 e F=1
|
| 99 |
+
return (tensor_5d * 2.0) - 1.0
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def generate_conditioning_items(
|
| 103 |
+
self,
|
| 104 |
+
media_items: List[Union[Image.Image, torch.Tensor]],
|
| 105 |
+
target_frames: List[int],
|
| 106 |
+
strengths: List[float],
|
| 107 |
+
target_resolution: Tuple[int, int]
|
| 108 |
+
) -> List[LatentConditioningItem]:
|
| 109 |
+
"""
|
| 110 |
+
[FUNÇÃO PRINCIPAL]
|
| 111 |
+
Converte uma lista de imagens (PIL ou tensores de pixel) em uma lista de
|
| 112 |
+
LatentConditioningItem, pronta para ser usada pelo pipeline LTX corrigido.
|
| 113 |
+
"""
|
| 114 |
+
t0 = time.time()
|
| 115 |
+
logging.info(f"Generating {len(media_items)} latent conditioning items...")
|
| 116 |
+
|
| 117 |
+
if not (len(media_items) == len(target_frames) == len(strengths)):
|
| 118 |
+
raise ValueError("As listas de media_items, target_frames e strengths devem ter o mesmo tamanho.")
|
| 119 |
+
|
| 120 |
+
conditioning_items = []
|
| 121 |
+
try:
|
| 122 |
+
for item, frame, strength in zip(media_items, target_frames, strengths):
|
| 123 |
+
# 1. Prepara a imagem/tensor para o formato de pixel correto
|
| 124 |
+
pixel_tensor = self._preprocess_input(item, target_resolution)
|
| 125 |
+
|
| 126 |
+
# 2. Move o tensor de pixel para a GPU do VAE e encoda para latente
|
| 127 |
+
pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype)
|
| 128 |
+
latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True)
|
| 129 |
+
|
| 130 |
+
# 3. Cria o LatentConditioningItem com o latente (movido para CPU para evitar manter na VRAM)
|
| 131 |
+
conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength))
|
| 132 |
+
|
| 133 |
+
logging.info(f"Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.")
|
| 134 |
+
return conditioning_items
|
| 135 |
+
finally:
|
| 136 |
+
self._cleanup_gpu()
|
| 137 |
+
|
| 138 |
+
@torch.no_grad()
|
| 139 |
+
def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
| 140 |
+
"""Decodifica um tensor latente para um tensor de pixels na CPU."""
|
| 141 |
+
t0 = time.time()
|
| 142 |
+
try:
|
| 143 |
+
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
|
| 144 |
+
num_items_in_batch = latent_tensor_gpu.shape[0]
|
| 145 |
+
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype)
|
| 146 |
+
|
| 147 |
+
pixels = vae_decode(
|
| 148 |
+
latent_tensor_gpu, self.vae, is_video=True,
|
| 149 |
+
timestep=timestep_tensor, vae_per_channel_normalize=True
|
| 150 |
+
)
|
| 151 |
+
logging.info(f"Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.")
|
| 152 |
+
return pixels.cpu() # Retorna na CPU
|
| 153 |
+
finally:
|
| 154 |
+
self._cleanup_gpu()
|
| 155 |
+
|
| 156 |
+
# --- Instância Singleton ---
|
| 157 |
+
# A inicialização ocorre quando o módulo é importado pela primeira vez.
|
| 158 |
+
try:
|
| 159 |
+
vae_server_singleton = VaeServer()
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logging.critical("CRITICAL: Failed to initialize VaeServer singleton.", exc_info=True)
|
| 162 |
+
vae_server_singleton = None
|