File size: 8,102 Bytes
97682d1
1cacf10
 
c9413de
 
1cacf10
c9413de
9a6b3d7
1cacf10
c9413de
9a6b3d7
1cacf10
e13ea4b
9a6b3d7
c9413de
9a6b3d7
1cacf10
c9413de
9a6b3d7
 
 
 
 
 
 
 
 
 
 
 
1cacf10
 
9a6b3d7
1cacf10
 
 
 
 
 
9a6b3d7
1cacf10
 
 
 
 
 
 
 
 
 
 
 
9a6b3d7
1cacf10
 
 
 
9a6b3d7
1cacf10
 
 
9a6b3d7
1cacf10
 
 
 
 
9a6b3d7
 
1cacf10
9a6b3d7
 
 
 
1cacf10
e13ea4b
9a6b3d7
 
1cacf10
 
 
 
 
 
 
 
9a6b3d7
1cacf10
 
 
9a6b3d7
1cacf10
 
 
 
 
 
9a6b3d7
 
1cacf10
 
 
 
 
 
 
 
 
 
 
9a6b3d7
 
 
 
c9413de
1cacf10
 
c9413de
1cacf10
c9413de
9a6b3d7
1cacf10
 
 
 
c9413de
1cacf10
 
 
 
 
 
 
9a6b3d7
1cacf10
 
 
c9413de
 
1cacf10
 
 
 
 
9a6b3d7
1cacf10
 
 
 
9a6b3d7
 
1cacf10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9413de
 
1cacf10
 
 
 
c9413de
1cacf10
c9413de
1cacf10
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# FILE: api/ltx/vae_aduc_pipeline.py
# DESCRIPTION: A high-level client for submitting VAE-related jobs to the LTXAducManager pool.
# It handles encoding media to latents, decoding latents to pixels, and creating ConditioningItems.

import logging
import time
import torch
import os
import torchvision.transforms.functional as TVF
from PIL import Image
from typing import List, Union, Tuple, Literal, Optional
from dataclasses import dataclass
from pathlib import Path
import sys

# O cliente importa o MANAGER para submeter os trabalhos ao pool de workers.
from api.ltx.ltx_aduc_manager import ltx_aduc_manager

# --- Adiciona o path do LTX-Video para importações de baixo nível ---
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
    repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
    if repo_path not in sys.path:
        sys.path.insert(0, repo_path)
add_deps_to_path()

# Importações para anotação de tipos e para as funções de trabalho (jobs).
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
import ltx_video.pipelines.crf_compressor as crf_compressor

# ==============================================================================
# --- DEFINIÇÕES DE ESTRUTURA E HELPERS ---
# ==============================================================================

@dataclass
class LatentConditioningItem:
    """
    Estrutura de dados para passar latentes condicionados entre serviços.
    O tensor latente é mantido na CPU para economizar VRAM entre as etapas.
    """
    latent_tensor: torch.Tensor
    media_frame_number: int
    conditioning_strength: float

def load_image_to_tensor_with_resize_and_crop(
    image_input: Union[str, Image.Image],
    target_height: int,
    target_width: int,
) -> torch.Tensor:
    """
    Carrega e processa uma imagem para um tensor de pixel 5D, normalizado para [-1, 1],
    pronto para ser enviado ao VAE para encoding.
    """
    if isinstance(image_input, str):
        image = Image.open(image_input).convert("RGB")
    elif isinstance(image_input, Image.Image):
        image = image_input.convert("RGB")
    else:
        raise ValueError("image_input must be a file path or a PIL Image object")

    # Lógica de corte e redimensionamento para manter a proporção
    input_width, input_height = image.size
    aspect_ratio_target = target_width / target_height
    aspect_ratio_frame = input_width / input_height
    if aspect_ratio_frame > aspect_ratio_target:
        new_width, new_height = int(input_height * aspect_ratio_target), input_height
        x_start = (input_width - new_width) // 2
        image = image.crop((x_start, 0, x_start + new_width, new_height))
    else:
        new_height = int(input_width / aspect_ratio_target)
        y_start = (input_height - new_height) // 2
        image = image.crop((0, y_start, input_width, y_start + new_height))
    
    image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
    
    # Conversão para tensor e normalização
    frame_tensor = TVF.to_tensor(image)
    frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
    frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
    frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
    
    frame_tensor = (frame_tensor * 2.0) - 1.0
    return frame_tensor.unsqueeze(0).unsqueeze(2)

# ==============================================================================
# --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool de VAE) ---
# ==============================================================================

def _job_encode_media(vae: CausalVideoAutoencoder, pixel_tensor: torch.Tensor) -> torch.Tensor:
    """Job que codifica um tensor de pixel em um tensor latente."""
    device = vae.device
    dtype = vae.dtype
    pixel_tensor_gpu = pixel_tensor.to(device, dtype=dtype)
    latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
    return latents.cpu()

def _job_decode_latent(vae: CausalVideoAutoencoder, latent_tensor: torch.Tensor) -> torch.Tensor:
    """Job que decodifica um tensor latente em um tensor de pixels."""
    device = vae.device
    dtype = vae.dtype
    latent_tensor_gpu = latent_tensor.to(device, dtype=dtype)
    pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
    return pixels.cpu()

# ==============================================================================
# --- A CLASSE CLIENTE (Interface Pública) ---
# ==============================================================================

class VaeAducPipeline:
    """
    Cliente de alto nível para orquestrar todas as tarefas relacionadas ao VAE.
    Ele define a lógica de negócios e submete os trabalhos ao LTXAducManager.
    """
    def __init__(self):
        logging.info("✅ VAE ADUC Pipeline (Client) initialized and ready to submit jobs.")
        pass

    def __call__(
        self,
        media: Union[torch.Tensor, List[Union[Image.Image, str]]],
        task: Literal['encode', 'decode', 'create_conditioning_items'],
        target_resolution: Optional[Tuple[int, int]] = (512, 512),
        conditioning_params: Optional[List[Tuple[int, float]]] = None
    ) -> Union[List[torch.Tensor], torch.Tensor, List[LatentConditioningItem]]:
        """
        Ponto de entrada principal para executar tarefas de VAE.

        Args:
            media: O dado de entrada.
            task: A tarefa a executar ('encode', 'decode', 'create_conditioning_items').
            target_resolution: A resolução (altura, largura) para o pré-processamento.
            conditioning_params: Para 'create_conditioning_items', uma lista de tuplas
                                 (frame_number, strength) para cada item de mídia.

        Returns:
            O resultado da tarefa, sempre na CPU.
        """
        t0 = time.time()
        logging.info(f"VAE Client received a '{task}' job.")

        if task == 'encode':
            if not isinstance(media, list): media = [media]
            pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
            results = [ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt) for pt in pixel_tensors]
            return results

        elif task == 'decode':
            if not isinstance(media, torch.Tensor):
                raise TypeError("Para a tarefa 'decode', 'media' deve ser um único tensor latente.")
            return ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_decode_latent, latent_tensor=media)

        elif task == 'create_conditioning_items':
            if not isinstance(media, list) or not isinstance(conditioning_params, list) or len(media) != len(conditioning_params):
                raise ValueError("Para 'create_conditioning_items', 'media' e 'conditioning_params' devem ser listas de mesmo tamanho.")
            
            pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
            conditioning_items = []
            for i, pt in enumerate(pixel_tensors):
                latent_tensor = ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt)
                frame_number, strength = conditioning_params[i]
                conditioning_items.append(LatentConditioningItem(
                    latent_tensor=latent_tensor,
                    media_frame_number=frame_number,
                    conditioning_strength=strength
                ))
            return conditioning_items
            
        else:
            raise ValueError(f"Tarefa desconhecida: '{task}'. Opções: 'encode', 'decode', 'create_conditioning_items'.")

# --- INSTÂNCIA SINGLETON DO CLIENTE ---
try:
    vae_aduc_pipeline = VaeAducPipeline()
except Exception as e:
    logging.critical("CRITICAL: Failed to initialize the VaeAducPipeline client.", exc_info=True)
    vae_aduc_pipeline = None