# FILE: managers/vae_manager.py # DESCRIPTION: Singleton manager for VAE decoding operations, supporting dedicated GPU devices. import torch import contextlib import logging class _SimpleVAEManager: """ Manages VAE decoding. It's designed to be aware that the VAE might reside on a different GPU than the main generation pipeline (e.g., Transformer). """ def __init__(self): """Initializes the manager without a pipeline attached.""" self.pipeline = None self.device = torch.device("cpu") # Defaults to CPU until a device is attached. self.autocast_dtype = torch.float32 def attach_pipeline(self, pipeline, device=None, autocast_dtype=None): """ Attaches the main pipeline and, crucially, stores the specific device that this manager and its associated VAE should operate on. Args: pipeline: The main LTX video pipeline instance. device (torch.device or str): The target device for VAE operations (e.g., 'cuda:1'). autocast_dtype (torch.dtype): The precision for torch.autocast. """ self.pipeline = pipeline if device is not None: self.device = torch.device(device) logging.info(f"[VAEManager] VAE device successfully set to: {self.device}") if autocast_dtype is not None: self.autocast_dtype = autocast_dtype @torch.no_grad() def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: """ Decodes a latent tensor into a pixel tensor. This method ensures that the decoding operation happens on the correct, potentially dedicated, VAE device. Args: latent_tensor (torch.Tensor): The latents to decode, typically on the main device or CPU. decode_timestep (float): The timestep for VAE decoding. Returns: torch.Tensor: The resulting pixel tensor, moved to the CPU for general use. """ if self.pipeline is None: raise RuntimeError("VAEManager: No pipeline has been attached. Call attach_pipeline() first.") if not hasattr(self.pipeline, 'vae'): raise AttributeError("VAEManager: The attached pipeline does not have a 'vae' attribute.") # 1. Move the input latents to the dedicated VAE device. This is the critical step. logging.debug(f"[VAEManager] Moving latents from {latent_tensor.device} to VAE device {self.device} for decoding.") latent_tensor_on_vae_device = latent_tensor.to(self.device) # 2. Get a reference to the VAE model (which is already on the correct device). vae = self.pipeline.vae # 3. Prepare other necessary tensors on the same VAE device. num_items_in_batch = latent_tensor_on_vae_device.shape[0] timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device) # 4. Set up the autocast context for the target device type. autocast_device_type = self.device.type ctx = torch.autocast( device_type=autocast_device_type, dtype=self.autocast_dtype, enabled=(autocast_device_type == 'cuda') ) # 5. Perform the decoding operation within the autocast context. with ctx: logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.") # The VAE expects latents scaled by its scaling factor. scaled_latents = latent_tensor_on_vae_device / vae.config.scaling_factor pixels = vae.decode(scaled_latents, timesteps=timestep_tensor).sample # 6. Post-process the output: normalize to [0, 1] range. pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0 # 7. Move the final pixel tensor to the CPU. This is a safe default, as subsequent # operations like video saving or UI display typically expect CPU tensors. logging.debug(f"[VAEManager] Decoding complete. Moving pixel tensor to CPU.") return pixels.cpu() # Create a single, global instance of the manager to be used throughout the application. vae_manager_singleton = _SimpleVAEManager()