File size: 6,507 Bytes
4dec1ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# src/vae_utils.py
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL
from typing import Optional
import numpy as np
class VAEManager:
"""Utility class for VAE encoding/decoding operations"""
def __init__(self, model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda"):
self.device = device
self.model_name = model_name
self.vae = None
self._load_vae()
def _load_vae(self):
"""Load VAE model"""
print(f"Loading VAE: {self.model_name}")
self.vae = AutoencoderKL.from_pretrained(self.model_name)
self.vae = self.vae.to(self.device)
self.vae.eval()
# Freeze VAE parameters
for param in self.vae.parameters():
param.requires_grad = False
def encode(self, images: torch.Tensor) -> torch.Tensor:
"""
Encode images to latent space
Args:
images: Tensor of shape [B, 3, H, W] in range [-1, 1]
Returns:
latents: Tensor of shape [B, 4, H//8, W//8]
"""
with torch.no_grad():
images = images.to(self.device)
latent_dist = self.vae.encode(images).latent_dist
latents = latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
return latents
def decode(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
upscale_mode: str = 'bicubic') -> torch.Tensor:
"""
Decode latents to images
Args:
latents: Tensor of shape [B, 4, H, W]
upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x, 1.5 for 1.5x)
If None, returns images at native resolution (H*8, W*8)
upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
Returns:
images: Tensor of shape [B, 3, H*8*upscale_factor, W*8*upscale_factor] in range [-1, 1]
"""
with torch.no_grad():
latents = latents.to(self.device)
# Rescale latents
latents = latents / self.vae.config.scaling_factor
images = self.vae.decode(latents).sample
# Apply upscaling if requested
if upscale_factor is not None and upscale_factor != 1.0:
_, _, h, w = images.shape
new_h = int(h * upscale_factor)
new_w = int(w * upscale_factor)
images = F.interpolate(
images,
size=(new_h, new_w),
mode=upscale_mode,
align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
)
return images
def decode_to_pil(self, latents: torch.Tensor, upscale_factor: Optional[float] = None,
upscale_mode: str = 'bicubic', target_size: Optional[tuple] = None):
"""
Decode latents to PIL images
Args:
latents: Tensor of shape [B, 4, H, W]
upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x)
upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest')
target_size: Optional target size as (height, width). Overrides upscale_factor if provided.
Returns:
pil_images: List of PIL images
"""
from PIL import Image
# Decode to tensor
images = self.decode(latents, upscale_factor=upscale_factor, upscale_mode=upscale_mode)
# Apply target size if specified
if target_size is not None:
images = F.interpolate(
images,
size=target_size,
mode=upscale_mode,
align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None,
antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False
)
# Convert to [0, 1] range
images = (images + 1.0) / 2.0
images = torch.clamp(images, 0, 1)
# Convert to PIL
pil_images = []
for i in range(images.shape[0]):
img_array = images[i].cpu().numpy().transpose(1, 2, 0)
img_array = (img_array * 255).astype(np.uint8)
pil_image = Image.fromarray(img_array)
pil_images.append(pil_image)
return pil_images
@property
def scaling_factor(self) -> float:
"""Get VAE scaling factor"""
return self.vae.config.scaling_factor
@property
def latent_channels(self) -> int:
"""Get number of latent channels"""
return 4 # Standard for Stable Diffusion VAE
def create_vae_manager(model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda") -> VAEManager:
"""Factory function to create VAE manager"""
return VAEManager(model_name, device)
def save_images_from_latents(latents: torch.Tensor, save_dir: str, vae_manager: VAEManager, prefix: str = "sample"):
"""
Save images from latents using VAE decoder
Args:
latents: Tensor of shape [B, 4, H, W]
save_dir: Directory to save images
vae_manager: VAE manager instance
prefix: Filename prefix
"""
import os
os.makedirs(save_dir, exist_ok=True)
# Decode to PIL images
pil_images = vae_manager.decode_to_pil(latents)
# Save each image
for i, pil_image in enumerate(pil_images):
save_path = os.path.join(save_dir, f"{prefix}_{i:03d}.png")
pil_image.save(save_path)
print(f"Saved {len(pil_images)} images to {save_dir}")
def create_image_grid(latents: torch.Tensor, vae_manager: VAEManager, nrow: int = 4) -> torch.Tensor:
"""
Create image grid from latents
Args:
latents: Tensor of shape [B, 4, H, W]
vae_manager: VAE manager instance
nrow: Number of images per row
Returns:
grid: Image grid tensor
"""
import torchvision.utils as vutils
# Decode latents
images = vae_manager.decode(latents)
# Convert to [0, 1] range
images = (images + 1.0) / 2.0
images = torch.clamp(images, 0, 1)
# Create grid
grid = vutils.make_grid(images, nrow=nrow, padding=2)
return grid |