In [1]:
import torch
from PIL import Image
from diffusers import AutoencoderKL,AsymmetricAutoencoderKL
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
import os
from torchvision.transforms import ToTensor, Normalize, CenterCrop

# путь к вашей картинке
IMG_PATH = "123456789.jpg"
OUT_DIR = "test"
device = "cuda"
dtype = torch.float16 
os.makedirs(OUT_DIR, exist_ok=True)

# список VAE
VAES = {
 "test": "/workspace/simple_vae2x",
}

def load_image(path):
 img = Image.open(path).convert('RGB')
 # обрезаем до кратности 8
 w, h = img.size
 img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)
 tensor = ToTensor()(img).unsqueeze(0) # [0,1]
 tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor) # [-1,1]
 return img, tensor.to(device, dtype=dtype)

# обратно в PIL
def tensor_to_img(t):
 t = (t * 0.5 + 0.5).clamp(0, 1)
 return to_pil_image(t[0])

def logvariance(latents):
 """Возвращает лог-дисперсию по всем элементам."""
 return torch.log(latents.var() + 1e-8).item()

def plot_latent_distribution(latents, title, save_path):
 """Гистограмма + QQ-plot."""
 lat = latents.detach().cpu().numpy().flatten()
 plt.figure(figsize=(10, 4))

 # гистограмма
 plt.subplot(1, 2, 1)
 plt.hist(lat, bins=100, density=True, alpha=0.7, color='steelblue')
 plt.title(f"{title} histogram")
 plt.xlabel("latent value")
 plt.ylabel("density")

 # QQ-plot
 from scipy.stats import probplot
 plt.subplot(1, 2, 2)
 probplot(lat, dist="norm", plot=plt)
 plt.title(f"{title} QQ-plot")

 plt.tight_layout()
 plt.savefig(save_path)
 plt.close()

for name, repo in VAES.items():
 if name=="test":
 vae = AsymmetricAutoencoderKL.from_pretrained(repo, subfolder="vae", torch_dtype=dtype).to(device)
 else:
 vae = AutoencoderKL.from_pretrained(repo, torch_dtype=dtype).to(device)#, subfolder="vae", variant="fp16"

 cfg = vae.config
 scale = getattr(cfg, "scaling_factor", 1.)
 shift = getattr(cfg, "shift_factor", 0.0)
 mean = getattr(cfg, "latents_mean", None)
 std = getattr(cfg, "latents_std", None)

 C = 16 # 4 для SDXL
 if mean is not None:
 mean = torch.tensor(mean, device=device, dtype=dtype).view(1, C, 1, 1)
 if std is not None:
 std = torch.tensor(std, device=device, dtype=dtype).view(1, C, 1, 1)
 if shift is not None:
 shift = torch.tensor(shift, device=device, dtype=dtype)
 else:
 shift = 0.0 

 scale = torch.tensor(scale, device=device, dtype=dtype)

 img, x = load_image(IMG_PATH)
 img.save(os.path.join(OUT_DIR, f"original.jpg"))

 with torch.no_grad():
 # encode
 latents = vae.encode(x).latent_dist.sample().to(dtype)
 if mean is not None and std is not None:
 latents = (latents - mean) / std
 latents = latents * scale + shift

 lv = logvariance(latents)
 print(f"{name} log-variance: {lv:.3f}")

 # график
 plot_latent_distribution(latents, f"{name}_latents",
 os.path.join(OUT_DIR, f"dist_{name}.png"))

 # decode
 latents = (latents - shift) / scale
 if mean is not None and std is not None:
 latents = latents * std + mean
 rec = vae.decode(latents).sample

 tensor_to_img(rec).save(os.path.join(OUT_DIR, f"decoded_{name}.png"))

print("Готово")


The config attributes {'block_out_channels': [128, 256, 512, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.


test log-variance: 0.065
Готово


In [5]:
!pip install scipy

Collecting scipy
 Downloading scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (62 kB)
Downloading scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.7 MB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.7/35.7 MB[0m [31m58.9 MB/s[0m [33m0:00:00[0mm0:00:01[0m00:01[0m
[?25hInstalling collected packages: scipy
Successfully installed scipy-1.16.2
