File size: 4,263 Bytes
b8a0748
 
8815ceb
 
 
b8a0748
829e1b9
8815ceb
b8a0748
 
 
 
 
 
 
 
 
8815ceb
 
b8a0748
 
 
 
 
 
 
 
 
8815ceb
 
b8a0748
 
8815ceb
 
 
441491f
 
b8a0748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441491f
b8a0748
 
 
 
 
 
 
 
 
8815ceb
b8a0748
 
 
 
441491f
b8a0748
 
 
 
 
 
 
8815ceb
b8a0748
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# FILE: managers/vae_manager.py
# DESCRIPTION: Singleton manager for VAE decoding operations, supporting dedicated GPU devices.

import torch
import contextlib
import logging

class _SimpleVAEManager:
    """
    Manages VAE decoding. It's designed to be aware that the VAE might reside
    on a different GPU than the main generation pipeline (e.g., Transformer).
    """
    def __init__(self):
        """Initializes the manager without a pipeline attached."""
        self.pipeline = None
        self.device = torch.device("cpu") # Defaults to CPU until a device is attached.
        self.autocast_dtype = torch.float32

    def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
        """
        Attaches the main pipeline and, crucially, stores the specific device
        that this manager and its associated VAE should operate on.

        Args:
            pipeline: The main LTX video pipeline instance.
            device (torch.device or str): The target device for VAE operations (e.g., 'cuda:1').
            autocast_dtype (torch.dtype): The precision for torch.autocast.
        """
        self.pipeline = pipeline
        if device is not None:
            self.device = torch.device(device)
            logging.info(f"[VAEManager] VAE device successfully set to: {self.device}")
        if autocast_dtype is not None:
            self.autocast_dtype = autocast_dtype

    @torch.no_grad()
    def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
        """
        Decodes a latent tensor into a pixel tensor.

        This method ensures that the decoding operation happens on the correct,
        potentially dedicated, VAE device.

        Args:
            latent_tensor (torch.Tensor): The latents to decode, typically on the main device or CPU.
            decode_timestep (float): The timestep for VAE decoding.

        Returns:
            torch.Tensor: The resulting pixel tensor, moved to the CPU for general use.
        """
        if self.pipeline is None:
            raise RuntimeError("VAEManager: No pipeline has been attached. Call attach_pipeline() first.")
        if not hasattr(self.pipeline, 'vae'):
            raise AttributeError("VAEManager: The attached pipeline does not have a 'vae' attribute.")

        # 1. Move the input latents to the dedicated VAE device. This is the critical step.
        logging.debug(f"[VAEManager] Moving latents from {latent_tensor.device} to VAE device {self.device} for decoding.")
        latent_tensor_on_vae_device = latent_tensor.to(self.device)

        # 2. Get a reference to the VAE model (which is already on the correct device).
        vae = self.pipeline.vae

        # 3. Prepare other necessary tensors on the same VAE device.
        num_items_in_batch = latent_tensor_on_vae_device.shape[0]
        timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device)
    
        # 4. Set up the autocast context for the target device type.
        autocast_device_type = self.device.type
        ctx = torch.autocast(
            device_type=autocast_device_type,
            dtype=self.autocast_dtype,
            enabled=(autocast_device_type == 'cuda')
        )
        
        # 5. Perform the decoding operation within the autocast context.
        with ctx:
            logging.debug(f"[VAEManager] Decoding latents with shape {latent_tensor_on_vae_device.shape} on {self.device}.")
            # The VAE expects latents scaled by its scaling factor.
            scaled_latents = latent_tensor_on_vae_device / vae.config.scaling_factor
            pixels = vae.decode(scaled_latents, timesteps=timestep_tensor).sample
    
        # 6. Post-process the output: normalize to [0, 1] range.
        pixels = (pixels.clamp(-1, 1) + 1.0) / 2.0
        
        # 7. Move the final pixel tensor to the CPU. This is a safe default, as subsequent
        #    operations like video saving or UI display typically expect CPU tensors.
        logging.debug(f"[VAEManager] Decoding complete. Moving pixel tensor to CPU.")
        return pixels.cpu()

# Create a single, global instance of the manager to be used throughout the application.
vae_manager_singleton = _SimpleVAEManager()