Spaces:
Paused
Paused
File size: 8,102 Bytes
97682d1 1cacf10 c9413de 1cacf10 c9413de 9a6b3d7 1cacf10 c9413de 9a6b3d7 1cacf10 e13ea4b 9a6b3d7 c9413de 9a6b3d7 1cacf10 c9413de 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 e13ea4b 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 9a6b3d7 c9413de 1cacf10 c9413de 1cacf10 c9413de 9a6b3d7 1cacf10 c9413de 1cacf10 9a6b3d7 1cacf10 c9413de 1cacf10 9a6b3d7 1cacf10 9a6b3d7 1cacf10 c9413de 1cacf10 c9413de 1cacf10 c9413de 1cacf10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# FILE: api/ltx/vae_aduc_pipeline.py
# DESCRIPTION: A high-level client for submitting VAE-related jobs to the LTXAducManager pool.
# It handles encoding media to latents, decoding latents to pixels, and creating ConditioningItems.
import logging
import time
import torch
import os
import torchvision.transforms.functional as TVF
from PIL import Image
from typing import List, Union, Tuple, Literal, Optional
from dataclasses import dataclass
from pathlib import Path
import sys
# O cliente importa o MANAGER para submeter os trabalhos ao pool de workers.
from api.ltx.ltx_aduc_manager import ltx_aduc_manager
# --- Adiciona o path do LTX-Video para importações de baixo nível ---
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
add_deps_to_path()
# Importações para anotação de tipos e para as funções de trabalho (jobs).
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
import ltx_video.pipelines.crf_compressor as crf_compressor
# ==============================================================================
# --- DEFINIÇÕES DE ESTRUTURA E HELPERS ---
# ==============================================================================
@dataclass
class LatentConditioningItem:
"""
Estrutura de dados para passar latentes condicionados entre serviços.
O tensor latente é mantido na CPU para economizar VRAM entre as etapas.
"""
latent_tensor: torch.Tensor
media_frame_number: int
conditioning_strength: float
def load_image_to_tensor_with_resize_and_crop(
image_input: Union[str, Image.Image],
target_height: int,
target_width: int,
) -> torch.Tensor:
"""
Carrega e processa uma imagem para um tensor de pixel 5D, normalizado para [-1, 1],
pronto para ser enviado ao VAE para encoding.
"""
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
raise ValueError("image_input must be a file path or a PIL Image object")
# Lógica de corte e redimensionamento para manter a proporção
input_width, input_height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_frame = input_width / input_height
if aspect_ratio_frame > aspect_ratio_target:
new_width, new_height = int(input_height * aspect_ratio_target), input_height
x_start = (input_width - new_width) // 2
image = image.crop((x_start, 0, x_start + new_width, new_height))
else:
new_height = int(input_width / aspect_ratio_target)
y_start = (input_height - new_height) // 2
image = image.crop((0, y_start, input_width, y_start + new_height))
image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
# Conversão para tensor e normalização
frame_tensor = TVF.to_tensor(image)
frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
frame_tensor = (frame_tensor * 2.0) - 1.0
return frame_tensor.unsqueeze(0).unsqueeze(2)
# ==============================================================================
# --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool de VAE) ---
# ==============================================================================
def _job_encode_media(vae: CausalVideoAutoencoder, pixel_tensor: torch.Tensor) -> torch.Tensor:
"""Job que codifica um tensor de pixel em um tensor latente."""
device = vae.device
dtype = vae.dtype
pixel_tensor_gpu = pixel_tensor.to(device, dtype=dtype)
latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
return latents.cpu()
def _job_decode_latent(vae: CausalVideoAutoencoder, latent_tensor: torch.Tensor) -> torch.Tensor:
"""Job que decodifica um tensor latente em um tensor de pixels."""
device = vae.device
dtype = vae.dtype
latent_tensor_gpu = latent_tensor.to(device, dtype=dtype)
pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
return pixels.cpu()
# ==============================================================================
# --- A CLASSE CLIENTE (Interface Pública) ---
# ==============================================================================
class VaeAducPipeline:
"""
Cliente de alto nível para orquestrar todas as tarefas relacionadas ao VAE.
Ele define a lógica de negócios e submete os trabalhos ao LTXAducManager.
"""
def __init__(self):
logging.info("✅ VAE ADUC Pipeline (Client) initialized and ready to submit jobs.")
pass
def __call__(
self,
media: Union[torch.Tensor, List[Union[Image.Image, str]]],
task: Literal['encode', 'decode', 'create_conditioning_items'],
target_resolution: Optional[Tuple[int, int]] = (512, 512),
conditioning_params: Optional[List[Tuple[int, float]]] = None
) -> Union[List[torch.Tensor], torch.Tensor, List[LatentConditioningItem]]:
"""
Ponto de entrada principal para executar tarefas de VAE.
Args:
media: O dado de entrada.
task: A tarefa a executar ('encode', 'decode', 'create_conditioning_items').
target_resolution: A resolução (altura, largura) para o pré-processamento.
conditioning_params: Para 'create_conditioning_items', uma lista de tuplas
(frame_number, strength) para cada item de mídia.
Returns:
O resultado da tarefa, sempre na CPU.
"""
t0 = time.time()
logging.info(f"VAE Client received a '{task}' job.")
if task == 'encode':
if not isinstance(media, list): media = [media]
pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
results = [ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt) for pt in pixel_tensors]
return results
elif task == 'decode':
if not isinstance(media, torch.Tensor):
raise TypeError("Para a tarefa 'decode', 'media' deve ser um único tensor latente.")
return ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_decode_latent, latent_tensor=media)
elif task == 'create_conditioning_items':
if not isinstance(media, list) or not isinstance(conditioning_params, list) or len(media) != len(conditioning_params):
raise ValueError("Para 'create_conditioning_items', 'media' e 'conditioning_params' devem ser listas de mesmo tamanho.")
pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
conditioning_items = []
for i, pt in enumerate(pixel_tensors):
latent_tensor = ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt)
frame_number, strength = conditioning_params[i]
conditioning_items.append(LatentConditioningItem(
latent_tensor=latent_tensor,
media_frame_number=frame_number,
conditioning_strength=strength
))
return conditioning_items
else:
raise ValueError(f"Tarefa desconhecida: '{task}'. Opções: 'encode', 'decode', 'create_conditioning_items'.")
# --- INSTÂNCIA SINGLETON DO CLIENTE ---
try:
vae_aduc_pipeline = VaeAducPipeline()
except Exception as e:
logging.critical("CRITICAL: Failed to initialize the VaeAducPipeline client.", exc_info=True)
vae_aduc_pipeline = None |