File size: 4,485 Bytes
eb52c18 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
Helper functions to encode image or decode latents using Qwen-VAE
"""
from typing import List
from diffusers import AutoencoderKLQwenImage
import numpy as np
import torch
from PIL import Image
def encode_image(
images: List[Image.Image], model: AutoencoderKLQwenImage
) -> torch.Tensor:
"""
Encode a batch of PIL Images to latents.
Args:
images: List of PIL Images
model: Qwen VAE model
Returns:
Latents tensor of shape [batch_size, channel, height, width]
"""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Convert PIL images to tensors and normalize to [-1, 1]
tensors = []
for img in images:
tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).to(dtype)
tensor = tensor / 127.5 - 1.0
tensors.append(tensor)
batch = torch.stack(tensors).to(device) # [b, c, h, w]
# Add temporal dimension [b, c, 1, h, w]
batch = batch.unsqueeze(2)
with torch.no_grad():
latents = model.encode(batch).latent_dist.sample()
# Remove temporal dimension [b, c, h, w]
latents = latents.squeeze(2)
return latents
def decode_latents(
latents: torch.Tensor, model: AutoencoderKLQwenImage
) -> List[Image.Image]:
"""
Decode latents to a batch of PIL Images.
Args:
latents: Latents tensor of shape [B, C, H, W]
model: Qwen VAE model
Returns:
List of PIL Images
"""
device = next(model.parameters()).device
dtype = next(model.parameters()).device
# Add temporal dimension [b, c, 1, h, w]
latents = latents.unsqueeze(2).to(device).to(dtype)
with torch.no_grad():
reconstructed = model.decode(latents).sample
# Remove temporal dimension [b, c, h, w]
reconstructed = reconstructed.squeeze(2)
# Denormalize from [-1, 1] to [0, 255] and convert to PIL
images = []
for tensor in reconstructed:
tensor = torch.clamp(tensor, -1.0, 1.0)
tensor = (tensor + 1.0) * 127.5
img_array = tensor.permute(1, 2, 0).cpu().to(torch.uint8).numpy()
images.append(Image.fromarray(img_array))
return images
if __name__ == "__main__":
vae = AutoencoderKLQwenImage.from_pretrained(
"REPA-E/e2e-qwenimage-vae", torch_dtype=torch.bfloat16, device_map="cuda:0"
)
test_image = Image.open("test_image.png")
latents = encode_image([test_image], vae)
print(f"Latents shape: {latents.shape}")
print(f"Latents dtype: {latents.dtype}")
print(f"Latents device: {latents.device}")
# Decode back to images
reconstructed_images = decode_latents(latents, vae)
print(f"Reconstructed {len(reconstructed_images)} images")
# Save reconstructed image
reconstructed_images[0].save("test_reconstruction.png")
print("Saved reconstructed image to test_reconstruction.png")
def get_vae_stats(
vae_path: str, device: torch.device = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Get latent mean and std from VAE config.
Args:
vae_path: Path to VAE model
device: Device to put tensors on
Returns:
(mean, std) tuple of tensors with shape [1, C, 1, 1]
"""
# Load config only to be fast
from transformers import PretrainedConfig
try:
config = PretrainedConfig.from_pretrained(vae_path)
except Exception:
# Fallback to loading full model if config load fails (e.g. local path issues)
vae = AutoencoderKLQwenImage.from_pretrained(
vae_path, torch_dtype=torch.bfloat16, local_files_only=True
)
config = vae.config
del vae
if hasattr(config, "latents_mean") and config.latents_mean is not None:
mean = torch.tensor(config.latents_mean).view(1, -1, 1, 1)
else:
mean = torch.zeros(1, 16, 1, 1)
if hasattr(config, "latents_std") and config.latents_std is not None:
std = torch.tensor(config.latents_std).view(1, -1, 1, 1)
else:
std = torch.ones(1, 16, 1, 1)
if device:
mean = mean.to(device)
std = std.to(device)
# print(
# mean
# ) # [-0.0418, -0.0157, -0.0053, -0.0127, -0.0445, 0.0351, -0.0367, 0.0239, -0.0363, -0.0044, 0.0380, -0.0015, -0.0821, -0.1100, -0.0483, 0.0077]
# print(
# std
# ) # [2.3349, 2.3665, 2.3873, 2.3958, 2.3773, 2.4054, 2.3908, 2.3725, 2.3623, 2.3824, 2.4043, 2.3669, 2.3800, 2.3779, 2.3889, 2.3639]
return mean, std
|