{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "6ca10d55-03ed-4c8b-b32b-8d2f94d77162", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 256, 512, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "test log-variance: 0.065\n", "Готово\n" ] } ], "source": [ "import torch\n", "from PIL import Image\n", "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n", "from torchvision.transforms.functional import to_pil_image\n", "import matplotlib.pyplot as plt\n", "import os\n", "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n", "\n", "# путь к вашей картинке\n", "IMG_PATH = \"123456789.jpg\"\n", "OUT_DIR = \"test\"\n", "device = \"cuda\"\n", "dtype = torch.float16 \n", "os.makedirs(OUT_DIR, exist_ok=True)\n", "\n", "# список VAE\n", "VAES = {\n", " \"test\": \"/workspace/simple_vae2x\",\n", "}\n", "\n", "def load_image(path):\n", " img = Image.open(path).convert('RGB')\n", " # обрезаем до кратности 8\n", " w, h = img.size\n", " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n", " tensor = ToTensor()(img).unsqueeze(0) # [0,1]\n", " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor) # [-1,1]\n", " return img, tensor.to(device, dtype=dtype)\n", "\n", "# обратно в PIL\n", "def tensor_to_img(t):\n", " t = (t * 0.5 + 0.5).clamp(0, 1)\n", " return to_pil_image(t[0])\n", "\n", "def logvariance(latents):\n", " \"\"\"Возвращает лог-дисперсию по всем элементам.\"\"\"\n", " return torch.log(latents.var() + 1e-8).item()\n", "\n", "def plot_latent_distribution(latents, title, save_path):\n", " \"\"\"Гистограмма + QQ-plot.\"\"\"\n", " lat = latents.detach().cpu().numpy().flatten()\n", " plt.figure(figsize=(10, 4))\n", "\n", " # гистограмма\n", " plt.subplot(1, 2, 1)\n", " plt.hist(lat, bins=100, density=True, alpha=0.7, color='steelblue')\n", " plt.title(f\"{title} histogram\")\n", " plt.xlabel(\"latent value\")\n", " plt.ylabel(\"density\")\n", "\n", " # QQ-plot\n", " from scipy.stats import probplot\n", " plt.subplot(1, 2, 2)\n", " probplot(lat, dist=\"norm\", plot=plt)\n", " plt.title(f\"{title} QQ-plot\")\n", "\n", " plt.tight_layout()\n", " plt.savefig(save_path)\n", " plt.close()\n", "\n", "for name, repo in VAES.items():\n", " if name==\"test\":\n", " vae = AsymmetricAutoencoderKL.from_pretrained(repo, subfolder=\"vae\", torch_dtype=dtype).to(device)\n", " else:\n", " vae = AutoencoderKL.from_pretrained(repo, torch_dtype=dtype).to(device)#, subfolder=\"vae\", variant=\"fp16\"\n", "\n", " cfg = vae.config\n", " scale = getattr(cfg, \"scaling_factor\", 1.)\n", " shift = getattr(cfg, \"shift_factor\", 0.0)\n", " mean = getattr(cfg, \"latents_mean\", None)\n", " std = getattr(cfg, \"latents_std\", None)\n", "\n", " C = 16 # 4 для SDXL\n", " if mean is not None:\n", " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, C, 1, 1)\n", " if std is not None:\n", " std = torch.tensor(std, device=device, dtype=dtype).view(1, C, 1, 1)\n", " if shift is not None:\n", " shift = torch.tensor(shift, device=device, dtype=dtype)\n", " else:\n", " shift = 0.0 \n", "\n", " scale = torch.tensor(scale, device=device, dtype=dtype)\n", "\n", " img, x = load_image(IMG_PATH)\n", " img.save(os.path.join(OUT_DIR, f\"original.jpg\"))\n", "\n", " with torch.no_grad():\n", " # encode\n", " latents = vae.encode(x).latent_dist.sample().to(dtype)\n", " if mean is not None and std is not None:\n", " latents = (latents - mean) / std\n", " latents = latents * scale + shift\n", "\n", " lv = logvariance(latents)\n", " print(f\"{name} log-variance: {lv:.3f}\")\n", "\n", " # график\n", " plot_latent_distribution(latents, f\"{name}_latents\",\n", " os.path.join(OUT_DIR, f\"dist_{name}.png\"))\n", "\n", " # decode\n", " latents = (latents - shift) / scale\n", " if mean is not None and std is not None:\n", " latents = latents * std + mean\n", " rec = vae.decode(latents).sample\n", "\n", " tensor_to_img(rec).save(os.path.join(OUT_DIR, f\"decoded_{name}.png\"))\n", "\n", "print(\"Готово\")\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "5e930fd3-0aa5-4ed6-beab-e871df009125", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting scipy\n", " Downloading scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (62 kB)\n", "Requirement already satisfied: numpy<2.6,>=1.25.2 in /usr/local/lib/python3.12/dist-packages (from scipy) (2.1.2)\n", "Downloading scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.7/35.7 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0mm0:00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hInstalling collected packages: scipy\n", "Successfully installed scipy-1.16.2\n" ] } ], "source": [ "!pip install scipy" ] }, { "cell_type": "code", "execution_count": null, "id": "72785e98-5dad-48a3-809b-3ab9755ac9db", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }