|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from diffusers import AutoencoderKL |
|
|
from typing import Optional |
|
|
import numpy as np |
|
|
|
|
|
class VAEManager: |
|
|
"""Utility class for VAE encoding/decoding operations""" |
|
|
|
|
|
def __init__(self, model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda"): |
|
|
self.device = device |
|
|
self.model_name = model_name |
|
|
self.vae = None |
|
|
self._load_vae() |
|
|
|
|
|
def _load_vae(self): |
|
|
"""Load VAE model""" |
|
|
print(f"Loading VAE: {self.model_name}") |
|
|
self.vae = AutoencoderKL.from_pretrained(self.model_name) |
|
|
self.vae = self.vae.to(self.device) |
|
|
self.vae.eval() |
|
|
|
|
|
|
|
|
for param in self.vae.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def encode(self, images: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encode images to latent space |
|
|
|
|
|
Args: |
|
|
images: Tensor of shape [B, 3, H, W] in range [-1, 1] |
|
|
|
|
|
Returns: |
|
|
latents: Tensor of shape [B, 4, H//8, W//8] |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
images = images.to(self.device) |
|
|
latent_dist = self.vae.encode(images).latent_dist |
|
|
latents = latent_dist.sample() |
|
|
latents = latents * self.vae.config.scaling_factor |
|
|
|
|
|
return latents |
|
|
|
|
|
def decode(self, latents: torch.Tensor, upscale_factor: Optional[float] = None, |
|
|
upscale_mode: str = 'bicubic') -> torch.Tensor: |
|
|
""" |
|
|
Decode latents to images |
|
|
|
|
|
Args: |
|
|
latents: Tensor of shape [B, 4, H, W] |
|
|
upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x, 1.5 for 1.5x) |
|
|
If None, returns images at native resolution (H*8, W*8) |
|
|
upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest') |
|
|
|
|
|
Returns: |
|
|
images: Tensor of shape [B, 3, H*8*upscale_factor, W*8*upscale_factor] in range [-1, 1] |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
latents = latents.to(self.device) |
|
|
|
|
|
latents = latents / self.vae.config.scaling_factor |
|
|
images = self.vae.decode(latents).sample |
|
|
|
|
|
|
|
|
if upscale_factor is not None and upscale_factor != 1.0: |
|
|
_, _, h, w = images.shape |
|
|
new_h = int(h * upscale_factor) |
|
|
new_w = int(w * upscale_factor) |
|
|
images = F.interpolate( |
|
|
images, |
|
|
size=(new_h, new_w), |
|
|
mode=upscale_mode, |
|
|
align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None, |
|
|
antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False |
|
|
) |
|
|
|
|
|
return images |
|
|
|
|
|
def decode_to_pil(self, latents: torch.Tensor, upscale_factor: Optional[float] = None, |
|
|
upscale_mode: str = 'bicubic', target_size: Optional[tuple] = None): |
|
|
""" |
|
|
Decode latents to PIL images |
|
|
|
|
|
Args: |
|
|
latents: Tensor of shape [B, 4, H, W] |
|
|
upscale_factor: Optional upscaling factor (e.g., 2.0 for 2x) |
|
|
upscale_mode: Interpolation mode ('bicubic', 'bilinear', 'nearest') |
|
|
target_size: Optional target size as (height, width). Overrides upscale_factor if provided. |
|
|
|
|
|
Returns: |
|
|
pil_images: List of PIL images |
|
|
""" |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
images = self.decode(latents, upscale_factor=upscale_factor, upscale_mode=upscale_mode) |
|
|
|
|
|
|
|
|
if target_size is not None: |
|
|
images = F.interpolate( |
|
|
images, |
|
|
size=target_size, |
|
|
mode=upscale_mode, |
|
|
align_corners=False if upscale_mode in ['bilinear', 'bicubic'] else None, |
|
|
antialias=True if upscale_mode in ['bilinear', 'bicubic'] else False |
|
|
) |
|
|
|
|
|
|
|
|
images = (images + 1.0) / 2.0 |
|
|
images = torch.clamp(images, 0, 1) |
|
|
|
|
|
|
|
|
pil_images = [] |
|
|
for i in range(images.shape[0]): |
|
|
img_array = images[i].cpu().numpy().transpose(1, 2, 0) |
|
|
img_array = (img_array * 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(img_array) |
|
|
pil_images.append(pil_image) |
|
|
|
|
|
return pil_images |
|
|
|
|
|
@property |
|
|
def scaling_factor(self) -> float: |
|
|
"""Get VAE scaling factor""" |
|
|
return self.vae.config.scaling_factor |
|
|
|
|
|
@property |
|
|
def latent_channels(self) -> int: |
|
|
"""Get number of latent channels""" |
|
|
return 4 |
|
|
|
|
|
def create_vae_manager(model_name: str = "stabilityai/sd-vae-ft-mse", device: str = "cuda") -> VAEManager: |
|
|
"""Factory function to create VAE manager""" |
|
|
return VAEManager(model_name, device) |
|
|
|
|
|
def save_images_from_latents(latents: torch.Tensor, save_dir: str, vae_manager: VAEManager, prefix: str = "sample"): |
|
|
""" |
|
|
Save images from latents using VAE decoder |
|
|
|
|
|
Args: |
|
|
latents: Tensor of shape [B, 4, H, W] |
|
|
save_dir: Directory to save images |
|
|
vae_manager: VAE manager instance |
|
|
prefix: Filename prefix |
|
|
""" |
|
|
import os |
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
pil_images = vae_manager.decode_to_pil(latents) |
|
|
|
|
|
|
|
|
for i, pil_image in enumerate(pil_images): |
|
|
save_path = os.path.join(save_dir, f"{prefix}_{i:03d}.png") |
|
|
pil_image.save(save_path) |
|
|
|
|
|
print(f"Saved {len(pil_images)} images to {save_dir}") |
|
|
|
|
|
def create_image_grid(latents: torch.Tensor, vae_manager: VAEManager, nrow: int = 4) -> torch.Tensor: |
|
|
""" |
|
|
Create image grid from latents |
|
|
|
|
|
Args: |
|
|
latents: Tensor of shape [B, 4, H, W] |
|
|
vae_manager: VAE manager instance |
|
|
nrow: Number of images per row |
|
|
|
|
|
Returns: |
|
|
grid: Image grid tensor |
|
|
""" |
|
|
import torchvision.utils as vutils |
|
|
|
|
|
|
|
|
images = vae_manager.decode(latents) |
|
|
|
|
|
|
|
|
images = (images + 1.0) / 2.0 |
|
|
images = torch.clamp(images, 0, 1) |
|
|
|
|
|
|
|
|
grid = vutils.make_grid(images, nrow=nrow, padding=2) |
|
|
|
|
|
return grid |