File size: 4,485 Bytes
eb52c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Helper functions to encode image or decode latents using Qwen-VAE
"""

from typing import List

from diffusers import AutoencoderKLQwenImage
import numpy as np
import torch
from PIL import Image


def encode_image(
    images: List[Image.Image], model: AutoencoderKLQwenImage
) -> torch.Tensor:
    """
    Encode a batch of PIL Images to latents.

    Args:
        images: List of PIL Images
        model: Qwen VAE model

    Returns:
        Latents tensor of shape [batch_size, channel, height, width]
    """
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype

    # Convert PIL images to tensors and normalize to [-1, 1]
    tensors = []
    for img in images:
        tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).to(dtype)
        tensor = tensor / 127.5 - 1.0
        tensors.append(tensor)

    batch = torch.stack(tensors).to(device)  # [b, c, h, w]

    # Add temporal dimension [b, c, 1, h, w]
    batch = batch.unsqueeze(2)

    with torch.no_grad():
        latents = model.encode(batch).latent_dist.sample()

    # Remove temporal dimension [b, c, h, w]
    latents = latents.squeeze(2)

    return latents


def decode_latents(
    latents: torch.Tensor, model: AutoencoderKLQwenImage
) -> List[Image.Image]:
    """
    Decode latents to a batch of PIL Images.

    Args:
        latents: Latents tensor of shape [B, C, H, W]
        model: Qwen VAE model

    Returns:
        List of PIL Images
    """
    device = next(model.parameters()).device
    dtype = next(model.parameters()).device

    # Add temporal dimension [b, c, 1, h, w]
    latents = latents.unsqueeze(2).to(device).to(dtype)

    with torch.no_grad():
        reconstructed = model.decode(latents).sample

    # Remove temporal dimension [b, c, h, w]
    reconstructed = reconstructed.squeeze(2)

    # Denormalize from [-1, 1] to [0, 255] and convert to PIL
    images = []
    for tensor in reconstructed:
        tensor = torch.clamp(tensor, -1.0, 1.0)
        tensor = (tensor + 1.0) * 127.5

        img_array = tensor.permute(1, 2, 0).cpu().to(torch.uint8).numpy()
        images.append(Image.fromarray(img_array))

    return images


if __name__ == "__main__":
    vae = AutoencoderKLQwenImage.from_pretrained(
        "REPA-E/e2e-qwenimage-vae", torch_dtype=torch.bfloat16, device_map="cuda:0"
    )

    test_image = Image.open("test_image.png")

    latents = encode_image([test_image], vae)
    print(f"Latents shape: {latents.shape}")
    print(f"Latents dtype: {latents.dtype}")
    print(f"Latents device: {latents.device}")

    # Decode back to images
    reconstructed_images = decode_latents(latents, vae)
    print(f"Reconstructed {len(reconstructed_images)} images")

    # Save reconstructed image
    reconstructed_images[0].save("test_reconstruction.png")
    print("Saved reconstructed image to test_reconstruction.png")


def get_vae_stats(
    vae_path: str, device: torch.device = None
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Get latent mean and std from VAE config.

    Args:
        vae_path: Path to VAE model
        device: Device to put tensors on

    Returns:
        (mean, std) tuple of tensors with shape [1, C, 1, 1]
    """
    # Load config only to be fast
    from transformers import PretrainedConfig

    try:
        config = PretrainedConfig.from_pretrained(vae_path)
    except Exception:
        # Fallback to loading full model if config load fails (e.g. local path issues)
        vae = AutoencoderKLQwenImage.from_pretrained(
            vae_path, torch_dtype=torch.bfloat16, local_files_only=True
        )
        config = vae.config
        del vae

    if hasattr(config, "latents_mean") and config.latents_mean is not None:
        mean = torch.tensor(config.latents_mean).view(1, -1, 1, 1)
    else:
        mean = torch.zeros(1, 16, 1, 1)

    if hasattr(config, "latents_std") and config.latents_std is not None:
        std = torch.tensor(config.latents_std).view(1, -1, 1, 1)
    else:
        std = torch.ones(1, 16, 1, 1)

    if device:
        mean = mean.to(device)
        std = std.to(device)

    # print(
    #    mean
    # )  # [-0.0418, -0.0157, -0.0053, -0.0127, -0.0445, 0.0351, -0.0367, 0.0239, -0.0363, -0.0044, 0.0380, -0.0015, -0.0821, -0.1100, -0.0483, 0.0077]
    # print(
    #    std
    # )  # [2.3349, 2.3665, 2.3873, 2.3958, 2.3773, 2.4054, 2.3908, 2.3725, 2.3623, 2.3824, 2.4043, 2.3669, 2.3800, 2.3779, 2.3889, 2.3639]
    return mean, std