diff --git "a/create.ipynb" "b/create.ipynb" deleted file mode 100644--- "a/create.ipynb" +++ /dev/null @@ -1,2905 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 7, - "id": "407171be-ab46-442b-a0bd-83ca75173eba", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "AsymmetricAutoencoderKL(\n", - " (encoder): Encoder(\n", - " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (down_blocks): ModuleList(\n", - " (0): DownEncoderBlock2D(\n", - " (resnets): ModuleList(\n", - " (0-1): 2 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " (downsamplers): ModuleList(\n", - " (0): Downsample2D(\n", - " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n", - " )\n", - " )\n", - " )\n", - " (1): DownEncoderBlock2D(\n", - " (resnets): ModuleList(\n", - " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (1): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " (downsamplers): ModuleList(\n", - " (0): Downsample2D(\n", - " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", - " )\n", - " )\n", - " )\n", - " (2): DownEncoderBlock2D(\n", - " (resnets): ModuleList(\n", - " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (1): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " (downsamplers): ModuleList(\n", - " (0): Downsample2D(\n", - " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n", - " )\n", - " )\n", - " )\n", - " (3): DownEncoderBlock2D(\n", - " (resnets): ModuleList(\n", - " (0-1): 2 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (mid_block): UNetMidBlock2D(\n", - " (attentions): ModuleList(\n", - " (0): Attention(\n", - " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (to_q): Linear(in_features=512, out_features=512, bias=True)\n", - " (to_k): Linear(in_features=512, out_features=512, bias=True)\n", - " (to_v): Linear(in_features=512, out_features=512, bias=True)\n", - " (to_out): ModuleList(\n", - " (0): Linear(in_features=512, out_features=512, bias=True)\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " (resnets): ModuleList(\n", - " (0-1): 2 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " )\n", - " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (conv_act): SiLU()\n", - " (conv_out): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " (decoder): MaskConditionDecoder(\n", - " (conv_in): Conv2d(16, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (up_blocks): ModuleList(\n", - " (0-1): 2 x UpDecoderBlock2D(\n", - " (resnets): ModuleList(\n", - " (0-2): 3 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 768, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 768, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " (upsamplers): ModuleList(\n", - " (0): Upsample2D(\n", - " (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " )\n", - " )\n", - " (2): UpDecoderBlock2D(\n", - " (resnets): ModuleList(\n", - " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 768, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (1-2): 2 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " (upsamplers): ModuleList(\n", - " (0): Upsample2D(\n", - " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " )\n", - " )\n", - " (3): UpDecoderBlock2D(\n", - " (resnets): ModuleList(\n", - " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (1-2): 2 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " (upsamplers): ModuleList(\n", - " (0): Upsample2D(\n", - " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " )\n", - " )\n", - " (4): UpDecoderBlock2D(\n", - " (resnets): ModuleList(\n", - " (0): ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (1-2): 2 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (mid_block): UNetMidBlock2D(\n", - " (attentions): ModuleList(\n", - " (0): Attention(\n", - " (group_norm): GroupNorm(32, 768, eps=1e-06, affine=True)\n", - " (to_q): Linear(in_features=768, out_features=768, bias=True)\n", - " (to_k): Linear(in_features=768, out_features=768, bias=True)\n", - " (to_v): Linear(in_features=768, out_features=768, bias=True)\n", - " (to_out): ModuleList(\n", - " (0): Linear(in_features=768, out_features=768, bias=True)\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " (resnets): ModuleList(\n", - " (0-1): 2 x ResnetBlock2D(\n", - " (norm1): GroupNorm(32, 768, eps=1e-06, affine=True)\n", - " (conv1): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (norm2): GroupNorm(32, 768, eps=1e-06, affine=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (conv2): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (nonlinearity): SiLU()\n", - " )\n", - " )\n", - " )\n", - " (condition_encoder): MaskConditionEncoder(\n", - " (layers): Sequential(\n", - " (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (2): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (3): Conv2d(512, 768, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (4): Conv2d(768, 768, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " )\n", - " )\n", - " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n", - " (conv_act): SiLU()\n", - " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " (quant_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (post_quant_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))\n", - ")\n" - ] - } - ], - "source": [ - "from diffusers.models import AsymmetricAutoencoderKL\n", - "import torch\n", - "\n", - "config = {\n", - " \"_class_name\": \"AsymmetricAutoencoderKL\",\n", - " \"act_fn\": \"silu\",\n", - " \"down_block_out_channels\": [128, 256, 512, 512],\n", - " \"down_block_types\": [\n", - " \"DownEncoderBlock2D\",\n", - " \"DownEncoderBlock2D\",\n", - " \"DownEncoderBlock2D\",\n", - " \"DownEncoderBlock2D\",\n", - " ],\n", - " \"in_channels\": 3,\n", - " \"latent_channels\": 16,\n", - " #\"latents_mean\": [0.2539, 0.1431, 0.1484, -0.3048, -0.0985, -0.162, 0.1403, 0.2034, -0.1419, 0.2646, 0.0655, 0.0061, 0.1555, 0.0506, 0.0129, -0.1948],\n", - " #\"latents_std\": [0.8123, 0.7376, 0.7354, 1.1827, 0.8387, 0.8735, 0.8705, 0.8142, 0.8076, 0.7409, 0.7655, 0.8731, 0.8087, 0.7058, 0.8087, 0.7615],\n", - " #\"layers_per_block\": 2,\n", - " \"norm_num_groups\": 32,\n", - " \"out_channels\": 3,\n", - " \"sample_size\": 1024,\n", - " \"scaling_factor\": 1,\n", - " \"shift_factor\": 0,\n", - " \"up_block_out_channels\": [128, 256, 512, 768, 768],\n", - " \"up_block_types\": [\n", - " \"UpDecoderBlock2D\",\n", - " \"UpDecoderBlock2D\",\n", - " \"UpDecoderBlock2D\",\n", - " \"UpDecoderBlock2D\",\n", - " \"UpDecoderBlock2D\",\n", - " ],\n", - "}\n", - "\n", - "# Преобразуем списки mean и std в тензоры\n", - "#latents_mean = torch.tensor(config[\"latents_mean\"])\n", - "#latents_std = torch.tensor(config[\"latents_std\"])\n", - "\n", - "# Создаем модель\n", - "vae = AsymmetricAutoencoderKL(\n", - " act_fn=config[\"act_fn\"],\n", - " down_block_out_channels=config[\"down_block_out_channels\"],\n", - " down_block_types=config[\"down_block_types\"],\n", - " in_channels=config[\"in_channels\"],\n", - " latent_channels=config[\"latent_channels\"],\n", - " norm_num_groups=config[\"norm_num_groups\"],\n", - " out_channels=config[\"out_channels\"],\n", - " sample_size=config[\"sample_size\"],\n", - " scaling_factor=config[\"scaling_factor\"],\n", - " up_block_out_channels=config[\"up_block_out_channels\"],\n", - " up_block_types=config[\"up_block_types\"],\n", - " layers_per_down_block = 2,\n", - " layers_per_up_block = 2\n", - ")\n", - "\n", - "# Устанавливаем mean и std для латентов\n", - "#vae.latents_mean = latents_mean\n", - "#vae.latents_std = latents_std\n", - "\n", - "vae.save_pretrained(\"simple_vae\")\n", - "print(vae)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "290a6758-5aa8-47a4-ba2f-ece8abf6df88", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n", - "Перенос весов: 100%|██████████| 228/228 [00:00<00:00, 192647.32it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "✗ Несовпадение размеров: decoder.conv_in.weight (torch.Size([512, 16, 3, 3])) -> decoder.conv_in.weight (torch.Size([768, 16, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.conv_in.bias (torch.Size([512])) -> decoder.conv_in.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm1.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.0.conv1.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.conv1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm2.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm2.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.0.conv2.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.conv2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.norm1.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.norm1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.1.conv1.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.conv1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.conv1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.norm2.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.norm2.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.norm2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.norm2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.1.conv2.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.conv2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.conv2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.norm1.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.norm1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.2.conv1.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.conv1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.conv1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.norm2.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.norm2.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.norm2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.norm2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.2.conv2.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.conv2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.conv2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.upsamplers.0.conv.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.upsamplers.0.conv.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.0.upsamplers.0.conv.bias (torch.Size([512])) -> decoder.up_blocks.0.upsamplers.0.conv.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.norm1.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.norm1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.0.conv1.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.conv1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.conv1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.norm2.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.norm2.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.norm2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.norm2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.0.conv2.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.conv2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.conv2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.norm1.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.norm1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.1.conv1.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.conv1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.conv1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.norm2.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.norm2.weight (torch.Size([768]))\n", - "✗ Несовпадение размер��в: decoder.up_blocks.1.resnets.1.norm2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.norm2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.1.conv2.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.conv2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.conv2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.norm1.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.norm1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.2.conv1.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.conv1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.conv1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.norm2.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.norm2.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.norm2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.norm2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.2.conv2.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.conv2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.conv2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.upsamplers.0.conv.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.upsamplers.0.conv.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.1.upsamplers.0.conv.bias (torch.Size([512])) -> decoder.up_blocks.1.upsamplers.0.conv.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.norm1.weight (torch.Size([512])) -> decoder.up_blocks.2.resnets.0.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.norm1.bias (torch.Size([512])) -> decoder.up_blocks.2.resnets.0.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv1.weight (torch.Size([256, 512, 3, 3])) -> decoder.up_blocks.2.resnets.0.conv1.weight (torch.Size([512, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.conv1.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.norm2.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.norm2.weight (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.norm2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.norm2.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv2.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.conv2.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv_shortcut.weight (torch.Size([256, 512, 1, 1])) -> decoder.up_blocks.2.resnets.0.conv_shortcut.weight (torch.Size([512, 768, 1, 1]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv_shortcut.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.conv_shortcut.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.norm1.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.norm1.weight (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.norm1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.norm1.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.conv1.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.conv1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.conv1.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.norm2.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.norm2.weight (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.norm2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.norm2.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.conv2.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.conv2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.conv2.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.norm1.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.norm1.weight (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.norm1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.norm1.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.conv1.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.2.conv1.weight (torch.Size([512, 512, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.conv1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.conv1.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.norm2.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.norm2.weight (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.norm2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.norm2.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.conv2.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.2.conv2.weight (torch.Size([512, 512, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.conv2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.conv2.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.upsamplers.0.conv.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.upsamplers.0.conv.weight (torch.Size([512, 512, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.2.upsamplers.0.conv.bias (torch.Size([256])) -> decoder.up_blocks.2.upsamplers.0.conv.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.norm1.weight (torch.Size([256])) -> decoder.up_blocks.3.resnets.0.norm1.weight (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.norm1.bias (torch.Size([256])) -> decoder.up_blocks.3.resnets.0.norm1.bias (torch.Size([512]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv1.weight (torch.Size([128, 256, 3, 3])) -> decoder.up_blocks.3.resnets.0.conv1.weight (torch.Size([256, 512, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.conv1.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.norm2.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.norm2.weight (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.norm2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.norm2.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv2.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.0.conv2.weight (torch.Size([256, 256, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.conv2.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv_shortcut.weight (torch.Size([128, 256, 1, 1])) -> decoder.up_blocks.3.resnets.0.conv_shortcut.weight (torch.Size([256, 512, 1, 1]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv_shortcut.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.conv_shortcut.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.norm1.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.norm1.weight (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.norm1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.norm1.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.conv1.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.1.conv1.weight (torch.Size([256, 256, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.conv1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.conv1.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.norm2.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.norm2.weight (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.norm2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.norm2.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.conv2.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.1.conv2.weight (torch.Size([256, 256, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.conv2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.conv2.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.norm1.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.norm1.weight (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.norm1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.norm1.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.conv1.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.2.conv1.weight (torch.Size([256, 256, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.conv1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.conv1.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.norm2.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.norm2.weight (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.norm2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.norm2.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.conv2.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.2.conv2.weight (torch.Size([256, 256, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.conv2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.conv2.bias (torch.Size([256]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.0.norm1.weight (torch.Size([512])) -> decoder.mid_block.resnets.0.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.0.norm1.bias (torch.Size([512])) -> decoder.mid_block.resnets.0.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.mid_block.resnets.0.conv1.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.0.conv1.bias (torch.Size([512])) -> decoder.mid_block.resnets.0.conv1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.0.norm2.weight (torch.Size([512])) -> decoder.mid_block.resnets.0.norm2.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.0.norm2.bias (torch.Size([512])) -> decoder.mid_block.resnets.0.norm2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.mid_block.resnets.0.conv2.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.0.conv2.bias (torch.Size([512])) -> decoder.mid_block.resnets.0.conv2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.1.norm1.weight (torch.Size([512])) -> decoder.mid_block.resnets.1.norm1.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.1.norm1.bias (torch.Size([512])) -> decoder.mid_block.resnets.1.norm1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.mid_block.resnets.1.conv1.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.1.conv1.bias (torch.Size([512])) -> decoder.mid_block.resnets.1.conv1.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.1.norm2.weight (torch.Size([512])) -> decoder.mid_block.resnets.1.norm2.weight (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.1.norm2.bias (torch.Size([512])) -> decoder.mid_block.resnets.1.norm2.bias (torch.Size([768]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.mid_block.resnets.1.conv2.weight (torch.Size([768, 768, 3, 3]))\n", - "✗ Несовпадение размеров: decoder.mid_block.resnets.1.conv2.bias (torch.Size([512])) -> decoder.mid_block.resnets.1.conv2.bias (torch.Size([768]))\n", - "Статистика переноса: {'перенесено': 104, 'несовпадение_размеров': 124, 'пропущено': 0}\n", - "Неперенесенные ключи в новой модели:\n", - "decoder.condition_encoder.layers.0.bias\n", - "decoder.condition_encoder.layers.0.weight\n", - "decoder.condition_encoder.layers.1.bias\n", - "decoder.condition_encoder.layers.1.weight\n", - "decoder.condition_encoder.layers.2.bias\n", - "decoder.condition_encoder.layers.2.weight\n", - "decoder.condition_encoder.layers.3.bias\n", - "decoder.condition_encoder.layers.3.weight\n", - "decoder.condition_encoder.layers.4.bias\n", - "decoder.condition_encoder.layers.4.weight\n", - "decoder.conv_in.bias\n", - "decoder.conv_in.weight\n", - "decoder.mid_block.attentions.0.group_norm.bias\n", - "decoder.mid_block.attentions.0.group_norm.weight\n", - "decoder.mid_block.attentions.0.to_k.bias\n", - "decoder.mid_block.attentions.0.to_k.weight\n", - "decoder.mid_block.attentions.0.to_out.0.bias\n", - "decoder.mid_block.attentions.0.to_out.0.weight\n", - "decoder.mid_block.attentions.0.to_q.bias\n", - "decoder.mid_block.attentions.0.to_q.weight\n", - "decoder.mid_block.attentions.0.to_v.bias\n", - "decoder.mid_block.attentions.0.to_v.weight\n", - "decoder.mid_block.resnets.0.conv1.bias\n", - "decoder.mid_block.resnets.0.conv1.weight\n", - "decoder.mid_block.resnets.0.conv2.bias\n", - "decoder.mid_block.resnets.0.conv2.weight\n", - "decoder.mid_block.resnets.0.norm1.bias\n", - "decoder.mid_block.resnets.0.norm1.weight\n", - "decoder.mid_block.resnets.0.norm2.bias\n", - "decoder.mid_block.resnets.0.norm2.weight\n", - "decoder.mid_block.resnets.1.conv1.bias\n", - "decoder.mid_block.resnets.1.conv1.weight\n", - "decoder.mid_block.resnets.1.conv2.bias\n", - "decoder.mid_block.resnets.1.conv2.weight\n", - "decoder.mid_block.resnets.1.norm1.bias\n", - "decoder.mid_block.resnets.1.norm1.weight\n", - "decoder.mid_block.resnets.1.norm2.bias\n", - "decoder.mid_block.resnets.1.norm2.weight\n", - "decoder.up_blocks.0.resnets.0.conv1.bias\n", - "decoder.up_blocks.0.resnets.0.conv1.weight\n", - "decoder.up_blocks.0.resnets.0.conv2.bias\n", - "decoder.up_blocks.0.resnets.0.conv2.weight\n", - "decoder.up_blocks.0.resnets.0.norm1.bias\n", - "decoder.up_blocks.0.resnets.0.norm1.weight\n", - "decoder.up_blocks.0.resnets.0.norm2.bias\n", - "decoder.up_blocks.0.resnets.0.norm2.weight\n", - "decoder.up_blocks.0.resnets.1.conv1.bias\n", - "decoder.up_blocks.0.resnets.1.conv1.weight\n", - "decoder.up_blocks.0.resnets.1.conv2.bias\n", - "decoder.up_blocks.0.resnets.1.conv2.weight\n", - "decoder.up_blocks.0.resnets.1.norm1.bias\n", - "decoder.up_blocks.0.resnets.1.norm1.weight\n", - "decoder.up_blocks.0.resnets.1.norm2.bias\n", - "decoder.up_blocks.0.resnets.1.norm2.weight\n", - "decoder.up_blocks.0.resnets.2.conv1.bias\n", - "decoder.up_blocks.0.resnets.2.conv1.weight\n", - "decoder.up_blocks.0.resnets.2.conv2.bias\n", - "decoder.up_blocks.0.resnets.2.conv2.weight\n", - "decoder.up_blocks.0.resnets.2.norm1.bias\n", - "decoder.up_blocks.0.resnets.2.norm1.weight\n", - "decoder.up_blocks.0.resnets.2.norm2.bias\n", - "decoder.up_blocks.0.resnets.2.norm2.weight\n", - "decoder.up_blocks.0.upsamplers.0.conv.bias\n", - "decoder.up_blocks.0.upsamplers.0.conv.weight\n", - "decoder.up_blocks.1.resnets.0.conv1.bias\n", - "decoder.up_blocks.1.resnets.0.conv1.weight\n", - "decoder.up_blocks.1.resnets.0.conv2.bias\n", - "decoder.up_blocks.1.resnets.0.conv2.weight\n", - "decoder.up_blocks.1.resnets.0.norm1.bias\n", - "decoder.up_blocks.1.resnets.0.norm1.weight\n", - "decoder.up_blocks.1.resnets.0.norm2.bias\n", - "decoder.up_blocks.1.resnets.0.norm2.weight\n", - "decoder.up_blocks.1.resnets.1.conv1.bias\n", - "decoder.up_blocks.1.resnets.1.conv1.weight\n", - "decoder.up_blocks.1.resnets.1.conv2.bias\n", - "decoder.up_blocks.1.resnets.1.conv2.weight\n", - "decoder.up_blocks.1.resnets.1.norm1.bias\n", - "decoder.up_blocks.1.resnets.1.norm1.weight\n", - "decoder.up_blocks.1.resnets.1.norm2.bias\n", - "decoder.up_blocks.1.resnets.1.norm2.weight\n", - "decoder.up_blocks.1.resnets.2.conv1.bias\n", - "decoder.up_blocks.1.resnets.2.conv1.weight\n", - "decoder.up_blocks.1.resnets.2.conv2.bias\n", - "decoder.up_blocks.1.resnets.2.conv2.weight\n", - "decoder.up_blocks.1.resnets.2.norm1.bias\n", - "decoder.up_blocks.1.resnets.2.norm1.weight\n", - "decoder.up_blocks.1.resnets.2.norm2.bias\n", - "decoder.up_blocks.1.resnets.2.norm2.weight\n", - "decoder.up_blocks.1.upsamplers.0.conv.bias\n", - "decoder.up_blocks.1.upsamplers.0.conv.weight\n", - "decoder.up_blocks.2.resnets.0.conv1.bias\n", - "decoder.up_blocks.2.resnets.0.conv1.weight\n", - "decoder.up_blocks.2.resnets.0.conv2.bias\n", - "decoder.up_blocks.2.resnets.0.conv2.weight\n", - "decoder.up_blocks.2.resnets.0.conv_shortcut.bias\n", - "decoder.up_blocks.2.resnets.0.conv_shortcut.weight\n", - "decoder.up_blocks.2.resnets.0.norm1.bias\n", - "decoder.up_blocks.2.resnets.0.norm1.weight\n", - "decoder.up_blocks.2.resnets.0.norm2.bias\n", - "decoder.up_blocks.2.resnets.0.norm2.weight\n", - "decoder.up_blocks.2.resnets.1.conv1.bias\n", - "decoder.up_blocks.2.resnets.1.conv1.weight\n", - "decoder.up_blocks.2.resnets.1.conv2.bias\n", - "decoder.up_blocks.2.resnets.1.conv2.weight\n", - "decoder.up_blocks.2.resnets.1.norm1.bias\n", - "decoder.up_blocks.2.resnets.1.norm1.weight\n", - "decoder.up_blocks.2.resnets.1.norm2.bias\n", - "decoder.up_blocks.2.resnets.1.norm2.weight\n", - "decoder.up_blocks.2.resnets.2.conv1.bias\n", - "decoder.up_blocks.2.resnets.2.conv1.weight\n", - "decoder.up_blocks.2.resnets.2.conv2.bias\n", - "decoder.up_blocks.2.resnets.2.conv2.weight\n", - "decoder.up_blocks.2.resnets.2.norm1.bias\n", - "decoder.up_blocks.2.resnets.2.norm1.weight\n", - "decoder.up_blocks.2.resnets.2.norm2.bias\n", - "decoder.up_blocks.2.resnets.2.norm2.weight\n", - "decoder.up_blocks.2.upsamplers.0.conv.bias\n", - "decoder.up_blocks.2.upsamplers.0.conv.weight\n", - "decoder.up_blocks.3.resnets.0.conv1.bias\n", - "decoder.up_blocks.3.resnets.0.conv1.weight\n", - "decoder.up_blocks.3.resnets.0.conv2.bias\n", - "decoder.up_blocks.3.resnets.0.conv2.weight\n", - "decoder.up_blocks.3.resnets.0.conv_shortcut.bias\n", - "decoder.up_blocks.3.resnets.0.conv_shortcut.weight\n", - "decoder.up_blocks.3.resnets.0.norm1.bias\n", - "decoder.up_blocks.3.resnets.0.norm1.weight\n", - "decoder.up_blocks.3.resnets.0.norm2.bias\n", - "decoder.up_blocks.3.resnets.0.norm2.weight\n", - "decoder.up_blocks.3.resnets.1.conv1.bias\n", - "decoder.up_blocks.3.resnets.1.conv1.weight\n", - "decoder.up_blocks.3.resnets.1.conv2.bias\n", - "decoder.up_blocks.3.resnets.1.conv2.weight\n", - "decoder.up_blocks.3.resnets.1.norm1.bias\n", - "decoder.up_blocks.3.resnets.1.norm1.weight\n", - "decoder.up_blocks.3.resnets.1.norm2.bias\n", - "decoder.up_blocks.3.resnets.1.norm2.weight\n", - "decoder.up_blocks.3.resnets.2.conv1.bias\n", - "decoder.up_blocks.3.resnets.2.conv1.weight\n", - "decoder.up_blocks.3.resnets.2.conv2.bias\n", - "decoder.up_blocks.3.resnets.2.conv2.weight\n", - "decoder.up_blocks.3.resnets.2.norm1.bias\n", - "decoder.up_blocks.3.resnets.2.norm1.weight\n", - "decoder.up_blocks.3.resnets.2.norm2.bias\n", - "decoder.up_blocks.3.resnets.2.norm2.weight\n", - "decoder.up_blocks.3.upsamplers.0.conv.bias\n", - "decoder.up_blocks.3.upsamplers.0.conv.weight\n", - "decoder.up_blocks.4.resnets.0.conv1.bias\n", - "decoder.up_blocks.4.resnets.0.conv1.weight\n", - "decoder.up_blocks.4.resnets.0.conv2.bias\n", - "decoder.up_blocks.4.resnets.0.conv2.weight\n", - "decoder.up_blocks.4.resnets.0.conv_shortcut.bias\n", - "decoder.up_blocks.4.resnets.0.conv_shortcut.weight\n", - "decoder.up_blocks.4.resnets.0.norm1.bias\n", - "decoder.up_blocks.4.resnets.0.norm1.weight\n", - "decoder.up_blocks.4.resnets.0.norm2.bias\n", - "decoder.up_blocks.4.resnets.0.norm2.weight\n", - "decoder.up_blocks.4.resnets.1.conv1.bias\n", - "decoder.up_blocks.4.resnets.1.conv1.weight\n", - "decoder.up_blocks.4.resnets.1.conv2.bias\n", - "decoder.up_blocks.4.resnets.1.conv2.weight\n", - "decoder.up_blocks.4.resnets.1.norm1.bias\n", - "decoder.up_blocks.4.resnets.1.norm1.weight\n", - "decoder.up_blocks.4.resnets.1.norm2.bias\n", - "decoder.up_blocks.4.resnets.1.norm2.weight\n", - "decoder.up_blocks.4.resnets.2.conv1.bias\n", - "decoder.up_blocks.4.resnets.2.conv1.weight\n", - "decoder.up_blocks.4.resnets.2.conv2.bias\n", - "decoder.up_blocks.4.resnets.2.conv2.weight\n", - "decoder.up_blocks.4.resnets.2.norm1.bias\n", - "decoder.up_blocks.4.resnets.2.norm1.weight\n", - "decoder.up_blocks.4.resnets.2.norm2.bias\n", - "decoder.up_blocks.4.resnets.2.norm2.weight\n", - "encoder.mid_block.attentions.0.group_norm.bias\n", - "encoder.mid_block.attentions.0.group_norm.weight\n", - "encoder.mid_block.attentions.0.to_k.bias\n", - "encoder.mid_block.attentions.0.to_k.weight\n", - "encoder.mid_block.attentions.0.to_out.0.bias\n", - "encoder.mid_block.attentions.0.to_out.0.weight\n", - "encoder.mid_block.attentions.0.to_q.bias\n", - "encoder.mid_block.attentions.0.to_q.weight\n", - "encoder.mid_block.attentions.0.to_v.bias\n", - "encoder.mid_block.attentions.0.to_v.weight\n" - ] - } - ], - "source": [ - "import torch\n", - "from diffusers import AsymmetricAutoencoderKL,AutoencoderKL\n", - "from tqdm import tqdm\n", - "import torch.nn.init as init\n", - "\n", - "def log(message):\n", - " print(message)\n", - "\n", - "def initialize_mid_block_weights(state_dict, device, dtype):\n", - " # Инициализация весов для mid block 0 с размерностью 512\n", - " state_dict['encoder.mid_block.attentions.0.group_norm.weight'] = torch.ones(512, device=device, dtype=dtype)\n", - " state_dict['encoder.mid_block.attentions.0.group_norm.bias'] = torch.zeros(512, device=device, dtype=dtype)\n", - " \n", - " # Удаляем ключи для второго блока внимания, так как он не существует в архитектуре\n", - " #if 'encoder.mid_block.attentions.1.group_norm.weight' in state_dict:\n", - " # del state_dict['encoder.mid_block.attentions.1.group_norm.weight']\n", - " #if 'encoder.mid_block.attentions.1.group_norm.bias' in state_dict:\n", - " # del state_dict['encoder.mid_block.attentions.1.group_norm.bias']\n", - " \n", - " return state_dict\n", - "\n", - "def main():\n", - " checkpoint_path_old = \"AiArtLab/sdxs\"\n", - " checkpoint_path_new = \"simple_vae\"\n", - " device = \"cuda\"\n", - " dtype = torch.float16\n", - "\n", - " # Загрузка моделей\n", - " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder=\"vae\",variant=\"fp16\").to(device, dtype=dtype)\n", - " new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n", - "\n", - " old_state_dict = old_unet.state_dict()\n", - " new_state_dict = new_unet.state_dict()\n", - "\n", - " transferred_state_dict = {}\n", - " transfer_stats = {\n", - " \"перенесено\": 0,\n", - " \"несовпадение_размеров\": 0,\n", - " \"пропущено\": 0\n", - " }\n", - "\n", - " transferred_keys = set()\n", - "\n", - " # Обрабатываем каждый ключ старой модели\n", - " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n", - " new_key = old_key\n", - "\n", - " if new_key in new_state_dict:\n", - " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n", - " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n", - " transferred_keys.add(new_key)\n", - " transfer_stats[\"перенесено\"] += 1\n", - " else:\n", - " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n", - " transfer_stats[\"несовпадение_размеров\"] += 1\n", - " else:\n", - " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n", - " transfer_stats[\"пропущено\"] += 1\n", - "\n", - " # Обновляем состояние новой модели перенесенными весами\n", - " new_state_dict.update(transferred_state_dict)\n", - " \n", - " # Инициализируем веса для нового mid блока\n", - " new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n", - " \n", - " new_unet.load_state_dict(new_state_dict)\n", - " new_unet.save_pretrained(\"vae\")\n", - "\n", - " # Получаем список неперенесенных ключей\n", - " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n", - "\n", - " print(\"Статистика переноса:\", transfer_stats)\n", - " print(\"Неперенесенные ключи в новой модели:\")\n", - " for key in non_transferred_keys:\n", - " print(key)\n", - "\n", - "if __name__ == \"__main__\":\n", - " main()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "cb9f4324-96b7-4f49-8d32-61c1294739c7", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Weights before normalization:\n", - "quant_conv.weight: torch.Size([32, 32, 1, 1])\n", - "quant_conv.bias: torch.Size([32])\n", - "post_quant_conv.weight: torch.Size([16, 16, 1, 1])\n", - "post_quant_conv.bias: torch.Size([16])\n" - ] - } - ], - "source": [ - "\n", - "import torch\n", - "from diffusers import AsymmetricAutoencoderKL\n", - "from tqdm import tqdm\n", - "\n", - "def normalize_weights3(state_dict, latents_mean, latents_std):\n", - " device = next(iter(state_dict.values())).device\n", - " dtype = next(iter(state_dict.values())).dtype\n", - " \n", - " # Преобразуем в тензоры\n", - " latents_mean = torch.tensor(latents_mean, device=device, dtype=dtype)\n", - " latents_std = torch.tensor(latents_std, device=device, dtype=dtype)\n", - "\n", - " # Нормализация для quant_conv (32 -> 32 каналов)\n", - " if 'quant_conv.weight' in state_dict:\n", - " weight = state_dict['quant_conv.weight'] # [32, 32, 1, 1]\n", - " # Применяем нормализацию к выходным каналам\n", - " for i in range(weight.size(0)):\n", - " weight[i] = weight[i] / latents_std[i % len(latents_std)]\n", - " \n", - " if 'quant_conv.bias' in state_dict:\n", - " bias = state_dict['quant_conv.bias'] # [32]\n", - " for i in range(bias.size(0)):\n", - " bias[i] = (bias[i] - latents_mean[i % len(latents_mean)]) / latents_std[i % len(latents_std)]\n", - "\n", - " # Нормализация для post_quant_conv (16 -> 16 каналов)\n", - " if 'post_quant_conv.weight' in state_dict:\n", - " weight = state_dict['post_quant_conv.weight'] # [16, 16, 1, 1]\n", - " # Применяем нормализацию к входным каналам\n", - " for i in range(weight.size(1)):\n", - " weight[:, i] = weight[:, i] * latents_std[i]\n", - " \n", - " if 'post_quant_conv.bias' in state_dict:\n", - " bias = state_dict['post_quant_conv.bias'] # [16]\n", - " for i in range(bias.size(0)):\n", - " bias[i] = bias[i] * latents_std[i] + latents_mean[i]\n", - "\n", - " return state_dict\n", - "\n", - "def normalize_weights(state_dict, latents_mean, latents_std):\n", - " device = next(iter(state_dict.values())).device\n", - " dtype = next(iter(state_dict.values())).dtype\n", - " \n", - " # Преобразуем в тензоры\n", - " latents_mean = torch.tensor(latents_mean, device=device, dtype=dtype)\n", - " latents_std = torch.tensor(latents_std, device=device, dtype=dtype)\n", - "\n", - " # Нормализация для quant_conv (32 -> 32 каналов)\n", - " # На выходе энкодера: (x - mean) / std\n", - " if 'quant_conv.weight' in state_dict:\n", - " weight = state_dict['quant_conv.weight'] # [32, 32, 1, 1]\n", - " # Нормализуем выходные каналы\n", - " for i in range(weight.size(0)):\n", - " if i < len(latents_std):\n", - " weight[i] = weight[i] / latents_std[i]\n", - " \n", - " if 'quant_conv.bias' in state_dict:\n", - " bias = state_dict['quant_conv.bias'] # [32]\n", - " for i in range(bias.size(0)):\n", - " if i < len(latents_mean):\n", - " # Сначала применяем сдвиг, потом масштабирование\n", - " bias[i] = -latents_mean[i] / latents_std[i]\n", - "\n", - " # Нормализация для post_quant_conv (16 -> 16 каналов)\n", - " # На входе декодера: x * std + mean\n", - " if 'post_quant_conv.weight' in state_dict:\n", - " weight = state_dict['post_quant_conv.weight'] # [16, 16, 1, 1]\n", - " # Нормализуем входные каналы\n", - " for i in range(weight.size(1)):\n", - " if i < len(latents_std):\n", - " weight[:, i] = weight[:, i] * latents_std[i]\n", - " \n", - " if 'post_quant_conv.bias' in state_dict:\n", - " bias = state_dict['post_quant_conv.bias'] # [16]\n", - " for i in range(bias.size(0)):\n", - " if i < len(latents_mean):\n", - " bias[i] = bias[i] + latents_mean[i]\n", - "\n", - " return state_dict\n", - "\n", - "def main():\n", - " # Путь к модели\n", - " model_path = \"vae\"\n", - " device = \"cuda\"\n", - " dtype = torch.float16\n", - "\n", - " # Ваши mean и std\n", - " latents_mean = [0.2539, 0.1431, 0.1484, -0.3048, -0.0985, -0.162, 0.1403, 0.2034, -0.1419, 0.2646, 0.0655, 0.0061, 0.1555, 0.0506, 0.0129, -0.1948]\n", - "\n", - " latents_std = [0.8123, 0.7376, 0.7354, 1.1827, 0.8387, 0.8735, 0.8705, 0.8142, 0.8076, 0.7409, 0.7655, 0.8731, 0.8087, 0.7058, 0.8087, 0.7615]\n", - "\n", - " # Загружаем модель\n", - " model = AsymmetricAutoencoderKL.from_pretrained(model_path).to(device, dtype=dtype)\n", - " \n", - " # Получаем state dict\n", - " state_dict = model.state_dict()\n", - "\n", - " # Выводим информацию о весах до нормализации\n", - " print(\"\\nWeights before normalization:\")\n", - " for key in ['quant_conv.weight', 'quant_conv.bias', 'post_quant_conv.weight', 'post_quant_conv.bias']:\n", - " if key in state_dict:\n", - " print(f\"{key}: {state_dict[key].shape}\")\n", - "\n", - " # Нормализуем веса\n", - " normalized_state_dict = normalize_weights(state_dict, latents_mean, latents_std)\n", - " normalized_state_dict = initialize_mid_block_weights(normalized_state_dict, device, dtype)\n", - "\n", - " # Загружаем нормализованные веса обратно в модель\n", - " model.load_state_dict(normalized_state_dict)\n", - "\n", - " # Сохраняем модель\n", - " model.save_pretrained(\"vaenorm\")\n", - "\n", - "if __name__ == \"__main__\":\n", - " main()" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "43b0cb3d-776d-414d-9e08-b136a31f27a5", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Processing decoder.up_blocks.3.resnets.0.norm1.weight\n", - "Old shape: torch.Size([256])\n", - "Target shape: torch.Size([512])\n", - "Interpolating tensor of shape torch.Size([256]) to target shape torch.Size([512])\n", - "Extending 1D tensor from 256 to 512\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.norm1.bias\n", - "Old shape: torch.Size([256])\n", - "Target shape: torch.Size([512])\n", - "Interpolating tensor of shape torch.Size([256]) to target shape torch.Size([512])\n", - "Extending 1D tensor from 256 to 512\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.conv1.weight\n", - "Old shape: torch.Size([128, 256, 3, 3])\n", - "Target shape: torch.Size([256, 512, 3, 3])\n", - "Interpolating tensor of shape torch.Size([128, 256, 3, 3]) to target shape torch.Size([256, 512, 3, 3])\n", - "Copying existing weights: min_out=128, min_in=256\n", - "Extending output channels from 128 to 256\n", - "Extending input channels from 256 to 512\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.conv1.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.norm2.weight\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.norm2.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.conv2.weight\n", - "Old shape: torch.Size([128, 128, 3, 3])\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", - "Copying existing weights: min_out=128, min_in=128\n", - "Extending output channels from 128 to 256\n", - "Extending input channels from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.conv2.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.conv_shortcut.weight\n", - "Old shape: torch.Size([128, 256, 1, 1])\n", - "Target shape: torch.Size([256, 512, 1, 1])\n", - "Interpolating tensor of shape torch.Size([128, 256, 1, 1]) to target shape torch.Size([256, 512, 1, 1])\n", - "Copying existing weights: min_out=128, min_in=256\n", - "Extending output channels from 128 to 256\n", - "Extending input channels from 256 to 512\n", - "\n", - "Processing decoder.up_blocks.3.resnets.0.conv_shortcut.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.1.norm1.weight\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.1.norm1.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.1.conv1.weight\n", - "Old shape: torch.Size([128, 128, 3, 3])\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", - "Copying existing weights: min_out=128, min_in=128\n", - "Extending output channels from 128 to 256\n", - "Extending input channels from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.1.conv1.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.1.norm2.weight\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.1.norm2.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.1.conv2.weight\n", - "Old shape: torch.Size([128, 128, 3, 3])\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", - "Copying existing weights: min_out=128, min_in=128\n", - "Extending output channels from 128 to 256\n", - "Extending input channels from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.1.conv2.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.2.norm1.weight\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.2.norm1.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.2.conv1.weight\n", - "Old shape: torch.Size([128, 128, 3, 3])\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", - "Copying existing weights: min_out=128, min_in=128\n", - "Extending output channels from 128 to 256\n", - "Extending input channels from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.2.conv1.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.2.norm2.weight\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.2.norm2.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.2.conv2.weight\n", - "Old shape: torch.Size([128, 128, 3, 3])\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", - "Copying existing weights: min_out=128, min_in=128\n", - "Extending output channels from 128 to 256\n", - "Extending input channels from 128 to 256\n", - "\n", - "Processing decoder.up_blocks.3.resnets.2.conv2.bias\n", - "Old shape: torch.Size([128])\n", - "Target shape: torch.Size([256])\n", - "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", - "Extending 1D tensor from 128 to 256\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Перенос весов: 100%|██████████| 286/286 [00:00<00:00, 163407.02it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.weight -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.weight -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_q.weight -> torch.Size([512, 512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_q.weight -> torch.Size([512, 512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_q.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_q.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_k.weight -> torch.Size([512, 512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_k.weight -> torch.Size([512, 512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_k.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_k.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_v.weight -> torch.Size([512, 512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_v.weight -> torch.Size([512, 512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_v.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_v.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_out.0.weight -> torch.Size([512, 512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_out.0.weight -> torch.Size([512, 512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_out.0.bias -> torch.Size([512])\n", - "? Ключ пропущен: encoder.mid_block.attentions.0.to_out.0.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.conv_in.weight -> torch.Size([768, 16, 3, 3])\n", - "? Ключ пропущен: decoder.conv_in.weight -> torch.Size([768, 16, 3, 3])\n", - "? Ключ пропущен: decoder.conv_in.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.conv_in.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.upsamplers.0.conv.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.upsamplers.0.conv.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.0.upsamplers.0.conv.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.0.upsamplers.0.conv.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.upsamplers.0.conv.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.upsamplers.0.conv.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.1.upsamplers.0.conv.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.1.upsamplers.0.conv.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv1.weight -> torch.Size([512, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv1.weight -> torch.Size([512, 768, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm2.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm2.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv2.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv2.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv_shortcut.weight -> torch.Size([512, 768, 1, 1])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv_shortcut.weight -> torch.Size([512, 768, 1, 1])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv_shortcut.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv_shortcut.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm1.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm1.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv1.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv1.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm2.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm2.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv2.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv2.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm1.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm1.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv1.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv1.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv1.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm2.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm2.weight -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv2.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv2.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.upsamplers.0.conv.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.upsamplers.0.conv.weight -> torch.Size([512, 512, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.2.upsamplers.0.conv.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.2.upsamplers.0.conv.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.up_blocks.3.upsamplers.0.conv.weight -> torch.Size([256, 256, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.3.upsamplers.0.conv.weight -> torch.Size([256, 256, 3, 3])\n", - "? Ключ пропущен: decoder.up_blocks.3.upsamplers.0.conv.bias -> torch.Size([256])\n", - "? Ключ пропущен: decoder.up_blocks.3.upsamplers.0.conv.bias -> torch.Size([256])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.group_norm.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.group_norm.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.group_norm.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.group_norm.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_q.weight -> torch.Size([768, 768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_q.weight -> torch.Size([768, 768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_q.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_q.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_k.weight -> torch.Size([768, 768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_k.weight -> torch.Size([768, 768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_k.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_k.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_v.weight -> torch.Size([768, 768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_v.weight -> torch.Size([768, 768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_v.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_v.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_out.0.weight -> torch.Size([768, 768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_out.0.weight -> torch.Size([768, 768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_out.0.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.attentions.0.to_out.0.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.0.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.norm1.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.norm1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.conv1.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.norm2.weight -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.norm2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.mid_block.resnets.1.conv2.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.0.weight -> torch.Size([128, 3, 3, 3])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.0.weight -> torch.Size([128, 3, 3, 3])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.0.bias -> torch.Size([128])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.0.bias -> torch.Size([128])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.1.weight -> torch.Size([256, 128, 3, 3])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.1.weight -> torch.Size([256, 128, 3, 3])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.1.bias -> torch.Size([256])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.1.bias -> torch.Size([256])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.2.weight -> torch.Size([512, 256, 4, 4])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.2.weight -> torch.Size([512, 256, 4, 4])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.2.bias -> torch.Size([512])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.3.weight -> torch.Size([768, 512, 4, 4])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.3.weight -> torch.Size([768, 512, 4, 4])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.3.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.3.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.4.weight -> torch.Size([768, 768, 4, 4])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.4.weight -> torch.Size([768, 768, 4, 4])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.4.bias -> torch.Size([768])\n", - "? Ключ пропущен: decoder.condition_encoder.layers.4.bias -> torch.Size([768])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "ename": "RuntimeError", - "evalue": "Error(s) in loading state_dict for AsymmetricAutoencoderKL:\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[6], line 164\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28mprint\u001b[39m(key)\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 164\u001b[0m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[6], line 153\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;66;03m# Инициализируем веса для нового mid блока\u001b[39;00m\n\u001b[1;32m 150\u001b[0m new_state_dict \u001b[38;5;241m=\u001b[39m initialize_mid_block_weights(new_state_dict, device, dtype)\n\u001b[0;32m--> 153\u001b[0m \u001b[43mnew_unet\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew_state_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m new_unet\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvae\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 156\u001b[0m \u001b[38;5;66;03m# Выводим статистику\u001b[39;00m\n", - "File \u001b[0;32m~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:2581\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2573\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 2574\u001b[0m \u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 2575\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2576\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)\n\u001b[1;32m 2577\u001b[0m ),\n\u001b[1;32m 2578\u001b[0m )\n\u001b[1;32m 2580\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 2581\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 2582\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)\n\u001b[1;32m 2584\u001b[0m )\n\u001b[1;32m 2585\u001b[0m )\n\u001b[1;32m 2586\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", - "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for AsymmetricAutoencoderKL:\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128])." - ] - } - ], - "source": [ - "import torch\n", - "from diffusers import AsymmetricAutoencoderKL,AutoencoderKL\n", - "from tqdm import tqdm\n", - "import torch.nn.init as init\n", - "\n", - "def log(message):\n", - " print(message)\n", - "\n", - "def initialize_mid_block_weights(state_dict, device, dtype):\n", - " # Инициализация весов для mid block 0 с размерностью 512\n", - " state_dict['encoder.mid_block.attentions.0.group_norm.weight'] = torch.ones(512, device=device, dtype=dtype)\n", - " state_dict['encoder.mid_block.attentions.0.group_norm.bias'] = torch.zeros(512, device=device, dtype=dtype)\n", - " \n", - " # Удаляем ключи для второго блока внимания, так как он не существует в архитектуре\n", - " #if 'encoder.mid_block.attentions.1.group_norm.weight' in state_dict:\n", - " # del state_dict['encoder.mid_block.attentions.1.group_norm.weight']\n", - " #if 'encoder.mid_block.attentions.1.group_norm.bias' in state_dict:\n", - " # del state_dict['encoder.mid_block.attentions.1.group_norm.bias']\n", - " \n", - " return state_dict\n", - " \n", - "def interpolate_tensor(tensor, target_shape):\n", - " \"\"\"Интерполяция тензора до целевой формы\"\"\"\n", - " print(f\"Interpolating tensor of shape {tensor.shape} to target shape {target_shape}\")\n", - " \n", - " if len(tensor.shape) == 4: # Для свёрточных слоев\n", - " out_channels, in_channels, k1, k2 = target_shape\n", - " \n", - " # Создаем новый тензор нужного размера\n", - " result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)\n", - " \n", - " # Копируем существующие веса с масштабированием\n", - " min_out = min(tensor.shape[0], out_channels)\n", - " min_in = min(tensor.shape[1], in_channels)\n", - " \n", - " print(f\"Copying existing weights: min_out={min_out}, min_in={min_in}\")\n", - " \n", - " # Копируем существующие веса\n", - " result[:min_out, :min_in, :, :] = tensor[:min_out, :min_in, :, :]\n", - " \n", - " # Заполняем новые выходные каналы\n", - " if out_channels > min_out:\n", - " print(f\"Extending output channels from {min_out} to {out_channels}\")\n", - " result[min_out:, :min_in, :, :] = result[min_out-1:min_out, :min_in, :, :].repeat(out_channels-min_out, 1, 1, 1)\n", - " \n", - " # Заполняем новые входные каналы\n", - " if in_channels > min_in:\n", - " print(f\"Extending input channels from {min_in} to {in_channels}\")\n", - " result[:, min_in:, :, :] = result[:, min_in-1:min_in, :, :].repeat(1, in_channels-min_in, 1, 1)\n", - " \n", - " return result\n", - " \n", - " else: # Для bias и других 1D тензоров\n", - " # Создаем новый тензор ну��ного размера\n", - " result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)\n", - " \n", - " # Копируем существующие значения\n", - " min_size = min(tensor.shape[0], target_shape[0])\n", - " result[:min_size] = tensor[:min_size]\n", - " \n", - " # Заполняем оставшиеся значения\n", - " if target_shape[0] > min_size:\n", - " print(f\"Extending 1D tensor from {min_size} to {target_shape[0]}\")\n", - " result[min_size:] = result[min_size-1]\n", - " \n", - " return result\n", - "\n", - "def should_interpolate(key):\n", - " \"\"\"Определяет, нужно ли интерполировать веса для данного ключа\"\"\"\n", - " return any(x in key for x in [\n", - " 'conv1', 'conv2', 'conv_shortcut', # свёрточные слои\n", - " 'norm1', 'norm2', 'group_norm', # нормализационные слои\n", - " 'bias', 'weight' # веса и смещения\n", - " ])\n", - " \n", - "def main():\n", - " checkpoint_path_old = \"AiArtLab/sdxs\"\n", - " checkpoint_path_new = \"simple_vae\"\n", - " device = \"cuda\"\n", - " dtype = torch.float16\n", - "\n", - " # Загрузка моделей\n", - " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder=\"vae\",variant=\"fp16\").to(device, dtype=dtype)\n", - " new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n", - "\n", - " old_state_dict = old_unet.state_dict()\n", - " new_state_dict = new_unet.state_dict()\n", - "\n", - " transferred_state_dict = {}\n", - " transfer_stats = {\n", - " \"перенесено\": 0,\n", - " \"несовпадение_размеров\": 0,\n", - " \"пропущено\": 0\n", - " }\n", - "\n", - " transferred_keys = set()\n", - "\n", - " # Сначала найдем все ключи блока 3 и их интерполированные значения\n", - " block3_interpolated = {}\n", - " for key in old_state_dict:\n", - " if 'decoder.up_blocks.3.' in key and should_interpolate(key):\n", - " old_tensor = old_state_dict[key]\n", - " new_key = key\n", - " if new_key in new_state_dict and old_tensor.shape != new_state_dict[new_key].shape:\n", - " print(f\"\\nProcessing {key}\")\n", - " print(f\"Old shape: {old_tensor.shape}\")\n", - " print(f\"Target shape: {new_state_dict[new_key].shape}\")\n", - " interpolated = interpolate_tensor(old_tensor, new_state_dict[new_key].shape)\n", - " block3_interpolated[key] = interpolated\n", - "\n", - " # Обрабатываем каждый ключ новой модели\n", - " for new_key in tqdm(new_state_dict.keys(), desc=\"Перенос весов\"):\n", - " # Случай 1: Прямое соответствие ключей и размеров\n", - " if new_key in old_state_dict and old_state_dict[new_key].shape == new_state_dict[new_key].shape:\n", - " transferred_state_dict[new_key] = old_state_dict[new_key].clone()\n", - " transferred_keys.add(new_key)\n", - " transfer_stats[\"перенесено\"] += 1\n", - " continue\n", - "\n", - " # Случай 2: Блоки 4 и 5 (копируем интерполированные веса блока 3)\n", - " if ('decoder.up_blocks.4.' in new_key or 'decoder.up_blocks.5.' in new_key) and should_interpolate(new_key):\n", - " source_key = new_key.replace('decoder.up_blocks.4.', 'decoder.up_blocks.3.')\n", - " source_key = source_key.replace('decoder.up_blocks.5.', 'decoder.up_blocks.3.')\n", - " \n", - " if source_key in block3_interpolated:\n", - " transferred_state_dict[new_key] = block3_interpolated[source_key].clone()\n", - " transferred_keys.add(new_key)\n", - " transfer_stats[\"перенесено\"] += 1\n", - " continue\n", - "\n", - " # Случай 3: Несовпадение размеров в блоке 3\n", - " if 'decoder.up_blocks.3.' in new_key and new_key in block3_interpolated:\n", - " transferred_state_dict[new_key] = block3_interpolated[new_key].clone()\n", - " transferred_keys.add(new_key)\n", - " transfer_stats[\"перенесено\"] += 1\n", - " continue\n", - "\n", - " # Если ключ не обработан - помечаем как пропущенный\n", - " transfer_stats[\"пропущено\"] += 1\n", - " log(f\"? Ключ пропущен: {new_key} -> {new_state_dict[new_key].shape}\")\n", - "\n", - " # Если ключ не обработан - помечаем как пропущенный\n", - " transfer_stats[\"пропущено\"] += 1\n", - " log(f\"? Ключ пропущен: {new_key} -> {new_state_dict[new_key].shape}\")\n", - "\n", - " # Обновляем состояние новой модели\n", - " new_state_dict.update(transferred_state_dict)\n", - "\n", - " # Инициализируем веса для нового mid блока\n", - " new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n", - "\n", - " \n", - " new_unet.load_state_dict(new_state_dict)\n", - " new_unet.save_pretrained(\"vae\")\n", - "\n", - " # Выводим статистику\n", - " print(\"\\nСтатистика переноса:\", transfer_stats)\n", - " print(\"\\nНеперенесенные ключи:\")\n", - " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n", - " for key in non_transferred_keys:\n", - " print(key)\n", - "\n", - "if __name__ == \"__main__\":\n", - " main()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "b3018b9a-82cf-4435-8afd-4448ff5e83a9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading models...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Processing weights...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 286/286 [00:00<00:00, 106458.20it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Processing key: encoder.conv_in.weight\n", - "Target shape: torch.Size([128, 3, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.conv_in.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.0.norm1.weight\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.0.norm1.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.0.conv1.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.0.conv1.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.0.norm2.weight\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.0.norm2.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.0.conv2.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.0.conv2.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.1.norm1.weight\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.1.norm1.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.1.conv1.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.1.conv1.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.1.norm2.weight\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.1.norm2.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.1.conv2.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.resnets.1.conv2.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.downsamplers.0.conv.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.0.downsamplers.0.conv.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.norm1.weight\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.norm1.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.conv1.weight\n", - "Target shape: torch.Size([256, 128, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.conv1.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.norm2.weight\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.norm2.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.conv2.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.conv2.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.conv_shortcut.weight\n", - "Target shape: torch.Size([256, 128, 1, 1])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.0.conv_shortcut.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.1.norm1.weight\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.1.norm1.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.1.conv1.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.1.conv1.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.1.norm2.weight\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.1.norm2.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.1.conv2.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.resnets.1.conv2.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.downsamplers.0.conv.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.1.downsamplers.0.conv.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.norm1.weight\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.norm1.bias\n", - "Target shape: torch.Size([256])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.conv1.weight\n", - "Target shape: torch.Size([512, 256, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.conv_shortcut.weight\n", - "Target shape: torch.Size([512, 256, 1, 1])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.0.conv_shortcut.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.1.norm1.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.1.norm1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.1.conv1.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.1.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.1.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.1.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.1.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.resnets.1.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.downsamplers.0.conv.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.2.downsamplers.0.conv.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.0.norm1.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.0.norm1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.0.conv1.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.0.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.0.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.0.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.0.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.0.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.1.norm1.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.1.norm1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.1.conv1.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.1.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.1.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.1.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.1.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.down_blocks.3.resnets.1.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.group_norm.weight\n", - "Target shape: torch.Size([512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.group_norm.bias\n", - "Target shape: torch.Size([512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.to_q.weight\n", - "Target shape: torch.Size([512, 512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.to_q.bias\n", - "Target shape: torch.Size([512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.to_k.weight\n", - "Target shape: torch.Size([512, 512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.to_k.bias\n", - "Target shape: torch.Size([512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.to_v.weight\n", - "Target shape: torch.Size([512, 512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.to_v.bias\n", - "Target shape: torch.Size([512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.to_out.0.weight\n", - "Target shape: torch.Size([512, 512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.attentions.0.to_out.0.bias\n", - "Target shape: torch.Size([512])\n", - "Key not found in source model\n", - "\n", - "Processing key: encoder.mid_block.resnets.0.norm1.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.0.norm1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.0.conv1.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.0.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.0.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.0.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.0.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.0.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.1.norm1.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.1.norm1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.1.conv1.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.1.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.1.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.1.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.1.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.mid_block.resnets.1.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.conv_norm_out.weight\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.conv_norm_out.bias\n", - "Target shape: torch.Size([512])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.conv_out.weight\n", - "Target shape: torch.Size([32, 512, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: encoder.conv_out.bias\n", - "Target shape: torch.Size([32])\n", - "Direct copy...\n", - "\n", - "Processing key: decoder.conv_in.weight\n", - "Target shape: torch.Size([768, 16, 3, 3])\n", - "Size mismatch: torch.Size([512, 16, 3, 3]) vs torch.Size([768, 16, 3, 3])\n", - "\n", - "Processing key: decoder.conv_in.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.0.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.0.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.0.conv1.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.0.conv1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.0.norm2.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.0.norm2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.0.conv2.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.0.conv2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.1.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.1.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.1.conv1.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.1.conv1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.1.norm2.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.1.norm2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.1.conv2.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.1.conv2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.2.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.2.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.2.conv1.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.2.conv1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.2.norm2.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.2.norm2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.2.conv2.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.0.resnets.2.conv2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.0.upsamplers.0.conv.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.0.upsamplers.0.conv.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.0.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.0.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.0.conv1.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.0.conv1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.0.norm2.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.0.norm2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.0.conv2.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.0.conv2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.1.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.1.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.1.conv1.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.1.conv1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.1.norm2.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.1.norm2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.1.conv2.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.1.conv2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.2.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.2.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.2.conv1.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.2.conv1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.2.norm2.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.2.norm2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.2.conv2.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.1.resnets.2.conv2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.1.upsamplers.0.conv.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.1.upsamplers.0.conv.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.conv1.weight\n", - "Target shape: torch.Size([512, 768, 3, 3])\n", - "Size mismatch: torch.Size([256, 512, 3, 3]) vs torch.Size([512, 768, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.conv_shortcut.weight\n", - "Target shape: torch.Size([512, 768, 1, 1])\n", - "Size mismatch: torch.Size([256, 512, 1, 1]) vs torch.Size([512, 768, 1, 1])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.0.conv_shortcut.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.1.norm1.weight\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.1.norm1.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.1.conv1.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.1.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.1.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.1.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.1.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.1.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.2.norm1.weight\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.2.norm1.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.2.conv1.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.2.conv1.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.2.norm2.weight\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.2.norm2.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.2.conv2.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.2.resnets.2.conv2.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.2.upsamplers.0.conv.weight\n", - "Target shape: torch.Size([512, 512, 3, 3])\n", - "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.2.upsamplers.0.conv.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.norm1.weight\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.norm1.bias\n", - "Target shape: torch.Size([512])\n", - "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.conv1.weight\n", - "Target shape: torch.Size([256, 512, 3, 3])\n", - "Size mismatch: torch.Size([128, 256, 3, 3]) vs torch.Size([256, 512, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.conv1.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.norm2.weight\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.norm2.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.conv2.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.conv2.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.conv_shortcut.weight\n", - "Target shape: torch.Size([256, 512, 1, 1])\n", - "Size mismatch: torch.Size([128, 256, 1, 1]) vs torch.Size([256, 512, 1, 1])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.0.conv_shortcut.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.1.norm1.weight\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.1.norm1.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.1.conv1.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.1.conv1.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.1.norm2.weight\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.1.norm2.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.1.conv2.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.1.conv2.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.2.norm1.weight\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.2.norm1.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.2.conv1.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.2.conv1.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.2.norm2.weight\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.2.norm2.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.2.conv2.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", - "\n", - "Processing key: decoder.up_blocks.3.resnets.2.conv2.bias\n", - "Target shape: torch.Size([256])\n", - "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", - "\n", - "Processing key: decoder.up_blocks.3.upsamplers.0.conv.weight\n", - "Target shape: torch.Size([256, 256, 3, 3])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.up_blocks.3.upsamplers.0.conv.bias\n", - "Target shape: torch.Size([256])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.norm1.weight\n", - "Target shape: torch.Size([256])\n", - "Found source key: decoder.up_blocks.3.resnets.0.norm1.weight\n", - "Source shape: torch.Size([256])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.norm1.bias\n", - "Target shape: torch.Size([256])\n", - "Found source key: decoder.up_blocks.3.resnets.0.norm1.bias\n", - "Source shape: torch.Size([256])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.conv1.weight\n", - "Target shape: torch.Size([128, 256, 3, 3])\n", - "Found source key: decoder.up_blocks.3.resnets.0.conv1.weight\n", - "Source shape: torch.Size([128, 256, 3, 3])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.conv1.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.0.conv1.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.norm2.weight\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.0.norm2.weight\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.norm2.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.0.norm2.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.conv2.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Found source key: decoder.up_blocks.3.resnets.0.conv2.weight\n", - "Source shape: torch.Size([128, 128, 3, 3])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.conv2.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.0.conv2.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.conv_shortcut.weight\n", - "Target shape: torch.Size([128, 256, 1, 1])\n", - "Found source key: decoder.up_blocks.3.resnets.0.conv_shortcut.weight\n", - "Source shape: torch.Size([128, 256, 1, 1])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.0.conv_shortcut.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.0.conv_shortcut.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.1.norm1.weight\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.1.norm1.weight\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.1.norm1.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.1.norm1.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.1.conv1.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Found source key: decoder.up_blocks.3.resnets.1.conv1.weight\n", - "Source shape: torch.Size([128, 128, 3, 3])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.1.conv1.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.1.conv1.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.1.norm2.weight\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.1.norm2.weight\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.1.norm2.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.1.norm2.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.1.conv2.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Found source key: decoder.up_blocks.3.resnets.1.conv2.weight\n", - "Source shape: torch.Size([128, 128, 3, 3])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.1.conv2.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.1.conv2.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.2.norm1.weight\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.2.norm1.weight\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.2.norm1.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.2.norm1.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.2.conv1.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Found source key: decoder.up_blocks.3.resnets.2.conv1.weight\n", - "Source shape: torch.Size([128, 128, 3, 3])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.2.conv1.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.2.conv1.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.2.norm2.weight\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.2.norm2.weight\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.2.norm2.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.2.norm2.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.2.conv2.weight\n", - "Target shape: torch.Size([128, 128, 3, 3])\n", - "Found source key: decoder.up_blocks.3.resnets.2.conv2.weight\n", - "Source shape: torch.Size([128, 128, 3, 3])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.up_blocks.4.resnets.2.conv2.bias\n", - "Target shape: torch.Size([128])\n", - "Found source key: decoder.up_blocks.3.resnets.2.conv2.bias\n", - "Source shape: torch.Size([128])\n", - "Shapes match, copying directly...\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.group_norm.weight\n", - "Target shape: torch.Size([768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.group_norm.bias\n", - "Target shape: torch.Size([768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.to_q.weight\n", - "Target shape: torch.Size([768, 768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.to_q.bias\n", - "Target shape: torch.Size([768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.to_k.weight\n", - "Target shape: torch.Size([768, 768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.to_k.bias\n", - "Target shape: torch.Size([768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.to_v.weight\n", - "Target shape: torch.Size([768, 768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.to_v.bias\n", - "Target shape: torch.Size([768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.to_out.0.weight\n", - "Target shape: torch.Size([768, 768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.attentions.0.to_out.0.bias\n", - "Target shape: torch.Size([768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.mid_block.resnets.0.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.0.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.0.conv1.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.mid_block.resnets.0.conv1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.0.norm2.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.0.norm2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.0.conv2.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.mid_block.resnets.0.conv2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.1.norm1.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.1.norm1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.1.conv1.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.mid_block.resnets.1.conv1.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.1.norm2.weight\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.1.norm2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.mid_block.resnets.1.conv2.weight\n", - "Target shape: torch.Size([768, 768, 3, 3])\n", - "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", - "\n", - "Processing key: decoder.mid_block.resnets.1.conv2.bias\n", - "Target shape: torch.Size([768])\n", - "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", - "\n", - "Processing key: decoder.condition_encoder.layers.0.weight\n", - "Target shape: torch.Size([128, 3, 3, 3])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.0.bias\n", - "Target shape: torch.Size([128])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.1.weight\n", - "Target shape: torch.Size([256, 128, 3, 3])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.1.bias\n", - "Target shape: torch.Size([256])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.2.weight\n", - "Target shape: torch.Size([512, 256, 4, 4])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.2.bias\n", - "Target shape: torch.Size([512])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.3.weight\n", - "Target shape: torch.Size([768, 512, 4, 4])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.3.bias\n", - "Target shape: torch.Size([768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.4.weight\n", - "Target shape: torch.Size([768, 768, 4, 4])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.condition_encoder.layers.4.bias\n", - "Target shape: torch.Size([768])\n", - "Key not found in source model\n", - "\n", - "Processing key: decoder.conv_norm_out.weight\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: decoder.conv_norm_out.bias\n", - "Target shape: torch.Size([128])\n", - "Direct copy...\n", - "\n", - "Processing key: decoder.conv_out.weight\n", - "Target shape: torch.Size([3, 128, 3, 3])\n", - "Direct copy...\n", - "\n", - "Processing key: decoder.conv_out.bias\n", - "Target shape: torch.Size([3])\n", - "Direct copy...\n", - "\n", - "Processing key: quant_conv.weight\n", - "Target shape: torch.Size([32, 32, 1, 1])\n", - "Direct copy...\n", - "\n", - "Processing key: quant_conv.bias\n", - "Target shape: torch.Size([32])\n", - "Direct copy...\n", - "\n", - "Processing key: post_quant_conv.weight\n", - "Target shape: torch.Size([16, 16, 1, 1])\n", - "Direct copy...\n", - "\n", - "Processing key: post_quant_conv.bias\n", - "Target shape: torch.Size([16])\n", - "Direct copy...\n", - "\n", - "Updating state dict...\n", - "\n", - "Loading state dict...\n", - "\n", - "Saving model...\n", - "\n", - "Transfer statistics: {'перенесено': 130, 'несовпадение_размеров': 124, 'пропущено': 32}\n" - ] - } - ], - "source": [ - "import torch\n", - "from diffusers import AsymmetricAutoencoderKL,AutoencoderKL\n", - "from tqdm import tqdm\n", - "\n", - "def log(message):\n", - " print(message)\n", - "\n", - "def interpolate_tensor(tensor, target_shape):\n", - " \"\"\"Интерполяция тензора до целевой формы\"\"\"\n", - " print(f\"Interpolating tensor of shape {tensor.shape} to target shape {target_shape}\")\n", - " \n", - " # Создаем новый тензор нужного размера\n", - " result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)\n", - " \n", - " if len(tensor.shape) == 1: # Для 1D тензоров (bias)\n", - " min_size = min(tensor.shape[0], target_shape[0])\n", - " result[:min_size] = tensor[:min_size]\n", - " if target_shape[0] > min_size:\n", - " result[min_size:] = tensor[min_size-1]\n", - " else: # Для всех остальных тензоров\n", - " # Просто копируем то, что можем, остальное оставляем нулями\n", - " if len(tensor.shape) == len(target_shape):\n", - " for i in range(len(tensor.shape)):\n", - " if tensor.shape[i] > target_shape[i]:\n", - " tensor = tensor.narrow(i, 0, target_shape[i])\n", - " result[tuple(slice(0, s) for s in tensor.shape)] = tensor\n", - " \n", - " return result\n", - "\n", - "def main():\n", - " checkpoint_path_old = \"AiArtLab/sdxs\"\n", - " checkpoint_path_new = \"simple_vae\"\n", - " device = \"cuda\"\n", - " dtype = torch.float16\n", - "\n", - " print(\"Loading models...\")\n", - " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder=\"vae\",variant=\"fp16\").to(device, dtype=dtype)\n", - " new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n", - "\n", - " old_state_dict = old_unet.state_dict()\n", - " new_state_dict = new_unet.state_dict()\n", - "\n", - " transferred_state_dict = {}\n", - " transfer_stats = {\"перенесено\": 0, \"несовпадение_размеров\": 0, \"пропущено\": 0}\n", - "\n", - " print(\"\\nProcessing weights...\")\n", - " for new_key in tqdm(new_state_dict.keys()):\n", - " print(f\"\\nProcessing key: {new_key}\")\n", - " print(f\"Target shape: {new_state_dict[new_key].shape}\")\n", - "\n", - " # Для блоков 4 и 5 используем веса из блока 3\n", - " if 'decoder.up_blocks.4.' in new_key or 'decoder.up_blocks.5.' in new_key:\n", - " source_key = new_key.replace('decoder.up_blocks.4.', 'decoder.up_blocks.3.')\n", - " source_key = source_key.replace('decoder.up_blocks.5.', 'decoder.up_blocks.3.')\n", - " \n", - " if source_key in old_state_dict:\n", - " print(f\"Found source key: {source_key}\")\n", - " source_tensor = old_state_dict[source_key]\n", - " print(f\"Source shape: {source_tensor.shape}\")\n", - " \n", - " if source_tensor.shape != new_state_dict[new_key].shape:\n", - " print(\"Shapes don't match, interpolating...\")\n", - " transferred_state_dict[new_key] = interpolate_tensor(source_tensor, new_state_dict[new_key].shape)\n", - " else:\n", - " print(\"Shapes match, copying directly...\")\n", - " transferred_state_dict[new_key] = source_tensor.clone()\n", - " transfer_stats[\"перенесено\"] += 1\n", - " continue\n", - "\n", - " # Для остальных ключей пробуем прямой перенос\n", - " if new_key in old_state_dict:\n", - " if old_state_dict[new_key].shape == new_state_dict[new_key].shape:\n", - " print(\"Direct copy...\")\n", - " transferred_state_dict[new_key] = old_state_dict[new_key].clone()\n", - " transfer_stats[\"перенесено\"] += 1\n", - " else:\n", - " print(f\"Size mismatch: {old_state_dict[new_key].shape} vs {new_state_dict[new_key].shape}\")\n", - " transfer_stats[\"несовпадение_размеров\"] += 1\n", - " else:\n", - " print(\"Key not found in source model\")\n", - " transfer_stats[\"пропущено\"] += 1\n", - "\n", - " print(\"\\nUpdating state dict...\")\n", - " new_state_dict.update(transferred_state_dict)\n", - "\n", - " print(\"\\nLoading state dict...\")\n", - " new_unet.load_state_dict(new_state_dict)\n", - "\n", - " print(\"\\nSaving model...\")\n", - " new_unet.save_pretrained(\"vae\")\n", - "\n", - " print(\"\\nTransfer statistics:\", transfer_stats)\n", - "\n", - "if __name__ == \"__main__\":\n", - " main()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "54b1c67b-8eab-4bca-9cfe-b05bcc966abb", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}