{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c15deb04-94a0-4073-a174-adcd22af10b8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "✅ Создана новая модель: \n", "\n", "--- Перенос весов ---\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 326/326 [00:00<00:00, 54186.54it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "✅ Перенос завершён.\n", "Статистика:\n", " перенесено: 227\n", " дублировано: 0\n", " пропущено: 0\n", "AutoencoderKL(\n", " (encoder): Encoder(\n", " (conv_in): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (down_blocks): ModuleList(\n", " (0): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-2): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (1): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)\n", " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-2): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (2): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-2): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (3): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-2): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (4): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-2): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " )\n", " (mid_block): UNetMidBlock2D(\n", " (attentions): ModuleList(\n", " (0): Attention(\n", " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (to_q): Linear(in_features=512, out_features=512, bias=True)\n", " (to_k): Linear(in_features=512, out_features=512, bias=True)\n", " (to_v): Linear(in_features=512, out_features=512, bias=True)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv_act): SiLU()\n", " (conv_out): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " (decoder): Decoder(\n", " (conv_in): Conv2d(16, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (up_blocks): ModuleList(\n", " (0-1): 2 x UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-3): 4 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (2): UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-3): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (3): UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-3): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (4): UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-3): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " )\n", " (mid_block): UNetMidBlock2D(\n", " (attentions): ModuleList(\n", " (0): Attention(\n", " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (to_q): Linear(in_features=512, out_features=512, bias=True)\n", " (to_k): Linear(in_features=512, out_features=512, bias=True)\n", " (to_v): Linear(in_features=512, out_features=512, bias=True)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " (conv_norm_out): GroupNorm(32, 64, eps=1e-06, affine=True)\n", " (conv_act): SiLU()\n", " (conv_out): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", ")\n" ] } ], "source": [ "from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL\n", "import torch\n", "from tqdm import tqdm\n", "\n", "# ---- Конфиг новой модели ----\n", "config = {\n", " \"_class_name\": \"AutoencoderKL\",\n", " \"act_fn\": \"silu\",\n", " \"in_channels\": 3,\n", " \"out_channels\": 3,\n", " \"scaling_factor\": 1.0,\n", " \"norm_num_groups\": 32,\n", " \"block_out_channels\": [64, 128, 256, 512, 512],\n", " \"down_block_types\": [\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " ],\n", " \"latent_channels\": 16,\n", " \"up_block_types\": [\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " ],\n", "}\n", "\n", "# ---- Создание пустой асимметричной модели ----\n", "vae = AutoencoderKL(\n", " act_fn=config[\"act_fn\"],\n", " block_out_channels=config[\"block_out_channels\"],\n", " down_block_types=config[\"down_block_types\"],\n", " latent_channels=config[\"latent_channels\"],\n", " up_block_types=config[\"up_block_types\"],\n", " in_channels=config[\"in_channels\"],\n", " out_channels=config[\"out_channels\"],\n", " scaling_factor=config[\"scaling_factor\"],\n", " norm_num_groups=config[\"norm_num_groups\"],\n", " layers_per_block=3,\n", " sample_size=1024,\n", " use_post_quant_conv = False,\n", " use_quant_conv = False,\n", ")\n", "\n", "vae.save_pretrained(\"vae_empty\")\n", "print(\"✅ Создана новая модель:\", type(vae))\n", "\n", "# ---- Функция переноса весов старого VAE ----\n", "def transfer_weights(old_path, new_path, save_path=\"asymmetric_vae\", device=\"cuda\", dtype=torch.float16):\n", " old_vae = AsymmetricAutoencoderKL.from_pretrained(old_path).to(device, dtype=dtype)\n", " new_vae = AutoencoderKL.from_pretrained(new_path).to(device, dtype=dtype)\n", "\n", " old_sd = old_vae.state_dict()\n", " new_sd = new_vae.state_dict()\n", "\n", " transferred_keys = set()\n", " transfer_stats = {\"перенесено\": 0, \"дублировано\": 0, \"пропущено\": 0}\n", "\n", " print(\"\\n--- Перенос весов ---\")\n", " for k, v in tqdm(old_sd.items()):\n", " # Копирование энкодера и прочих совпадающих ключей\n", " if (\"encoder\" in k) or (\"quant_conv\" in k) or (\"post_quant_conv\" in k):\n", " if k in new_sd and new_sd[k].shape == v.shape:\n", " new_sd[k] = v.clone()\n", " transferred_keys.add(k)\n", " transfer_stats[\"перенесено\"] += 1\n", " continue\n", "\n", " # Копирование декодера (без сдвига)\n", " if \"decoder.up_blocks\" in k:\n", " if k in new_sd and new_sd[k].shape == v.shape:\n", " new_sd[k] = v.clone()\n", " transferred_keys.add(k)\n", " transfer_stats[\"перенесено\"] += 1\n", " continue\n", "\n", " # Дублирование весов старого первого 512→512 блока в новый блок 64→128 для апскейла\n", " #ref_prefix = \"encoder.down_blocks.1\"\n", " #new_prefix = \"encoder.down_blocks.0\"\n", " #for k, v in old_sd.items():\n", " # if k.startswith(ref_prefix) and new_prefix + k[len(ref_prefix):] in new_sd:\n", " # new_k = k.replace(ref_prefix, new_prefix)\n", " # if new_sd[new_k].shape == v.shape:\n", " # new_sd[new_k] = v.clone()\n", " # transferred_keys.add(new_k)\n", " # transfer_stats[\"дублировано\"] += 1\n", "\n", " # Загрузка и сохранение\n", " new_vae.load_state_dict(new_sd, strict=False)\n", " new_vae.save_pretrained(save_path)\n", "\n", " print(\"\\n✅ Перенос завершён.\")\n", " print(\"Статистика:\")\n", " for k, v in transfer_stats.items():\n", " print(f\" {k}: {v}\")\n", " print(new_vae)\n", "\n", "# ---- Запуск переноса ----\n", "transfer_weights(\"vae16\", \"vae_empty\", save_path=\"vae17\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "59fcafb9-6d89-49b4-8362-b4891f591687", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }