File size: 7,120 Bytes
2c5f020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# FILE: api/vae_server.py
# DESCRIPTION: A dedicated, "hot" VAE service specialist.
# It loads the VAE model onto a dedicated GPU and keeps it in memory
# to handle all encoding and decoding requests with minimal latency.

import os
import sys
import time
import logging
from pathlib import Path
from typing import List, Union, Tuple

import torch
import numpy as np
from PIL import Image

from api.ltx_pool_manager import LatentConditioningItem
from api.gpu_manager import gpu_manager


# --- Importações da Arquitetura e do LTX ---
try:
    # Adiciona o path para as bibliotecas do LTX
    LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
    if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
        sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
    
    from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
    from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
except ImportError as e:
    raise ImportError(f"A crucial import failed for VaeServer. Check dependencies. Error: {e}")


class VaeServer:
    _instance = None

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialized = False
        return cls._instance

    def __init__(self):
        if self._initialized: return
        
        logging.info("⚙️ Initializing VaeServer Singleton...")
        t0 = time.time()
        
        # 1. Obter o dispositivo VAE dedicado do gerenciador central
        self.device = gpu_manager.get_ltx_vae_device()
        
        # 2. Carregar o modelo VAE do checkpoint do LTX
        # Assumimos que o setup.py já baixou os modelos.
        try:
            from api.ltx_pool_manager import ltx_pool_manager
            # Reutiliza a configuração e o pipeline já carregados pelo LTX Pool Manager
            # para garantir que estamos usando o mesmo VAE.
            self.vae = ltx_pool_manager.get_pipeline().vae
        except Exception as e:
            logging.critical(f"Failed to get VAE from LTXPoolManager. Is it initialized first? Error: {e}", exc_info=True)
            raise

        # 3. Garante que o VAE está no dispositivo correto e em modo de avaliação
        self.vae.to(self.device)
        self.vae.eval()
        self.dtype = self.vae.dtype
        
        self._initialized = True
        logging.info(f"✅ VaeServer ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s")

    def _cleanup_gpu(self):
        """Limpa a VRAM da GPU do VAE."""
        if torch.cuda.is_available():
            with torch.cuda.device(self.device):
                torch.cuda.empty_cache()

    def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor:
        """Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera."""
        if isinstance(item, Image.Image):
            from PIL import ImageOps
            img = item.convert("RGB")
            # Redimensiona mantendo a proporção e cortando o excesso
            processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
            image_np = np.array(processed_img).astype(np.float32) / 255.0
            tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW
        elif isinstance(item, torch.Tensor):
            # Se já for um tensor, apenas garante que está no formato CHW
            if item.ndim == 4 and item.shape[0] == 1: # Remove dimensão de batch se houver
                tensor = item.squeeze(0)
            elif item.ndim == 3:
                tensor = item
            else:
                raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}")
        else:
            raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}")

        # Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1]
        tensor_5d = tensor.unsqueeze(0).unsqueeze(2) # Adiciona B=1 e F=1
        return (tensor_5d * 2.0) - 1.0

    @torch.no_grad()
    def generate_conditioning_items(
        self,
        media_items: List[Union[Image.Image, torch.Tensor]],
        target_frames: List[int],
        strengths: List[float],
        target_resolution: Tuple[int, int]
    ) -> List[LatentConditioningItem]:
        """
        [FUNÇÃO PRINCIPAL]
        Converte uma lista de imagens (PIL ou tensores de pixel) em uma lista de
        LatentConditioningItem, pronta para ser usada pelo pipeline LTX corrigido.
        """
        t0 = time.time()
        logging.info(f"Generating {len(media_items)} latent conditioning items...")
        
        if not (len(media_items) == len(target_frames) == len(strengths)):
            raise ValueError("As listas de media_items, target_frames e strengths devem ter o mesmo tamanho.")
        
        conditioning_items = []
        try:
            for item, frame, strength in zip(media_items, target_frames, strengths):
                # 1. Prepara a imagem/tensor para o formato de pixel correto
                pixel_tensor = self._preprocess_input(item, target_resolution)
                
                # 2. Move o tensor de pixel para a GPU do VAE e encoda para latente
                pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype)
                latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True)
                
                # 3. Cria o LatentConditioningItem com o latente (movido para CPU para evitar manter na VRAM)
                conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength))

            logging.info(f"Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.")
            return conditioning_items
        finally:
            self._cleanup_gpu()

    @torch.no_grad()
    def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
        """Decodifica um tensor latente para um tensor de pixels na CPU."""
        t0 = time.time()
        try:
            latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
            num_items_in_batch = latent_tensor_gpu.shape[0]
            timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype)
            
            pixels = vae_decode(
                latent_tensor_gpu, self.vae, is_video=True, 
                timestep=timestep_tensor, vae_per_channel_normalize=True
            )
            logging.info(f"Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.")
            return pixels.cpu() # Retorna na CPU
        finally:
            self._cleanup_gpu()

# --- Instância Singleton ---
# A inicialização ocorre quando o módulo é importado pela primeira vez.
try:
    vae_server_singleton = VaeServer()
except Exception as e:
    logging.critical("CRITICAL: Failed to initialize VaeServer singleton.", exc_info=True)
    vae_server_singleton = None