| | """ |
| | Helper functions to encode image or decode latents using Qwen-VAE |
| | """ |
| |
|
| | from typing import List |
| |
|
| | from diffusers import AutoencoderKLQwenImage |
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| |
|
| |
|
| | def encode_image( |
| | images: List[Image.Image], model: AutoencoderKLQwenImage |
| | ) -> torch.Tensor: |
| | """ |
| | Encode a batch of PIL Images to latents. |
| | |
| | Args: |
| | images: List of PIL Images |
| | model: Qwen VAE model |
| | |
| | Returns: |
| | Latents tensor of shape [batch_size, channel, height, width] |
| | """ |
| | device = next(model.parameters()).device |
| | dtype = next(model.parameters()).dtype |
| |
|
| | |
| | tensors = [] |
| | for img in images: |
| | tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).to(dtype) |
| | tensor = tensor / 127.5 - 1.0 |
| | tensors.append(tensor) |
| |
|
| | batch = torch.stack(tensors).to(device) |
| |
|
| | |
| | batch = batch.unsqueeze(2) |
| |
|
| | with torch.no_grad(): |
| | latents = model.encode(batch).latent_dist.sample() |
| |
|
| | |
| | latents = latents.squeeze(2) |
| |
|
| | return latents |
| |
|
| |
|
| | def decode_latents( |
| | latents: torch.Tensor, model: AutoencoderKLQwenImage |
| | ) -> List[Image.Image]: |
| | """ |
| | Decode latents to a batch of PIL Images. |
| | |
| | Args: |
| | latents: Latents tensor of shape [B, C, H, W] |
| | model: Qwen VAE model |
| | |
| | Returns: |
| | List of PIL Images |
| | """ |
| | device = next(model.parameters()).device |
| | dtype = next(model.parameters()).device |
| |
|
| | |
| | latents = latents.unsqueeze(2).to(device).to(dtype) |
| |
|
| | with torch.no_grad(): |
| | reconstructed = model.decode(latents).sample |
| |
|
| | |
| | reconstructed = reconstructed.squeeze(2) |
| |
|
| | |
| | images = [] |
| | for tensor in reconstructed: |
| | tensor = torch.clamp(tensor, -1.0, 1.0) |
| | tensor = (tensor + 1.0) * 127.5 |
| |
|
| | img_array = tensor.permute(1, 2, 0).cpu().to(torch.uint8).numpy() |
| | images.append(Image.fromarray(img_array)) |
| |
|
| | return images |
| |
|
| |
|
| | if __name__ == "__main__": |
| | vae = AutoencoderKLQwenImage.from_pretrained( |
| | "REPA-E/e2e-qwenimage-vae", torch_dtype=torch.bfloat16, device_map="cuda:0" |
| | ) |
| |
|
| | test_image = Image.open("test_image.png") |
| |
|
| | latents = encode_image([test_image], vae) |
| | print(f"Latents shape: {latents.shape}") |
| | print(f"Latents dtype: {latents.dtype}") |
| | print(f"Latents device: {latents.device}") |
| |
|
| | |
| | reconstructed_images = decode_latents(latents, vae) |
| | print(f"Reconstructed {len(reconstructed_images)} images") |
| |
|
| | |
| | reconstructed_images[0].save("test_reconstruction.png") |
| | print("Saved reconstructed image to test_reconstruction.png") |
| |
|
| |
|
| | def get_vae_stats( |
| | vae_path: str, device: torch.device = None |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Get latent mean and std from VAE config. |
| | |
| | Args: |
| | vae_path: Path to VAE model |
| | device: Device to put tensors on |
| | |
| | Returns: |
| | (mean, std) tuple of tensors with shape [1, C, 1, 1] |
| | """ |
| | |
| | from transformers import PretrainedConfig |
| |
|
| | try: |
| | config = PretrainedConfig.from_pretrained(vae_path) |
| | except Exception: |
| | |
| | vae = AutoencoderKLQwenImage.from_pretrained( |
| | vae_path, torch_dtype=torch.bfloat16, local_files_only=True |
| | ) |
| | config = vae.config |
| | del vae |
| |
|
| | if hasattr(config, "latents_mean") and config.latents_mean is not None: |
| | mean = torch.tensor(config.latents_mean).view(1, -1, 1, 1) |
| | else: |
| | mean = torch.zeros(1, 16, 1, 1) |
| |
|
| | if hasattr(config, "latents_std") and config.latents_std is not None: |
| | std = torch.tensor(config.latents_std).view(1, -1, 1, 1) |
| | else: |
| | std = torch.ones(1, 16, 1, 1) |
| |
|
| | if device: |
| | mean = mean.to(device) |
| | std = std.to(device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | return mean, std |
| |
|