# src/vae_utils.py import torch import torch.nn.functional as F from diffusers import AutoencoderKL from typing import Optional import numpy as np class VAEManager: """Utility class for VAE encoding/decoding operations""" def __init__(self, model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda"): self.device = device self.model_name = model_name self.vae = None self._load_vae() def _load_vae(self): """Load VAE model""" print(f"Loading VAE: {self.model_name}") self.vae = AutoencoderKL.from_pretrained(self.model_name) self.vae = self.vae.to(self.device) self.vae.eval() # Freeze VAE parameters for param in self.vae.parameters(): param.requires_grad = False def encode(self, images: torch.Tensor) -> torch.Tensor: """ Encode images to latent space Args: images: Tensor of shape [B, 3, H, W] in range [-1, 1] Returns: latents: Tensor of shape [B, 4, H//8, W//8] """ with torch.no_grad(): images = images.to(self.device) latent_dist = self.vae.encode(images).latent_dist latents = latent_dist.sample() latents = latents * self.vae.config.scaling_factor return latents def decode(self, latents: torch.Tensor, upscale_factor: Optional[float] = None, upscale_mode: str = 'bicubic') -> torch.Tensor: """ Decode latents to images Args: latents: Tensor of shape [B, 4, H, W] upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x, 1.5 for 1.5x) If None, returns images at native resolution (H*8, W*8) upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest') Returns: images: Tensor of shape [B, 3, H*8*upscale_factor, W*8*upscale_factor] in range [-1, 1] """ with torch.no_grad(): latents = latents.to(self.device) # Rescale latents latents = latents / self.vae.config.scaling_factor images = self.vae.decode(latents).sample # Apply upscaling if requested if upscale_factor is not None and upscale_factor != 1.0: _, _, h, w = images.shape new_h = int(h * upscale_factor) new_w = int(w * upscale_factor) images = F.interpolate( images, size=(new_h, new_w), mode=upscale_mode, align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None, antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False ) return images def decode_to_pil(self, latents: torch.Tensor, upscale_factor: Optional[float] = None, upscale_mode: str = 'bicubic', target_size: Optional[tuple] = None): """ Decode latents to PIL images Args: latents: Tensor of shape [B, 4, H, W] upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x) upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest') target_size: Optional target size as (height, width). Overrides upscale_factor if provided. Returns: pil_images: List of PIL images """ from PIL import Image # Decode to tensor images = self.decode(latents, upscale_factor=upscale_factor, upscale_mode=upscale_mode) # Apply target size if specified if target_size is not None: images = F.interpolate( images, size=target_size, mode=upscale_mode, align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None, antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False ) # Convert to [0, 1] range images = (images + 1.0) / 2.0 images = torch.clamp(images, 0, 1) # Convert to PIL pil_images = [] for i in range(images.shape[0]): img_array = images[i].cpu().numpy().transpose(1, 2, 0) img_array = (img_array * 255).astype(np.uint8) pil_image = Image.fromarray(img_array) pil_images.append(pil_image) return pil_images @property def scaling_factor(self) -> float: """Get VAE scaling factor""" return self.vae.config.scaling_factor @property def latent_channels(self) -> int: """Get number of latent channels""" return 4 # Standard for Stable Diffusion VAE def create_vae_manager(model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda") -> VAEManager: """Factory function to create VAE manager""" return VAEManager(model_name, device) def save_images_from_latents(latents: torch.Tensor, save_dir: str, vae_manager: VAEManager, prefix: str = "sample"): """ Save images from latents using VAE decoder Args: latents: Tensor of shape [B, 4, H, W] save_dir: Directory to save images vae_manager: VAE manager instance prefix: Filename prefix """ import os os.makedirs(save_dir, exist_ok=True) # Decode to PIL images pil_images = vae_manager.decode_to_pil(latents) # Save each image for i, pil_image in enumerate(pil_images): save_path = os.path.join(save_dir, f"{prefix}_{i:03d}.png") pil_image.save(save_path) print(f"Saved {len(pil_images)} images to {save_dir}") def create_image_grid(latents: torch.Tensor, vae_manager: VAEManager, nrow: int = 4) -> torch.Tensor: """ Create image grid from latents Args: latents: Tensor of shape [B, 4, H, W] vae_manager: VAE manager instance nrow: Number of images per row Returns: grid: Image grid tensor """ import torchvision.utils as vutils # Decode latents images = vae_manager.decode(latents) # Convert to [0, 1] range images = (images + 1.0) / 2.0 images = torch.clamp(images, 0, 1) # Create grid grid = vutils.make_grid(images, nrow=nrow, padding=2) return grid