File size: 6,507 Bytes
4dec1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# src/vae_utils.py
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL
from typing import Optional
import numpy as np

class VAEManager:
    """Utility class for VAE encoding/decoding operations"""
    
    def __init__(self, model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda"):
        self.device = device
        self.model_name = model_name
        self.vae = None
        self._load_vae()
    
    def _load_vae(self):
        """Load VAE model"""
        print(f"Loading VAE: {self.model_name}")
        self.vae = AutoencoderKL.from_pretrained(self.model_name)
        self.vae = self.vae.to(self.device)
        self.vae.eval()
        
        # Freeze VAE parameters
        for param in self.vae.parameters():
            param.requires_grad = False
    
    def encode(self, images: torch.Tensor) -> torch.Tensor:
        """
        Encode images to latent space
        
        Args:
            images: Tensor of shape [B, 3, H, W] in range [-1, 1]
            
        Returns:
            latents: Tensor of shape [B, 4, H//8, W//8]
        """
        with torch.no_grad():
            images = images.to(self.device)
            latent_dist = self.vae.encode(images).latent_dist
            latents = latent_dist.sample()
            latents = latents * self.vae.config.scaling_factor
        
        return latents
    
    def decode(self, latents: torch.Tensor, upscale_factor: Optional[float] = None, 
               upscale_mode: str = 'bicubic') -> torch.Tensor:
        """
        Decode latents to images
        
        Args:
            latents: Tensor of shape [B, 4, H, W]
            upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x, 1.5 for 1.5x)
                          If None, returns images at native resolution (H*8, W*8)
            upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
            
        Returns:
            images: Tensor of shape [B, 3, H*8*upscale_factor, W*8*upscale_factor] in range [-1, 1]
        """
        with torch.no_grad():
            latents = latents.to(self.device)
            # Rescale latents
            latents = latents / self.vae.config.scaling_factor
            images = self.vae.decode(latents).sample
            
            # Apply upscaling if requested
            if upscale_factor is not None and upscale_factor != 1.0:
                _, _, h, w = images.shape
                new_h = int(h * upscale_factor)
                new_w = int(w * upscale_factor)
                images = F.interpolate(
                    images, 
                    size=(new_h, new_w), 
                    mode=upscale_mode,
                    align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
                    antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
                )
        
        return images
    
    def decode_to_pil(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
                     upscale_mode: str = 'bicubic', target_size: Optional[tuple] = None):
        """
        Decode latents to PIL images
        
        Args:
            latents: Tensor of shape [B, 4, H, W]
            upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x)
            upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
            target_size: Optional target size as (height, width). Overrides upscale_factor if provided.
            
        Returns:
            pil_images: List of PIL images
        """
        from PIL import Image
        
        # Decode to tensor
        images = self.decode(latents, upscale_factor=upscale_factor, upscale_mode=upscale_mode)
        
        # Apply target size if specified
        if target_size is not None:
            images = F.interpolate(
                images,
                size=target_size,
                mode=upscale_mode,
                align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
                antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
            )
        
        # Convert to [0, 1] range
        images = (images + 1.0) / 2.0
        images = torch.clamp(images, 0, 1)
        
        # Convert to PIL
        pil_images = []
        for i in range(images.shape[0]):
            img_array = images[i].cpu().numpy().transpose(1, 2, 0)
            img_array = (img_array * 255).astype(np.uint8)
            pil_image = Image.fromarray(img_array)
            pil_images.append(pil_image)
        
        return pil_images
    
    @property
    def scaling_factor(self) -> float:
        """Get VAE scaling factor"""
        return self.vae.config.scaling_factor
    
    @property
    def latent_channels(self) -> int:
        """Get number of latent channels"""
        return 4  # Standard for Stable Diffusion VAE

def create_vae_manager(model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda") -> VAEManager:
    """Factory function to create VAE manager"""
    return VAEManager(model_name, device)

def save_images_from_latents(latents: torch.Tensor, save_dir: str, vae_manager: VAEManager, prefix: str = "sample"):
    """
    Save images from latents using VAE decoder
    
    Args:
        latents: Tensor of shape [B, 4, H, W]
        save_dir: Directory to save images
        vae_manager: VAE manager instance
        prefix: Filename prefix
    """
    import os
    
    os.makedirs(save_dir, exist_ok=True)
    
    # Decode to PIL images
    pil_images = vae_manager.decode_to_pil(latents)
    
    # Save each image
    for i, pil_image in enumerate(pil_images):
        save_path = os.path.join(save_dir, f"{prefix}_{i:03d}.png")
        pil_image.save(save_path)
    
    print(f"Saved {len(pil_images)} images to {save_dir}")

def create_image_grid(latents: torch.Tensor, vae_manager: VAEManager, nrow: int = 4) -> torch.Tensor:
    """
    Create image grid from latents
    
    Args:
        latents: Tensor of shape [B, 4, H, W]
        vae_manager: VAE manager instance
        nrow: Number of images per row
        
    Returns:
        grid: Image grid tensor
    """
    import torchvision.utils as vutils
    
    # Decode latents
    images = vae_manager.decode(latents)
    
    # Convert to [0, 1] range
    images = (images + 1.0) / 2.0
    images = torch.clamp(images, 0, 1)
    
    # Create grid
    grid = vutils.make_grid(images, nrow=nrow, padding=2)
    
    return grid