baguette / src /vae_utils.py
nbagel's picture
Initial upload: Paris MoE inference code and weights
4dec1ca verified
# src/vae_utils.py
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL
from typing import Optional
import numpy as np
class VAEManager:
"""Utility class for VAE encoding/decoding operations"""
def __init__(self, model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda"):
self.device = device
self.model_name = model_name
self.vae = None
self._load_vae()
def _load_vae(self):
"""Load VAE model"""
print(f"Loading VAE: {self.model_name}")
self.vae = AutoencoderKL.from_pretrained(self.model_name)
self.vae = self.vae.to(self.device)
self.vae.eval()
# Freeze VAE parameters
for param in self.vae.parameters():
param.requires_grad = False
def encode(self, images: torch.Tensor) -> torch.Tensor:
"""
Encode images to latent space
Args:
images: Tensor of shape [B, 3, H, W] in range [-1, 1]
Returns:
latents: Tensor of shape [B, 4, H//8, W//8]
"""
with torch.no_grad():
images = images.to(self.device)
latent_dist = self.vae.encode(images).latent_dist
latents = latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
return latents
def decode(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
upscale_mode: str = 'bicubic') -> torch.Tensor:
"""
Decode latents to images
Args:
latents: Tensor of shape [B, 4, H, W]
upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x, 1.5 for 1.5x)
If None, returns images at native resolution (H*8, W*8)
upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
Returns:
images: Tensor of shape [B, 3, H*8*upscale_factor, W*8*upscale_factor] in range [-1, 1]
"""
with torch.no_grad():
latents = latents.to(self.device)
# Rescale latents
latents = latents / self.vae.config.scaling_factor
images = self.vae.decode(latents).sample
# Apply upscaling if requested
if upscale_factor is not None and upscale_factor != 1.0:
_, _, h, w = images.shape
new_h = int(h * upscale_factor)
new_w = int(w * upscale_factor)
images = F.interpolate(
images,
size=(new_h, new_w),
mode=upscale_mode,
align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
)
return images
def decode_to_pil(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
upscale_mode: str = 'bicubic', target_size: Optional[tuple] = None):
"""
Decode latents to PIL images
Args:
latents: Tensor of shape [B, 4, H, W]
upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x)
upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
target_size: Optional target size as (height, width). Overrides upscale_factor if provided.
Returns:
pil_images: List of PIL images
"""
from PIL import Image
# Decode to tensor
images = self.decode(latents, upscale_factor=upscale_factor, upscale_mode=upscale_mode)
# Apply target size if specified
if target_size is not None:
images = F.interpolate(
images,
size=target_size,
mode=upscale_mode,
align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
)
# Convert to [0, 1] range
images = (images + 1.0) / 2.0
images = torch.clamp(images, 0, 1)
# Convert to PIL
pil_images = []
for i in range(images.shape[0]):
img_array = images[i].cpu().numpy().transpose(1, 2, 0)
img_array = (img_array * 255).astype(np.uint8)
pil_image = Image.fromarray(img_array)
pil_images.append(pil_image)
return pil_images
@property
def scaling_factor(self) -> float:
"""Get VAE scaling factor"""
return self.vae.config.scaling_factor
@property
def latent_channels(self) -> int:
"""Get number of latent channels"""
return 4 # Standard for Stable Diffusion VAE
def create_vae_manager(model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda") -> VAEManager:
"""Factory function to create VAE manager"""
return VAEManager(model_name, device)
def save_images_from_latents(latents: torch.Tensor, save_dir: str, vae_manager: VAEManager, prefix: str = "sample"):
"""
Save images from latents using VAE decoder
Args:
latents: Tensor of shape [B, 4, H, W]
save_dir: Directory to save images
vae_manager: VAE manager instance
prefix: Filename prefix
"""
import os
os.makedirs(save_dir, exist_ok=True)
# Decode to PIL images
pil_images = vae_manager.decode_to_pil(latents)
# Save each image
for i, pil_image in enumerate(pil_images):
save_path = os.path.join(save_dir, f"{prefix}_{i:03d}.png")
pil_image.save(save_path)
print(f"Saved {len(pil_images)} images to {save_dir}")
def create_image_grid(latents: torch.Tensor, vae_manager: VAEManager, nrow: int = 4) -> torch.Tensor:
"""
Create image grid from latents
Args:
latents: Tensor of shape [B, 4, H, W]
vae_manager: VAE manager instance
nrow: Number of images per row
Returns:
grid: Image grid tensor
"""
import torchvision.utils as vutils
# Decode latents
images = vae_manager.decode(latents)
# Convert to [0, 1] range
images = (images + 1.0) / 2.0
images = torch.clamp(images, 0, 1)
# Create grid
grid = vutils.make_grid(images, nrow=nrow, padding=2)
return grid