{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "dccce86b-90a0-47c7-aaad-2ebb16d90756", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Картинка загружена: torch.Size([1, 3, 1280, 1280])\n", "\n", "=======================================================\n", "VAE : FLUX.2\n", "repo: AiArtLab/sdxs-1b\n", "latent_channels : 32\n", "scaling_factor : 1.00000\n", "shift_factor : 0.00000\n", "latents_mean : нет\n", "latents_std : нет\n", "\n", "[encode] raw latents: torch.Size([1, 32, 160, 160])\n", "[flux2] patchify : torch.Size([1, 32, 160, 160]) → torch.Size([1, 128, 80, 80])\n", "[flux2] BN norm : mean=-0.0096 std=1.7674\n", "\n", "[STATS] после BN нормализации (128ch):\n", " log-variance : -0.0767 (идеал ≈ 0.0)\n", " mean : -0.0134\n", " std : 0.9624\n", " shape : torch.Size([1, 128, 80, 80])\n", "\n", "[flux2] BN denorm + unpatchify: torch.Size([1, 32, 160, 160])\n", "Сохранено: vaetest/decoded_FLUX.2.png\n", "\n", "=======================================================\n", "VAE : vae32ch2\n", "repo: vae32ch2\n", "latent_channels : 32\n", "scaling_factor : 1.00000\n", "shift_factor : 0.00000\n", "latents_mean : да (32ch)\n", "latents_std : да (32ch)\n", "\n", "[encode] raw latents: torch.Size([1, 32, 160, 160])\n", "\n", "[STATS] после per-channel нормализации (32ch):\n", " log-variance : 0.1192 (идеал ≈ 0.0)\n", " mean : -0.0016\n", " std : 1.0614\n", " shape : torch.Size([1, 32, 160, 160])\n", "\n", "[vae32ch2] denorm: torch.Size([1, 32, 160, 160])\n", "Сохранено: vaetest/decoded_vae32ch2.png\n", "\n", "=======================================================\n", "Готово\n" ] } ], "source": [ "\n", "import torch\n", "from PIL import Image\n", "from diffusers import AutoencoderKL, AutoencoderKLFlux2\n", "from torchvision.transforms.functional import to_pil_image\n", "import matplotlib.pyplot as plt\n", "import os\n", "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n", "\n", "# ── Настройки ─────────────────────────────────────────────────────────────────\n", "IMG_PATH = \"1234.png\"\n", "OUT_DIR = \"vaetest\"\n", "device = \"cuda\"\n", "dtype = torch.float32\n", "os.makedirs(OUT_DIR, exist_ok=True)\n", "\n", "VAES = {\n", " \"FLUX.2\": (\"flux2\", \"AiArtLab/sdxs-1b\"),\n", " \"vae32ch2\": (\"vae32ch\", \"vae32ch2\"),\n", "}\n", "\n", "# ── Patchify / Unpatchify ─────────────────────────────────────────────────────\n", "def _patchify_latents(latents):\n", " B, C, H, W = latents.shape\n", " latents = latents.view(B, C, H // 2, 2, W // 2, 2)\n", " latents = latents.permute(0, 1, 3, 5, 2, 4)\n", " latents = latents.reshape(B, C * 4, H // 2, W // 2)\n", " return latents\n", "\n", "def _unpatchify_latents(latents):\n", " B, C, H, W = latents.shape\n", " latents = latents.reshape(B, C // 4, 2, 2, H, W)\n", " latents = latents.permute(0, 1, 4, 2, 5, 3)\n", " latents = latents.reshape(B, C // 4, H * 2, W * 2)\n", " return latents\n", "\n", "# ── Загрузка картинки ─────────────────────────────────────────────────────────\n", "def load_image(path):\n", " img = Image.open(path).convert(\"RGB\")\n", " w, h = img.size\n", " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n", " tensor = ToTensor()(img).unsqueeze(0)\n", " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor)\n", " return img, tensor.to(device, dtype=dtype)\n", "\n", "def tensor_to_img(t):\n", " t = (t * 0.5 + 0.5).clamp(0, 1)\n", " return to_pil_image(t[0])\n", "\n", "# ── Статистика ────────────────────────────────────────────────────────────────\n", "def logvariance(latents):\n", " return torch.log(latents.var() + 1e-8).item()\n", "\n", "def print_stats(name, latents):\n", " lv = logvariance(latents)\n", " print(f\" log-variance : {lv:.4f} (идеал ≈ 0.0)\")\n", " print(f\" mean : {latents.mean():.4f}\")\n", " print(f\" std : {latents.std():.4f}\")\n", " print(f\" shape : {latents.shape}\")\n", "\n", "def plot_latent_distribution(latents, title, save_path):\n", " from scipy.stats import probplot\n", " lat = latents.detach().cpu().float().numpy().flatten()\n", "\n", " plt.figure(figsize=(10, 4))\n", "\n", " plt.subplot(1, 2, 1)\n", " plt.hist(lat, bins=100, density=True, alpha=0.7, color=\"steelblue\")\n", " plt.title(f\"{title} histogram\")\n", " plt.xlabel(\"latent value\")\n", " plt.ylabel(\"density\")\n", "\n", " plt.subplot(1, 2, 2)\n", " probplot(lat, dist=\"norm\", plot=plt)\n", " plt.title(f\"{title} QQ-plot\")\n", "\n", " plt.tight_layout()\n", " plt.savefig(save_path)\n", " plt.close()\n", " print(f\" график сохранён: {save_path}\")\n", "\n", "# ── Нормализация из конфига (per-channel для vae32ch) ────────────────────────\n", "def make_norm_tensors(cfg, latent_channels, device, dtype):\n", " mean = getattr(cfg, \"latents_mean\", None)\n", " std = getattr(cfg, \"latents_std\", None)\n", " shift = getattr(cfg, \"shift_factor\", 0.0)\n", " scale = getattr(cfg, \"scaling_factor\", 1.0)\n", "\n", " if mean is not None:\n", " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, latent_channels, 1, 1)\n", " if std is not None:\n", " std = torch.tensor(std, device=device, dtype=dtype).view(1, latent_channels, 1, 1)\n", "\n", " shift = torch.tensor(shift if shift else 0., device=device, dtype=dtype)\n", " scale = torch.tensor(scale, device=device, dtype=dtype)\n", " return mean, std, shift, scale\n", "\n", "# ── Основной цикл ─────────────────────────────────────────────────────────────\n", "img, x = load_image(IMG_PATH)\n", "img.save(os.path.join(OUT_DIR, \"original.png\"))\n", "print(f\"Картинка загружена: {x.shape}\")\n", "\n", "for name, (kind, repo) in VAES.items():\n", " print(f\"\\n{'='*55}\")\n", " print(f\"VAE : {name}\")\n", " print(f\"repo: {repo}\")\n", "\n", " # --- загружаем нужный класс ---\n", " if kind == \"flux2\":\n", " vae = AutoencoderKLFlux2.from_pretrained(\n", " repo, subfolder=\"vae\", torch_dtype=dtype\n", " ).to(device)\n", " else:\n", " vae = AutoencoderKL.from_pretrained(\n", " repo, torch_dtype=dtype\n", " ).to(device)\n", " vae.eval()\n", "\n", " latent_channels = vae.config.latent_channels\n", " mean_t, std_t, shift_t, scale_t = make_norm_tensors(\n", " vae.config, latent_channels, device, dtype\n", " )\n", "\n", " print(f\"latent_channels : {latent_channels}\")\n", " print(f\"scaling_factor : {scale_t.item():.5f}\")\n", " print(f\"shift_factor : {shift_t.item():.5f}\")\n", " print(f\"latents_mean : {'да (' + str(latent_channels) + 'ch)' if mean_t is not None else 'нет'}\")\n", " print(f\"latents_std : {'да (' + str(latent_channels) + 'ch)' if std_t is not None else 'нет'}\")\n", "\n", " with torch.no_grad():\n", "\n", " # ── ENCODE ────────────────────────────────────────────────────────────\n", " latents = vae.encode(x).latent_dist.sample().to(dtype)\n", " print(f\"\\n[encode] raw latents: {latents.shape}\")\n", "\n", " if kind == \"flux2\":\n", " # 32ch → patchify → 128ch\n", " latents_patched = _patchify_latents(latents)\n", " print(f\"[flux2] patchify : {latents.shape} → {latents_patched.shape}\")\n", "\n", " # BN нормализация в 128-канальном пространстве\n", " bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device, dtype)\n", " bn_std = torch.sqrt(\n", " vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps\n", " ).to(device, dtype)\n", " latents_normed = (latents_patched - bn_mean) / bn_std\n", " print(f\"[flux2] BN norm : mean={bn_mean.mean():.4f} std={bn_std.mean():.4f}\")\n", "\n", " # считаем статистику в 128ch нормализованном пространстве\n", " print(\"\\n[STATS] после BN нормализации (128ch):\")\n", " print_stats(name, latents_normed)\n", " #plot_latent_distribution(\n", " # latents_normed,\n", " # f\"{name}_latents\",\n", " # os.path.join(OUT_DIR, f\"dist_{name}.png\")\n", " #)\n", "\n", " # unpatchify → 32ch (для decode)\n", " latents = _unpatchify_latents(latents_normed)\n", "\n", " else: # vae32ch2\n", " # per-channel нормализация из конфига\n", " if mean_t is not None and std_t is not None:\n", " latents = (latents - mean_t) / std_t\n", " latents = (latents - shift_t) / scale_t\n", "\n", " print(f\"\\n[STATS] после per-channel нормализации ({latent_channels}ch):\")\n", " print_stats(name, latents)\n", " #plot_latent_distribution(\n", " # latents,\n", " # f\"{name}_latents\",\n", " # os.path.join(OUT_DIR, f\"dist_{name}.png\")\n", " #)\n", "\n", " # ── DECODE ────────────────────────────────────────────────────────────\n", " if kind == \"flux2\":\n", " # patchify → denorm → unpatchify\n", " latents_patched = _patchify_latents(latents)\n", " latents_denormed = latents_patched * bn_std + bn_mean\n", " latents = _unpatchify_latents(latents_denormed)\n", " print(f\"\\n[flux2] BN denorm + unpatchify: {latents.shape}\")\n", "\n", " else: # vae32ch2\n", " latents = latents * scale_t + shift_t\n", " if mean_t is not None and std_t is not None:\n", " latents = latents * std_t + mean_t\n", " print(f\"\\n[vae32ch2] denorm: {latents.shape}\")\n", "\n", " rec = vae.decode(latents).sample\n", "\n", " out_path = os.path.join(OUT_DIR, f\"decoded_{name}.png\")\n", " tensor_to_img(rec).save(out_path)\n", " print(f\"Сохранено: {out_path}\")\n", "\n", "print(f\"\\n{'='*55}\")\n", "print(\"Готово\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "c219c07b-8da2-4182-ace6-8c3cc63ae3b1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: scipy in /usr/local/lib/python3.12/dist-packages (1.17.1)\n", "Requirement already satisfied: numpy<2.7,>=1.26.4 in /usr/local/lib/python3.12/dist-packages (from scipy) (2.4.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] }, { "ename": "ModuleNotFoundError", "evalue": "No module named 'scipy'", "output_type": "error", "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m get_ipython().system(\u001b[33m'\u001b[39m\u001b[33mpip install --user scipy\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mscipy\u001b[39;00m\n\u001b[32m 4\u001b[39m \u001b[38;5;28mprint\u001b[39m(scipy.__version__)\n", "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'scipy'" ] } ], "source": [ "!pip install --user scipy\n", "\n", "import scipy\n", "print(scipy.__version__)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "43a4e1bc-2b02-4604-b69e-1a5aa276b6ac", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }