{ "cells": [ { "cell_type": "code", "execution_count": 7, "id": "407171be-ab46-442b-a0bd-83ca75173eba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AsymmetricAutoencoderKL(\n", " (encoder): Encoder(\n", " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (down_blocks): ModuleList(\n", " (0): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (1): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (2): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (downsamplers): ModuleList(\n", " (0): Downsample2D(\n", " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n", " )\n", " )\n", " )\n", " (3): DownEncoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " )\n", " (mid_block): UNetMidBlock2D(\n", " (attentions): ModuleList(\n", " (0): Attention(\n", " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (to_q): Linear(in_features=512, out_features=512, bias=True)\n", " (to_k): Linear(in_features=512, out_features=512, bias=True)\n", " (to_v): Linear(in_features=512, out_features=512, bias=True)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=512, out_features=512, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv_act): SiLU()\n", " (conv_out): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " (decoder): MaskConditionDecoder(\n", " (conv_in): Conv2d(16, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (up_blocks): ModuleList(\n", " (0-1): 2 x UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0-2): 3 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 768, eps=1e-06, affine=True)\n", " (conv1): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 768, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (2): UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 768, eps=1e-06, affine=True)\n", " (conv1): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-2): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (3): UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n", " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-2): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " (upsamplers): ModuleList(\n", " (0): Upsample2D(\n", " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " )\n", " )\n", " (4): UpDecoderBlock2D(\n", " (resnets): ModuleList(\n", " (0): ResnetBlock2D(\n", " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n", " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", " (1-2): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " )\n", " (mid_block): UNetMidBlock2D(\n", " (attentions): ModuleList(\n", " (0): Attention(\n", " (group_norm): GroupNorm(32, 768, eps=1e-06, affine=True)\n", " (to_q): Linear(in_features=768, out_features=768, bias=True)\n", " (to_k): Linear(in_features=768, out_features=768, bias=True)\n", " (to_v): Linear(in_features=768, out_features=768, bias=True)\n", " (to_out): ModuleList(\n", " (0): Linear(in_features=768, out_features=768, bias=True)\n", " (1): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (resnets): ModuleList(\n", " (0-1): 2 x ResnetBlock2D(\n", " (norm1): GroupNorm(32, 768, eps=1e-06, affine=True)\n", " (conv1): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (norm2): GroupNorm(32, 768, eps=1e-06, affine=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (conv2): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (nonlinearity): SiLU()\n", " )\n", " )\n", " )\n", " (condition_encoder): MaskConditionEncoder(\n", " (layers): Sequential(\n", " (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (2): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (3): Conv2d(512, 768, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " (4): Conv2d(768, 768, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", " )\n", " )\n", " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n", " (conv_act): SiLU()\n", " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " (quant_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n", " (post_quant_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))\n", ")\n" ] } ], "source": [ "from diffusers.models import AsymmetricAutoencoderKL\n", "import torch\n", "\n", "config = {\n", " \"_class_name\": \"AsymmetricAutoencoderKL\",\n", " \"act_fn\": \"silu\",\n", " \"down_block_out_channels\": [128, 256, 512, 512],\n", " \"down_block_types\": [\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " \"DownEncoderBlock2D\",\n", " ],\n", " \"in_channels\": 3,\n", " \"latent_channels\": 16,\n", " #\"latents_mean\": [0.2539, 0.1431, 0.1484, -0.3048, -0.0985, -0.162, 0.1403, 0.2034, -0.1419, 0.2646, 0.0655, 0.0061, 0.1555, 0.0506, 0.0129, -0.1948],\n", " #\"latents_std\": [0.8123, 0.7376, 0.7354, 1.1827, 0.8387, 0.8735, 0.8705, 0.8142, 0.8076, 0.7409, 0.7655, 0.8731, 0.8087, 0.7058, 0.8087, 0.7615],\n", " #\"layers_per_block\": 2,\n", " \"norm_num_groups\": 32,\n", " \"out_channels\": 3,\n", " \"sample_size\": 1024,\n", " \"scaling_factor\": 1,\n", " \"shift_factor\": 0,\n", " \"up_block_out_channels\": [128, 256, 512, 768, 768],\n", " \"up_block_types\": [\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " \"UpDecoderBlock2D\",\n", " ],\n", "}\n", "\n", "# Преобразуем списки mean и std в тензоры\n", "#latents_mean = torch.tensor(config[\"latents_mean\"])\n", "#latents_std = torch.tensor(config[\"latents_std\"])\n", "\n", "# Создаем модель\n", "vae = AsymmetricAutoencoderKL(\n", " act_fn=config[\"act_fn\"],\n", " down_block_out_channels=config[\"down_block_out_channels\"],\n", " down_block_types=config[\"down_block_types\"],\n", " in_channels=config[\"in_channels\"],\n", " latent_channels=config[\"latent_channels\"],\n", " norm_num_groups=config[\"norm_num_groups\"],\n", " out_channels=config[\"out_channels\"],\n", " sample_size=config[\"sample_size\"],\n", " scaling_factor=config[\"scaling_factor\"],\n", " up_block_out_channels=config[\"up_block_out_channels\"],\n", " up_block_types=config[\"up_block_types\"],\n", " layers_per_down_block = 2,\n", " layers_per_up_block = 2\n", ")\n", "\n", "# Устанавливаем mean и std для латентов\n", "#vae.latents_mean = latents_mean\n", "#vae.latents_std = latents_std\n", "\n", "vae.save_pretrained(\"simple_vae\")\n", "print(vae)" ] }, { "cell_type": "code", "execution_count": 8, "id": "290a6758-5aa8-47a4-ba2f-ece8abf6df88", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n", "Перенос весов: 100%|██████████| 228/228 [00:00<00:00, 192647.32it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "✗ Несовпадение размеров: decoder.conv_in.weight (torch.Size([512, 16, 3, 3])) -> decoder.conv_in.weight (torch.Size([768, 16, 3, 3]))\n", "✗ Несовпадение размеров: decoder.conv_in.bias (torch.Size([512])) -> decoder.conv_in.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm1.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.0.conv1.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.conv1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm2.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm2.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.0.conv2.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.conv2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.norm1.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.norm1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.1.conv1.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.conv1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.conv1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.norm2.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.norm2.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.norm2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.norm2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.1.conv2.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.1.conv2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.1.conv2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.norm1.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.norm1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.2.conv1.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.conv1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.conv1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.norm2.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.norm2.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.norm2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.norm2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.2.conv2.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.resnets.2.conv2.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.2.conv2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.upsamplers.0.conv.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.upsamplers.0.conv.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.0.upsamplers.0.conv.bias (torch.Size([512])) -> decoder.up_blocks.0.upsamplers.0.conv.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.norm1.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.norm1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.0.conv1.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.conv1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.conv1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.norm2.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.norm2.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.norm2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.norm2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.0.conv2.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.0.conv2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.0.conv2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.norm1.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.norm1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.1.conv1.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.conv1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.conv1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.norm2.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.norm2.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.norm2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.norm2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.1.conv2.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.1.conv2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.1.conv2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.norm1.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.norm1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.2.conv1.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.conv1.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.conv1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.norm2.weight (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.norm2.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.norm2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.norm2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.resnets.2.conv2.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.resnets.2.conv2.bias (torch.Size([512])) -> decoder.up_blocks.1.resnets.2.conv2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.upsamplers.0.conv.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.1.upsamplers.0.conv.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.1.upsamplers.0.conv.bias (torch.Size([512])) -> decoder.up_blocks.1.upsamplers.0.conv.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.norm1.weight (torch.Size([512])) -> decoder.up_blocks.2.resnets.0.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.norm1.bias (torch.Size([512])) -> decoder.up_blocks.2.resnets.0.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv1.weight (torch.Size([256, 512, 3, 3])) -> decoder.up_blocks.2.resnets.0.conv1.weight (torch.Size([512, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.conv1.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.norm2.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.norm2.weight (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.norm2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.norm2.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv2.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.conv2.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv_shortcut.weight (torch.Size([256, 512, 1, 1])) -> decoder.up_blocks.2.resnets.0.conv_shortcut.weight (torch.Size([512, 768, 1, 1]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.0.conv_shortcut.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.0.conv_shortcut.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.norm1.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.norm1.weight (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.norm1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.norm1.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.conv1.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.conv1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.conv1.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.norm2.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.norm2.weight (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.norm2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.norm2.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.conv2.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.1.conv2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.1.conv2.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.norm1.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.norm1.weight (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.norm1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.norm1.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.conv1.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.2.conv1.weight (torch.Size([512, 512, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.conv1.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.conv1.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.norm2.weight (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.norm2.weight (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.norm2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.norm2.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.conv2.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.resnets.2.conv2.weight (torch.Size([512, 512, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.resnets.2.conv2.bias (torch.Size([256])) -> decoder.up_blocks.2.resnets.2.conv2.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.upsamplers.0.conv.weight (torch.Size([256, 256, 3, 3])) -> decoder.up_blocks.2.upsamplers.0.conv.weight (torch.Size([512, 512, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.2.upsamplers.0.conv.bias (torch.Size([256])) -> decoder.up_blocks.2.upsamplers.0.conv.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.norm1.weight (torch.Size([256])) -> decoder.up_blocks.3.resnets.0.norm1.weight (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.norm1.bias (torch.Size([256])) -> decoder.up_blocks.3.resnets.0.norm1.bias (torch.Size([512]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv1.weight (torch.Size([128, 256, 3, 3])) -> decoder.up_blocks.3.resnets.0.conv1.weight (torch.Size([256, 512, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.conv1.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.norm2.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.norm2.weight (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.norm2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.norm2.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv2.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.0.conv2.weight (torch.Size([256, 256, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.conv2.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv_shortcut.weight (torch.Size([128, 256, 1, 1])) -> decoder.up_blocks.3.resnets.0.conv_shortcut.weight (torch.Size([256, 512, 1, 1]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.0.conv_shortcut.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.0.conv_shortcut.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.norm1.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.norm1.weight (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.norm1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.norm1.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.conv1.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.1.conv1.weight (torch.Size([256, 256, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.conv1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.conv1.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.norm2.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.norm2.weight (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.norm2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.norm2.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.conv2.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.1.conv2.weight (torch.Size([256, 256, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.1.conv2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.1.conv2.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.norm1.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.norm1.weight (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.norm1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.norm1.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.conv1.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.2.conv1.weight (torch.Size([256, 256, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.conv1.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.conv1.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.norm2.weight (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.norm2.weight (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.norm2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.norm2.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.conv2.weight (torch.Size([128, 128, 3, 3])) -> decoder.up_blocks.3.resnets.2.conv2.weight (torch.Size([256, 256, 3, 3]))\n", "✗ Несовпадение размеров: decoder.up_blocks.3.resnets.2.conv2.bias (torch.Size([128])) -> decoder.up_blocks.3.resnets.2.conv2.bias (torch.Size([256]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.0.norm1.weight (torch.Size([512])) -> decoder.mid_block.resnets.0.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.0.norm1.bias (torch.Size([512])) -> decoder.mid_block.resnets.0.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.mid_block.resnets.0.conv1.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.0.conv1.bias (torch.Size([512])) -> decoder.mid_block.resnets.0.conv1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.0.norm2.weight (torch.Size([512])) -> decoder.mid_block.resnets.0.norm2.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.0.norm2.bias (torch.Size([512])) -> decoder.mid_block.resnets.0.norm2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.mid_block.resnets.0.conv2.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.0.conv2.bias (torch.Size([512])) -> decoder.mid_block.resnets.0.conv2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.1.norm1.weight (torch.Size([512])) -> decoder.mid_block.resnets.1.norm1.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.1.norm1.bias (torch.Size([512])) -> decoder.mid_block.resnets.1.norm1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.mid_block.resnets.1.conv1.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.1.conv1.bias (torch.Size([512])) -> decoder.mid_block.resnets.1.conv1.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.1.norm2.weight (torch.Size([512])) -> decoder.mid_block.resnets.1.norm2.weight (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.1.norm2.bias (torch.Size([512])) -> decoder.mid_block.resnets.1.norm2.bias (torch.Size([768]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3])) -> decoder.mid_block.resnets.1.conv2.weight (torch.Size([768, 768, 3, 3]))\n", "✗ Несовпадение размеров: decoder.mid_block.resnets.1.conv2.bias (torch.Size([512])) -> decoder.mid_block.resnets.1.conv2.bias (torch.Size([768]))\n", "Статистика переноса: {'перенесено': 104, 'несовпадение_размеров': 124, 'пропущено': 0}\n", "Неперенесенные ключи в новой модели:\n", "decoder.condition_encoder.layers.0.bias\n", "decoder.condition_encoder.layers.0.weight\n", "decoder.condition_encoder.layers.1.bias\n", "decoder.condition_encoder.layers.1.weight\n", "decoder.condition_encoder.layers.2.bias\n", "decoder.condition_encoder.layers.2.weight\n", "decoder.condition_encoder.layers.3.bias\n", "decoder.condition_encoder.layers.3.weight\n", "decoder.condition_encoder.layers.4.bias\n", "decoder.condition_encoder.layers.4.weight\n", "decoder.conv_in.bias\n", "decoder.conv_in.weight\n", "decoder.mid_block.attentions.0.group_norm.bias\n", "decoder.mid_block.attentions.0.group_norm.weight\n", "decoder.mid_block.attentions.0.to_k.bias\n", "decoder.mid_block.attentions.0.to_k.weight\n", "decoder.mid_block.attentions.0.to_out.0.bias\n", "decoder.mid_block.attentions.0.to_out.0.weight\n", "decoder.mid_block.attentions.0.to_q.bias\n", "decoder.mid_block.attentions.0.to_q.weight\n", "decoder.mid_block.attentions.0.to_v.bias\n", "decoder.mid_block.attentions.0.to_v.weight\n", "decoder.mid_block.resnets.0.conv1.bias\n", "decoder.mid_block.resnets.0.conv1.weight\n", "decoder.mid_block.resnets.0.conv2.bias\n", "decoder.mid_block.resnets.0.conv2.weight\n", "decoder.mid_block.resnets.0.norm1.bias\n", "decoder.mid_block.resnets.0.norm1.weight\n", "decoder.mid_block.resnets.0.norm2.bias\n", "decoder.mid_block.resnets.0.norm2.weight\n", "decoder.mid_block.resnets.1.conv1.bias\n", "decoder.mid_block.resnets.1.conv1.weight\n", "decoder.mid_block.resnets.1.conv2.bias\n", "decoder.mid_block.resnets.1.conv2.weight\n", "decoder.mid_block.resnets.1.norm1.bias\n", "decoder.mid_block.resnets.1.norm1.weight\n", "decoder.mid_block.resnets.1.norm2.bias\n", "decoder.mid_block.resnets.1.norm2.weight\n", "decoder.up_blocks.0.resnets.0.conv1.bias\n", "decoder.up_blocks.0.resnets.0.conv1.weight\n", "decoder.up_blocks.0.resnets.0.conv2.bias\n", "decoder.up_blocks.0.resnets.0.conv2.weight\n", "decoder.up_blocks.0.resnets.0.norm1.bias\n", "decoder.up_blocks.0.resnets.0.norm1.weight\n", "decoder.up_blocks.0.resnets.0.norm2.bias\n", "decoder.up_blocks.0.resnets.0.norm2.weight\n", "decoder.up_blocks.0.resnets.1.conv1.bias\n", "decoder.up_blocks.0.resnets.1.conv1.weight\n", "decoder.up_blocks.0.resnets.1.conv2.bias\n", "decoder.up_blocks.0.resnets.1.conv2.weight\n", "decoder.up_blocks.0.resnets.1.norm1.bias\n", "decoder.up_blocks.0.resnets.1.norm1.weight\n", "decoder.up_blocks.0.resnets.1.norm2.bias\n", "decoder.up_blocks.0.resnets.1.norm2.weight\n", "decoder.up_blocks.0.resnets.2.conv1.bias\n", "decoder.up_blocks.0.resnets.2.conv1.weight\n", "decoder.up_blocks.0.resnets.2.conv2.bias\n", "decoder.up_blocks.0.resnets.2.conv2.weight\n", "decoder.up_blocks.0.resnets.2.norm1.bias\n", "decoder.up_blocks.0.resnets.2.norm1.weight\n", "decoder.up_blocks.0.resnets.2.norm2.bias\n", "decoder.up_blocks.0.resnets.2.norm2.weight\n", "decoder.up_blocks.0.upsamplers.0.conv.bias\n", "decoder.up_blocks.0.upsamplers.0.conv.weight\n", "decoder.up_blocks.1.resnets.0.conv1.bias\n", "decoder.up_blocks.1.resnets.0.conv1.weight\n", "decoder.up_blocks.1.resnets.0.conv2.bias\n", "decoder.up_blocks.1.resnets.0.conv2.weight\n", "decoder.up_blocks.1.resnets.0.norm1.bias\n", "decoder.up_blocks.1.resnets.0.norm1.weight\n", "decoder.up_blocks.1.resnets.0.norm2.bias\n", "decoder.up_blocks.1.resnets.0.norm2.weight\n", "decoder.up_blocks.1.resnets.1.conv1.bias\n", "decoder.up_blocks.1.resnets.1.conv1.weight\n", "decoder.up_blocks.1.resnets.1.conv2.bias\n", "decoder.up_blocks.1.resnets.1.conv2.weight\n", "decoder.up_blocks.1.resnets.1.norm1.bias\n", "decoder.up_blocks.1.resnets.1.norm1.weight\n", "decoder.up_blocks.1.resnets.1.norm2.bias\n", "decoder.up_blocks.1.resnets.1.norm2.weight\n", "decoder.up_blocks.1.resnets.2.conv1.bias\n", "decoder.up_blocks.1.resnets.2.conv1.weight\n", "decoder.up_blocks.1.resnets.2.conv2.bias\n", "decoder.up_blocks.1.resnets.2.conv2.weight\n", "decoder.up_blocks.1.resnets.2.norm1.bias\n", "decoder.up_blocks.1.resnets.2.norm1.weight\n", "decoder.up_blocks.1.resnets.2.norm2.bias\n", "decoder.up_blocks.1.resnets.2.norm2.weight\n", "decoder.up_blocks.1.upsamplers.0.conv.bias\n", "decoder.up_blocks.1.upsamplers.0.conv.weight\n", "decoder.up_blocks.2.resnets.0.conv1.bias\n", "decoder.up_blocks.2.resnets.0.conv1.weight\n", "decoder.up_blocks.2.resnets.0.conv2.bias\n", "decoder.up_blocks.2.resnets.0.conv2.weight\n", "decoder.up_blocks.2.resnets.0.conv_shortcut.bias\n", "decoder.up_blocks.2.resnets.0.conv_shortcut.weight\n", "decoder.up_blocks.2.resnets.0.norm1.bias\n", "decoder.up_blocks.2.resnets.0.norm1.weight\n", "decoder.up_blocks.2.resnets.0.norm2.bias\n", "decoder.up_blocks.2.resnets.0.norm2.weight\n", "decoder.up_blocks.2.resnets.1.conv1.bias\n", "decoder.up_blocks.2.resnets.1.conv1.weight\n", "decoder.up_blocks.2.resnets.1.conv2.bias\n", "decoder.up_blocks.2.resnets.1.conv2.weight\n", "decoder.up_blocks.2.resnets.1.norm1.bias\n", "decoder.up_blocks.2.resnets.1.norm1.weight\n", "decoder.up_blocks.2.resnets.1.norm2.bias\n", "decoder.up_blocks.2.resnets.1.norm2.weight\n", "decoder.up_blocks.2.resnets.2.conv1.bias\n", "decoder.up_blocks.2.resnets.2.conv1.weight\n", "decoder.up_blocks.2.resnets.2.conv2.bias\n", "decoder.up_blocks.2.resnets.2.conv2.weight\n", "decoder.up_blocks.2.resnets.2.norm1.bias\n", "decoder.up_blocks.2.resnets.2.norm1.weight\n", "decoder.up_blocks.2.resnets.2.norm2.bias\n", "decoder.up_blocks.2.resnets.2.norm2.weight\n", "decoder.up_blocks.2.upsamplers.0.conv.bias\n", "decoder.up_blocks.2.upsamplers.0.conv.weight\n", "decoder.up_blocks.3.resnets.0.conv1.bias\n", "decoder.up_blocks.3.resnets.0.conv1.weight\n", "decoder.up_blocks.3.resnets.0.conv2.bias\n", "decoder.up_blocks.3.resnets.0.conv2.weight\n", "decoder.up_blocks.3.resnets.0.conv_shortcut.bias\n", "decoder.up_blocks.3.resnets.0.conv_shortcut.weight\n", "decoder.up_blocks.3.resnets.0.norm1.bias\n", "decoder.up_blocks.3.resnets.0.norm1.weight\n", "decoder.up_blocks.3.resnets.0.norm2.bias\n", "decoder.up_blocks.3.resnets.0.norm2.weight\n", "decoder.up_blocks.3.resnets.1.conv1.bias\n", "decoder.up_blocks.3.resnets.1.conv1.weight\n", "decoder.up_blocks.3.resnets.1.conv2.bias\n", "decoder.up_blocks.3.resnets.1.conv2.weight\n", "decoder.up_blocks.3.resnets.1.norm1.bias\n", "decoder.up_blocks.3.resnets.1.norm1.weight\n", "decoder.up_blocks.3.resnets.1.norm2.bias\n", "decoder.up_blocks.3.resnets.1.norm2.weight\n", "decoder.up_blocks.3.resnets.2.conv1.bias\n", "decoder.up_blocks.3.resnets.2.conv1.weight\n", "decoder.up_blocks.3.resnets.2.conv2.bias\n", "decoder.up_blocks.3.resnets.2.conv2.weight\n", "decoder.up_blocks.3.resnets.2.norm1.bias\n", "decoder.up_blocks.3.resnets.2.norm1.weight\n", "decoder.up_blocks.3.resnets.2.norm2.bias\n", "decoder.up_blocks.3.resnets.2.norm2.weight\n", "decoder.up_blocks.3.upsamplers.0.conv.bias\n", "decoder.up_blocks.3.upsamplers.0.conv.weight\n", "decoder.up_blocks.4.resnets.0.conv1.bias\n", "decoder.up_blocks.4.resnets.0.conv1.weight\n", "decoder.up_blocks.4.resnets.0.conv2.bias\n", "decoder.up_blocks.4.resnets.0.conv2.weight\n", "decoder.up_blocks.4.resnets.0.conv_shortcut.bias\n", "decoder.up_blocks.4.resnets.0.conv_shortcut.weight\n", "decoder.up_blocks.4.resnets.0.norm1.bias\n", "decoder.up_blocks.4.resnets.0.norm1.weight\n", "decoder.up_blocks.4.resnets.0.norm2.bias\n", "decoder.up_blocks.4.resnets.0.norm2.weight\n", "decoder.up_blocks.4.resnets.1.conv1.bias\n", "decoder.up_blocks.4.resnets.1.conv1.weight\n", "decoder.up_blocks.4.resnets.1.conv2.bias\n", "decoder.up_blocks.4.resnets.1.conv2.weight\n", "decoder.up_blocks.4.resnets.1.norm1.bias\n", "decoder.up_blocks.4.resnets.1.norm1.weight\n", "decoder.up_blocks.4.resnets.1.norm2.bias\n", "decoder.up_blocks.4.resnets.1.norm2.weight\n", "decoder.up_blocks.4.resnets.2.conv1.bias\n", "decoder.up_blocks.4.resnets.2.conv1.weight\n", "decoder.up_blocks.4.resnets.2.conv2.bias\n", "decoder.up_blocks.4.resnets.2.conv2.weight\n", "decoder.up_blocks.4.resnets.2.norm1.bias\n", "decoder.up_blocks.4.resnets.2.norm1.weight\n", "decoder.up_blocks.4.resnets.2.norm2.bias\n", "decoder.up_blocks.4.resnets.2.norm2.weight\n", "encoder.mid_block.attentions.0.group_norm.bias\n", "encoder.mid_block.attentions.0.group_norm.weight\n", "encoder.mid_block.attentions.0.to_k.bias\n", "encoder.mid_block.attentions.0.to_k.weight\n", "encoder.mid_block.attentions.0.to_out.0.bias\n", "encoder.mid_block.attentions.0.to_out.0.weight\n", "encoder.mid_block.attentions.0.to_q.bias\n", "encoder.mid_block.attentions.0.to_q.weight\n", "encoder.mid_block.attentions.0.to_v.bias\n", "encoder.mid_block.attentions.0.to_v.weight\n" ] } ], "source": [ "import torch\n", "from diffusers import AsymmetricAutoencoderKL,AutoencoderKL\n", "from tqdm import tqdm\n", "import torch.nn.init as init\n", "\n", "def log(message):\n", " print(message)\n", "\n", "def initialize_mid_block_weights(state_dict, device, dtype):\n", " # Инициализация весов для mid block 0 с размерностью 512\n", " state_dict['encoder.mid_block.attentions.0.group_norm.weight'] = torch.ones(512, device=device, dtype=dtype)\n", " state_dict['encoder.mid_block.attentions.0.group_norm.bias'] = torch.zeros(512, device=device, dtype=dtype)\n", " \n", " # Удаляем ключи для второго блока внимания, так как он не существует в архитектуре\n", " #if 'encoder.mid_block.attentions.1.group_norm.weight' in state_dict:\n", " # del state_dict['encoder.mid_block.attentions.1.group_norm.weight']\n", " #if 'encoder.mid_block.attentions.1.group_norm.bias' in state_dict:\n", " # del state_dict['encoder.mid_block.attentions.1.group_norm.bias']\n", " \n", " return state_dict\n", "\n", "def main():\n", " checkpoint_path_old = \"AiArtLab/sdxs\"\n", " checkpoint_path_new = \"simple_vae\"\n", " device = \"cuda\"\n", " dtype = torch.float16\n", "\n", " # Загрузка моделей\n", " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder=\"vae\",variant=\"fp16\").to(device, dtype=dtype)\n", " new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n", "\n", " old_state_dict = old_unet.state_dict()\n", " new_state_dict = new_unet.state_dict()\n", "\n", " transferred_state_dict = {}\n", " transfer_stats = {\n", " \"перенесено\": 0,\n", " \"несовпадение_размеров\": 0,\n", " \"пропущено\": 0\n", " }\n", "\n", " transferred_keys = set()\n", "\n", " # Обрабатываем каждый ключ старой модели\n", " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n", " new_key = old_key\n", "\n", " if new_key in new_state_dict:\n", " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n", " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n", " transferred_keys.add(new_key)\n", " transfer_stats[\"перенесено\"] += 1\n", " else:\n", " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n", " transfer_stats[\"несовпадение_размеров\"] += 1\n", " else:\n", " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n", " transfer_stats[\"пропущено\"] += 1\n", "\n", " # Обновляем состояние новой модели перенесенными весами\n", " new_state_dict.update(transferred_state_dict)\n", " \n", " # Инициализируем веса для нового mid блока\n", " new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n", " \n", " new_unet.load_state_dict(new_state_dict)\n", " new_unet.save_pretrained(\"vae\")\n", "\n", " # Получаем список неперенесенных ключей\n", " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n", "\n", " print(\"Статистика переноса:\", transfer_stats)\n", " print(\"Неперенесенные ключи в новой модели:\")\n", " for key in non_transferred_keys:\n", " print(key)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "code", "execution_count": 3, "id": "cb9f4324-96b7-4f49-8d32-61c1294739c7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Weights before normalization:\n", "quant_conv.weight: torch.Size([32, 32, 1, 1])\n", "quant_conv.bias: torch.Size([32])\n", "post_quant_conv.weight: torch.Size([16, 16, 1, 1])\n", "post_quant_conv.bias: torch.Size([16])\n" ] } ], "source": [ "\n", "import torch\n", "from diffusers import AsymmetricAutoencoderKL\n", "from tqdm import tqdm\n", "\n", "def normalize_weights3(state_dict, latents_mean, latents_std):\n", " device = next(iter(state_dict.values())).device\n", " dtype = next(iter(state_dict.values())).dtype\n", " \n", " # Преобразуем в тензоры\n", " latents_mean = torch.tensor(latents_mean, device=device, dtype=dtype)\n", " latents_std = torch.tensor(latents_std, device=device, dtype=dtype)\n", "\n", " # Нормализация для quant_conv (32 -> 32 каналов)\n", " if 'quant_conv.weight' in state_dict:\n", " weight = state_dict['quant_conv.weight'] # [32, 32, 1, 1]\n", " # Применяем нормализацию к выходным каналам\n", " for i in range(weight.size(0)):\n", " weight[i] = weight[i] / latents_std[i % len(latents_std)]\n", " \n", " if 'quant_conv.bias' in state_dict:\n", " bias = state_dict['quant_conv.bias'] # [32]\n", " for i in range(bias.size(0)):\n", " bias[i] = (bias[i] - latents_mean[i % len(latents_mean)]) / latents_std[i % len(latents_std)]\n", "\n", " # Нормализация для post_quant_conv (16 -> 16 каналов)\n", " if 'post_quant_conv.weight' in state_dict:\n", " weight = state_dict['post_quant_conv.weight'] # [16, 16, 1, 1]\n", " # Применяем нормализацию к входным каналам\n", " for i in range(weight.size(1)):\n", " weight[:, i] = weight[:, i] * latents_std[i]\n", " \n", " if 'post_quant_conv.bias' in state_dict:\n", " bias = state_dict['post_quant_conv.bias'] # [16]\n", " for i in range(bias.size(0)):\n", " bias[i] = bias[i] * latents_std[i] + latents_mean[i]\n", "\n", " return state_dict\n", "\n", "def normalize_weights(state_dict, latents_mean, latents_std):\n", " device = next(iter(state_dict.values())).device\n", " dtype = next(iter(state_dict.values())).dtype\n", " \n", " # Преобразуем в тензоры\n", " latents_mean = torch.tensor(latents_mean, device=device, dtype=dtype)\n", " latents_std = torch.tensor(latents_std, device=device, dtype=dtype)\n", "\n", " # Нормализация для quant_conv (32 -> 32 каналов)\n", " # На выходе энкодера: (x - mean) / std\n", " if 'quant_conv.weight' in state_dict:\n", " weight = state_dict['quant_conv.weight'] # [32, 32, 1, 1]\n", " # Нормализуем выходные каналы\n", " for i in range(weight.size(0)):\n", " if i < len(latents_std):\n", " weight[i] = weight[i] / latents_std[i]\n", " \n", " if 'quant_conv.bias' in state_dict:\n", " bias = state_dict['quant_conv.bias'] # [32]\n", " for i in range(bias.size(0)):\n", " if i < len(latents_mean):\n", " # Сначала применяем сдвиг, потом масштабирование\n", " bias[i] = -latents_mean[i] / latents_std[i]\n", "\n", " # Нормализация для post_quant_conv (16 -> 16 каналов)\n", " # На входе декодера: x * std + mean\n", " if 'post_quant_conv.weight' in state_dict:\n", " weight = state_dict['post_quant_conv.weight'] # [16, 16, 1, 1]\n", " # Нормализуем входные каналы\n", " for i in range(weight.size(1)):\n", " if i < len(latents_std):\n", " weight[:, i] = weight[:, i] * latents_std[i]\n", " \n", " if 'post_quant_conv.bias' in state_dict:\n", " bias = state_dict['post_quant_conv.bias'] # [16]\n", " for i in range(bias.size(0)):\n", " if i < len(latents_mean):\n", " bias[i] = bias[i] + latents_mean[i]\n", "\n", " return state_dict\n", "\n", "def main():\n", " # Путь к модели\n", " model_path = \"vae\"\n", " device = \"cuda\"\n", " dtype = torch.float16\n", "\n", " # Ваши mean и std\n", " latents_mean = [0.2539, 0.1431, 0.1484, -0.3048, -0.0985, -0.162, 0.1403, 0.2034, -0.1419, 0.2646, 0.0655, 0.0061, 0.1555, 0.0506, 0.0129, -0.1948]\n", "\n", " latents_std = [0.8123, 0.7376, 0.7354, 1.1827, 0.8387, 0.8735, 0.8705, 0.8142, 0.8076, 0.7409, 0.7655, 0.8731, 0.8087, 0.7058, 0.8087, 0.7615]\n", "\n", " # Загружаем модель\n", " model = AsymmetricAutoencoderKL.from_pretrained(model_path).to(device, dtype=dtype)\n", " \n", " # Получаем state dict\n", " state_dict = model.state_dict()\n", "\n", " # Выводим информацию о весах до нормализации\n", " print(\"\\nWeights before normalization:\")\n", " for key in ['quant_conv.weight', 'quant_conv.bias', 'post_quant_conv.weight', 'post_quant_conv.bias']:\n", " if key in state_dict:\n", " print(f\"{key}: {state_dict[key].shape}\")\n", "\n", " # Нормализуем веса\n", " normalized_state_dict = normalize_weights(state_dict, latents_mean, latents_std)\n", " normalized_state_dict = initialize_mid_block_weights(normalized_state_dict, device, dtype)\n", "\n", " # Загружаем нормализованные веса обратно в модель\n", " model.load_state_dict(normalized_state_dict)\n", "\n", " # Сохраняем модель\n", " model.save_pretrained(\"vaenorm\")\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "code", "execution_count": 6, "id": "43b0cb3d-776d-414d-9e08-b136a31f27a5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Processing decoder.up_blocks.3.resnets.0.norm1.weight\n", "Old shape: torch.Size([256])\n", "Target shape: torch.Size([512])\n", "Interpolating tensor of shape torch.Size([256]) to target shape torch.Size([512])\n", "Extending 1D tensor from 256 to 512\n", "\n", "Processing decoder.up_blocks.3.resnets.0.norm1.bias\n", "Old shape: torch.Size([256])\n", "Target shape: torch.Size([512])\n", "Interpolating tensor of shape torch.Size([256]) to target shape torch.Size([512])\n", "Extending 1D tensor from 256 to 512\n", "\n", "Processing decoder.up_blocks.3.resnets.0.conv1.weight\n", "Old shape: torch.Size([128, 256, 3, 3])\n", "Target shape: torch.Size([256, 512, 3, 3])\n", "Interpolating tensor of shape torch.Size([128, 256, 3, 3]) to target shape torch.Size([256, 512, 3, 3])\n", "Copying existing weights: min_out=128, min_in=256\n", "Extending output channels from 128 to 256\n", "Extending input channels from 256 to 512\n", "\n", "Processing decoder.up_blocks.3.resnets.0.conv1.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.0.norm2.weight\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.0.norm2.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.0.conv2.weight\n", "Old shape: torch.Size([128, 128, 3, 3])\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", "Copying existing weights: min_out=128, min_in=128\n", "Extending output channels from 128 to 256\n", "Extending input channels from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.0.conv2.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.0.conv_shortcut.weight\n", "Old shape: torch.Size([128, 256, 1, 1])\n", "Target shape: torch.Size([256, 512, 1, 1])\n", "Interpolating tensor of shape torch.Size([128, 256, 1, 1]) to target shape torch.Size([256, 512, 1, 1])\n", "Copying existing weights: min_out=128, min_in=256\n", "Extending output channels from 128 to 256\n", "Extending input channels from 256 to 512\n", "\n", "Processing decoder.up_blocks.3.resnets.0.conv_shortcut.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.1.norm1.weight\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.1.norm1.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.1.conv1.weight\n", "Old shape: torch.Size([128, 128, 3, 3])\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", "Copying existing weights: min_out=128, min_in=128\n", "Extending output channels from 128 to 256\n", "Extending input channels from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.1.conv1.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.1.norm2.weight\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.1.norm2.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.1.conv2.weight\n", "Old shape: torch.Size([128, 128, 3, 3])\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", "Copying existing weights: min_out=128, min_in=128\n", "Extending output channels from 128 to 256\n", "Extending input channels from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.1.conv2.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.2.norm1.weight\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.2.norm1.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.2.conv1.weight\n", "Old shape: torch.Size([128, 128, 3, 3])\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", "Copying existing weights: min_out=128, min_in=128\n", "Extending output channels from 128 to 256\n", "Extending input channels from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.2.conv1.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.2.norm2.weight\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.2.norm2.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.2.conv2.weight\n", "Old shape: torch.Size([128, 128, 3, 3])\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Interpolating tensor of shape torch.Size([128, 128, 3, 3]) to target shape torch.Size([256, 256, 3, 3])\n", "Copying existing weights: min_out=128, min_in=128\n", "Extending output channels from 128 to 256\n", "Extending input channels from 128 to 256\n", "\n", "Processing decoder.up_blocks.3.resnets.2.conv2.bias\n", "Old shape: torch.Size([128])\n", "Target shape: torch.Size([256])\n", "Interpolating tensor of shape torch.Size([128]) to target shape torch.Size([256])\n", "Extending 1D tensor from 128 to 256\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Перенос весов: 100%|██████████| 286/286 [00:00<00:00, 163407.02it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.weight -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.weight -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_q.weight -> torch.Size([512, 512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_q.weight -> torch.Size([512, 512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_q.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_q.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_k.weight -> torch.Size([512, 512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_k.weight -> torch.Size([512, 512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_k.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_k.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_v.weight -> torch.Size([512, 512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_v.weight -> torch.Size([512, 512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_v.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_v.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_out.0.weight -> torch.Size([512, 512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_out.0.weight -> torch.Size([512, 512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_out.0.bias -> torch.Size([512])\n", "? Ключ пропущен: encoder.mid_block.attentions.0.to_out.0.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.conv_in.weight -> torch.Size([768, 16, 3, 3])\n", "? Ключ пропущен: decoder.conv_in.weight -> torch.Size([768, 16, 3, 3])\n", "? Ключ пропущен: decoder.conv_in.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.conv_in.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.0.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.1.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.resnets.2.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.upsamplers.0.conv.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.upsamplers.0.conv.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.0.upsamplers.0.conv.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.0.upsamplers.0.conv.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.0.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.1.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.resnets.2.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.upsamplers.0.conv.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.upsamplers.0.conv.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.1.upsamplers.0.conv.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.1.upsamplers.0.conv.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv1.weight -> torch.Size([512, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv1.weight -> torch.Size([512, 768, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm2.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm2.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.norm2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv2.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv2.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv_shortcut.weight -> torch.Size([512, 768, 1, 1])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv_shortcut.weight -> torch.Size([512, 768, 1, 1])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv_shortcut.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.0.conv_shortcut.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm1.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm1.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv1.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv1.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm2.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm2.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.norm2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv2.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv2.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.1.conv2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm1.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm1.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv1.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv1.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv1.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm2.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm2.weight -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.norm2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv2.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv2.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.resnets.2.conv2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.upsamplers.0.conv.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.upsamplers.0.conv.weight -> torch.Size([512, 512, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.2.upsamplers.0.conv.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.2.upsamplers.0.conv.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.up_blocks.3.upsamplers.0.conv.weight -> torch.Size([256, 256, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.3.upsamplers.0.conv.weight -> torch.Size([256, 256, 3, 3])\n", "? Ключ пропущен: decoder.up_blocks.3.upsamplers.0.conv.bias -> torch.Size([256])\n", "? Ключ пропущен: decoder.up_blocks.3.upsamplers.0.conv.bias -> torch.Size([256])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.group_norm.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.group_norm.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.group_norm.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.group_norm.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_q.weight -> torch.Size([768, 768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_q.weight -> torch.Size([768, 768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_q.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_q.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_k.weight -> torch.Size([768, 768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_k.weight -> torch.Size([768, 768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_k.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_k.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_v.weight -> torch.Size([768, 768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_v.weight -> torch.Size([768, 768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_v.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_v.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_out.0.weight -> torch.Size([768, 768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_out.0.weight -> torch.Size([768, 768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_out.0.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.attentions.0.to_out.0.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.0.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.norm1.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.norm1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.conv1.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.conv1.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.norm2.weight -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.norm2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.conv2.weight -> torch.Size([768, 768, 3, 3])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.mid_block.resnets.1.conv2.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.condition_encoder.layers.0.weight -> torch.Size([128, 3, 3, 3])\n", "? Ключ пропущен: decoder.condition_encoder.layers.0.weight -> torch.Size([128, 3, 3, 3])\n", "? Ключ пропущен: decoder.condition_encoder.layers.0.bias -> torch.Size([128])\n", "? Ключ пропущен: decoder.condition_encoder.layers.0.bias -> torch.Size([128])\n", "? Ключ пропущен: decoder.condition_encoder.layers.1.weight -> torch.Size([256, 128, 3, 3])\n", "? Ключ пропущен: decoder.condition_encoder.layers.1.weight -> torch.Size([256, 128, 3, 3])\n", "? Ключ пропущен: decoder.condition_encoder.layers.1.bias -> torch.Size([256])\n", "? Ключ пропущен: decoder.condition_encoder.layers.1.bias -> torch.Size([256])\n", "? Ключ пропущен: decoder.condition_encoder.layers.2.weight -> torch.Size([512, 256, 4, 4])\n", "? Ключ пропущен: decoder.condition_encoder.layers.2.weight -> torch.Size([512, 256, 4, 4])\n", "? Ключ пропущен: decoder.condition_encoder.layers.2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.condition_encoder.layers.2.bias -> torch.Size([512])\n", "? Ключ пропущен: decoder.condition_encoder.layers.3.weight -> torch.Size([768, 512, 4, 4])\n", "? Ключ пропущен: decoder.condition_encoder.layers.3.weight -> torch.Size([768, 512, 4, 4])\n", "? Ключ пропущен: decoder.condition_encoder.layers.3.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.condition_encoder.layers.3.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.condition_encoder.layers.4.weight -> torch.Size([768, 768, 4, 4])\n", "? Ключ пропущен: decoder.condition_encoder.layers.4.weight -> torch.Size([768, 768, 4, 4])\n", "? Ключ пропущен: decoder.condition_encoder.layers.4.bias -> torch.Size([768])\n", "? Ключ пропущен: decoder.condition_encoder.layers.4.bias -> torch.Size([768])\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "ename": "RuntimeError", "evalue": "Error(s) in loading state_dict for AsymmetricAutoencoderKL:\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[6], line 164\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28mprint\u001b[39m(key)\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 164\u001b[0m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[6], line 153\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;66;03m# Инициализируем веса для нового mid блока\u001b[39;00m\n\u001b[1;32m 150\u001b[0m new_state_dict \u001b[38;5;241m=\u001b[39m initialize_mid_block_weights(new_state_dict, device, dtype)\n\u001b[0;32m--> 153\u001b[0m \u001b[43mnew_unet\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew_state_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m new_unet\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvae\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 156\u001b[0m \u001b[38;5;66;03m# Выводим статистику\u001b[39;00m\n", "File \u001b[0;32m~/.local/lib/python3.11/site-packages/torch/nn/modules/module.py:2581\u001b[0m, in \u001b[0;36mModule.load_state_dict\u001b[0;34m(self, state_dict, strict, assign)\u001b[0m\n\u001b[1;32m 2573\u001b[0m error_msgs\u001b[38;5;241m.\u001b[39minsert(\n\u001b[1;32m 2574\u001b[0m \u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 2575\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMissing key(s) in state_dict: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2576\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m missing_keys)\n\u001b[1;32m 2577\u001b[0m ),\n\u001b[1;32m 2578\u001b[0m )\n\u001b[1;32m 2580\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(error_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 2581\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 2582\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mError(s) in loading state_dict for \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 2583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)\n\u001b[1;32m 2584\u001b[0m )\n\u001b[1;32m 2585\u001b[0m )\n\u001b[1;32m 2586\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _IncompatibleKeys(missing_keys, unexpected_keys)\n", "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for AsymmetricAutoencoderKL:\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).\n\tsize mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).\n\tsize mismatch for decoder.up_blocks.4.resnets.2.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128])." ] } ], "source": [ "import torch\n", "from diffusers import AsymmetricAutoencoderKL,AutoencoderKL\n", "from tqdm import tqdm\n", "import torch.nn.init as init\n", "\n", "def log(message):\n", " print(message)\n", "\n", "def initialize_mid_block_weights(state_dict, device, dtype):\n", " # Инициализация весов для mid block 0 с размерностью 512\n", " state_dict['encoder.mid_block.attentions.0.group_norm.weight'] = torch.ones(512, device=device, dtype=dtype)\n", " state_dict['encoder.mid_block.attentions.0.group_norm.bias'] = torch.zeros(512, device=device, dtype=dtype)\n", " \n", " # Удаляем ключи для второго блока внимания, так как он не существует в архитектуре\n", " #if 'encoder.mid_block.attentions.1.group_norm.weight' in state_dict:\n", " # del state_dict['encoder.mid_block.attentions.1.group_norm.weight']\n", " #if 'encoder.mid_block.attentions.1.group_norm.bias' in state_dict:\n", " # del state_dict['encoder.mid_block.attentions.1.group_norm.bias']\n", " \n", " return state_dict\n", " \n", "def interpolate_tensor(tensor, target_shape):\n", " \"\"\"Интерполяция тензора до целевой формы\"\"\"\n", " print(f\"Interpolating tensor of shape {tensor.shape} to target shape {target_shape}\")\n", " \n", " if len(tensor.shape) == 4: # Для свёрточных слоев\n", " out_channels, in_channels, k1, k2 = target_shape\n", " \n", " # Создаем новый тензор нужного размера\n", " result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)\n", " \n", " # Копируем существующие веса с масштабированием\n", " min_out = min(tensor.shape[0], out_channels)\n", " min_in = min(tensor.shape[1], in_channels)\n", " \n", " print(f\"Copying existing weights: min_out={min_out}, min_in={min_in}\")\n", " \n", " # Копируем существующие веса\n", " result[:min_out, :min_in, :, :] = tensor[:min_out, :min_in, :, :]\n", " \n", " # Заполняем новые выходные каналы\n", " if out_channels > min_out:\n", " print(f\"Extending output channels from {min_out} to {out_channels}\")\n", " result[min_out:, :min_in, :, :] = result[min_out-1:min_out, :min_in, :, :].repeat(out_channels-min_out, 1, 1, 1)\n", " \n", " # Заполняем новые входные каналы\n", " if in_channels > min_in:\n", " print(f\"Extending input channels from {min_in} to {in_channels}\")\n", " result[:, min_in:, :, :] = result[:, min_in-1:min_in, :, :].repeat(1, in_channels-min_in, 1, 1)\n", " \n", " return result\n", " \n", " else: # Для bias и других 1D тензоров\n", " # Создаем новый тензор нужного размера\n", " result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)\n", " \n", " # Копируем существующие значения\n", " min_size = min(tensor.shape[0], target_shape[0])\n", " result[:min_size] = tensor[:min_size]\n", " \n", " # Заполняем оставшиеся значения\n", " if target_shape[0] > min_size:\n", " print(f\"Extending 1D tensor from {min_size} to {target_shape[0]}\")\n", " result[min_size:] = result[min_size-1]\n", " \n", " return result\n", "\n", "def should_interpolate(key):\n", " \"\"\"Определяет, нужно ли интерполировать веса для данного ключа\"\"\"\n", " return any(x in key for x in [\n", " 'conv1', 'conv2', 'conv_shortcut', # свёрточные слои\n", " 'norm1', 'norm2', 'group_norm', # нормализационные слои\n", " 'bias', 'weight' # веса и смещения\n", " ])\n", " \n", "def main():\n", " checkpoint_path_old = \"AiArtLab/sdxs\"\n", " checkpoint_path_new = \"simple_vae\"\n", " device = \"cuda\"\n", " dtype = torch.float16\n", "\n", " # Загрузка моделей\n", " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder=\"vae\",variant=\"fp16\").to(device, dtype=dtype)\n", " new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n", "\n", " old_state_dict = old_unet.state_dict()\n", " new_state_dict = new_unet.state_dict()\n", "\n", " transferred_state_dict = {}\n", " transfer_stats = {\n", " \"перенесено\": 0,\n", " \"несовпадение_размеров\": 0,\n", " \"пропущено\": 0\n", " }\n", "\n", " transferred_keys = set()\n", "\n", " # Сначала найдем все ключи блока 3 и их интерполированные значения\n", " block3_interpolated = {}\n", " for key in old_state_dict:\n", " if 'decoder.up_blocks.3.' in key and should_interpolate(key):\n", " old_tensor = old_state_dict[key]\n", " new_key = key\n", " if new_key in new_state_dict and old_tensor.shape != new_state_dict[new_key].shape:\n", " print(f\"\\nProcessing {key}\")\n", " print(f\"Old shape: {old_tensor.shape}\")\n", " print(f\"Target shape: {new_state_dict[new_key].shape}\")\n", " interpolated = interpolate_tensor(old_tensor, new_state_dict[new_key].shape)\n", " block3_interpolated[key] = interpolated\n", "\n", " # Обрабатываем каждый ключ новой модели\n", " for new_key in tqdm(new_state_dict.keys(), desc=\"Перенос весов\"):\n", " # Случай 1: Прямое соответствие ключей и размеров\n", " if new_key in old_state_dict and old_state_dict[new_key].shape == new_state_dict[new_key].shape:\n", " transferred_state_dict[new_key] = old_state_dict[new_key].clone()\n", " transferred_keys.add(new_key)\n", " transfer_stats[\"перенесено\"] += 1\n", " continue\n", "\n", " # Случай 2: Блоки 4 и 5 (копируем интерполированные веса блока 3)\n", " if ('decoder.up_blocks.4.' in new_key or 'decoder.up_blocks.5.' in new_key) and should_interpolate(new_key):\n", " source_key = new_key.replace('decoder.up_blocks.4.', 'decoder.up_blocks.3.')\n", " source_key = source_key.replace('decoder.up_blocks.5.', 'decoder.up_blocks.3.')\n", " \n", " if source_key in block3_interpolated:\n", " transferred_state_dict[new_key] = block3_interpolated[source_key].clone()\n", " transferred_keys.add(new_key)\n", " transfer_stats[\"перенесено\"] += 1\n", " continue\n", "\n", " # Случай 3: Несовпадение размеров в блоке 3\n", " if 'decoder.up_blocks.3.' in new_key and new_key in block3_interpolated:\n", " transferred_state_dict[new_key] = block3_interpolated[new_key].clone()\n", " transferred_keys.add(new_key)\n", " transfer_stats[\"перенесено\"] += 1\n", " continue\n", "\n", " # Если ключ не обработан - помечаем как пропущенный\n", " transfer_stats[\"пропущено\"] += 1\n", " log(f\"? Ключ пропущен: {new_key} -> {new_state_dict[new_key].shape}\")\n", "\n", " # Если ключ не обработан - помечаем как пропущенный\n", " transfer_stats[\"пропущено\"] += 1\n", " log(f\"? Ключ пропущен: {new_key} -> {new_state_dict[new_key].shape}\")\n", "\n", " # Обновляем состояние новой модели\n", " new_state_dict.update(transferred_state_dict)\n", "\n", " # Инициализируем веса для нового mid блока\n", " new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n", "\n", " \n", " new_unet.load_state_dict(new_state_dict)\n", " new_unet.save_pretrained(\"vae\")\n", "\n", " # Выводим статистику\n", " print(\"\\nСтатистика переноса:\", transfer_stats)\n", " print(\"\\nНеперенесенные ключи:\")\n", " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n", " for key in non_transferred_keys:\n", " print(key)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "code", "execution_count": 7, "id": "b3018b9a-82cf-4435-8afd-4448ff5e83a9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading models...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Processing weights...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 286/286 [00:00<00:00, 106458.20it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Processing key: encoder.conv_in.weight\n", "Target shape: torch.Size([128, 3, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.conv_in.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.0.norm1.weight\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.0.norm1.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.0.conv1.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.0.conv1.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.0.norm2.weight\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.0.norm2.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.0.conv2.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.0.conv2.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.1.norm1.weight\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.1.norm1.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.1.conv1.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.1.conv1.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.1.norm2.weight\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.1.norm2.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.1.conv2.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.resnets.1.conv2.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.downsamplers.0.conv.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.0.downsamplers.0.conv.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.norm1.weight\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.norm1.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.conv1.weight\n", "Target shape: torch.Size([256, 128, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.conv1.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.norm2.weight\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.norm2.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.conv2.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.conv2.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.conv_shortcut.weight\n", "Target shape: torch.Size([256, 128, 1, 1])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.0.conv_shortcut.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.1.norm1.weight\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.1.norm1.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.1.conv1.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.1.conv1.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.1.norm2.weight\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.1.norm2.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.1.conv2.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.resnets.1.conv2.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.downsamplers.0.conv.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.1.downsamplers.0.conv.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.norm1.weight\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.norm1.bias\n", "Target shape: torch.Size([256])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.conv1.weight\n", "Target shape: torch.Size([512, 256, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.conv1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.norm2.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.norm2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.conv2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.conv_shortcut.weight\n", "Target shape: torch.Size([512, 256, 1, 1])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.0.conv_shortcut.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.1.norm1.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.1.norm1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.1.conv1.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.1.conv1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.1.norm2.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.1.norm2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.1.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.resnets.1.conv2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.downsamplers.0.conv.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.2.downsamplers.0.conv.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.0.norm1.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.0.norm1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.0.conv1.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.0.conv1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.0.norm2.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.0.norm2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.0.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.0.conv2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.1.norm1.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.1.norm1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.1.conv1.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.1.conv1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.1.norm2.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.1.norm2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.1.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.down_blocks.3.resnets.1.conv2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.attentions.0.group_norm.weight\n", "Target shape: torch.Size([512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.group_norm.bias\n", "Target shape: torch.Size([512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.to_q.weight\n", "Target shape: torch.Size([512, 512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.to_q.bias\n", "Target shape: torch.Size([512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.to_k.weight\n", "Target shape: torch.Size([512, 512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.to_k.bias\n", "Target shape: torch.Size([512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.to_v.weight\n", "Target shape: torch.Size([512, 512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.to_v.bias\n", "Target shape: torch.Size([512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.to_out.0.weight\n", "Target shape: torch.Size([512, 512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.attentions.0.to_out.0.bias\n", "Target shape: torch.Size([512])\n", "Key not found in source model\n", "\n", "Processing key: encoder.mid_block.resnets.0.norm1.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.0.norm1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.0.conv1.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.0.conv1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.0.norm2.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.0.norm2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.0.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.0.conv2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.1.norm1.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.1.norm1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.1.conv1.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.1.conv1.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.1.norm2.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.1.norm2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.1.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.mid_block.resnets.1.conv2.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.conv_norm_out.weight\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.conv_norm_out.bias\n", "Target shape: torch.Size([512])\n", "Direct copy...\n", "\n", "Processing key: encoder.conv_out.weight\n", "Target shape: torch.Size([32, 512, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: encoder.conv_out.bias\n", "Target shape: torch.Size([32])\n", "Direct copy...\n", "\n", "Processing key: decoder.conv_in.weight\n", "Target shape: torch.Size([768, 16, 3, 3])\n", "Size mismatch: torch.Size([512, 16, 3, 3]) vs torch.Size([768, 16, 3, 3])\n", "\n", "Processing key: decoder.conv_in.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.0.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.0.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.0.conv1.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.0.conv1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.0.norm2.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.0.norm2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.0.conv2.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.0.conv2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.1.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.1.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.1.conv1.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.1.conv1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.1.norm2.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.1.norm2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.1.conv2.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.1.conv2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.2.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.2.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.2.conv1.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.2.conv1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.2.norm2.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.2.norm2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.2.conv2.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.0.resnets.2.conv2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.0.upsamplers.0.conv.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.0.upsamplers.0.conv.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.0.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.0.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.0.conv1.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.0.conv1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.0.norm2.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.0.norm2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.0.conv2.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.0.conv2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.1.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.1.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.1.conv1.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.1.conv1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.1.norm2.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.1.norm2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.1.conv2.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.1.conv2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.2.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.2.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.2.conv1.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.2.conv1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.2.norm2.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.2.norm2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.2.conv2.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.1.resnets.2.conv2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.1.upsamplers.0.conv.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.1.upsamplers.0.conv.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.conv1.weight\n", "Target shape: torch.Size([512, 768, 3, 3])\n", "Size mismatch: torch.Size([256, 512, 3, 3]) vs torch.Size([512, 768, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.conv1.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.norm2.weight\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.norm2.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.conv2.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.conv_shortcut.weight\n", "Target shape: torch.Size([512, 768, 1, 1])\n", "Size mismatch: torch.Size([256, 512, 1, 1]) vs torch.Size([512, 768, 1, 1])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.0.conv_shortcut.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.1.norm1.weight\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.1.norm1.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.1.conv1.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.1.conv1.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.1.norm2.weight\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.1.norm2.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.1.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.1.conv2.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.2.norm1.weight\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.2.norm1.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.2.conv1.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.2.conv1.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.2.norm2.weight\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.2.norm2.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.2.conv2.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.2.resnets.2.conv2.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.2.upsamplers.0.conv.weight\n", "Target shape: torch.Size([512, 512, 3, 3])\n", "Size mismatch: torch.Size([256, 256, 3, 3]) vs torch.Size([512, 512, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.2.upsamplers.0.conv.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.norm1.weight\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.norm1.bias\n", "Target shape: torch.Size([512])\n", "Size mismatch: torch.Size([256]) vs torch.Size([512])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.conv1.weight\n", "Target shape: torch.Size([256, 512, 3, 3])\n", "Size mismatch: torch.Size([128, 256, 3, 3]) vs torch.Size([256, 512, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.conv1.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.norm2.weight\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.norm2.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.conv2.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.conv2.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.conv_shortcut.weight\n", "Target shape: torch.Size([256, 512, 1, 1])\n", "Size mismatch: torch.Size([128, 256, 1, 1]) vs torch.Size([256, 512, 1, 1])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.0.conv_shortcut.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.1.norm1.weight\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.1.norm1.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.1.conv1.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.1.conv1.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.1.norm2.weight\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.1.norm2.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.1.conv2.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.1.conv2.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.2.norm1.weight\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.2.norm1.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.2.conv1.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.2.conv1.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.2.norm2.weight\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.2.norm2.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.2.conv2.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Size mismatch: torch.Size([128, 128, 3, 3]) vs torch.Size([256, 256, 3, 3])\n", "\n", "Processing key: decoder.up_blocks.3.resnets.2.conv2.bias\n", "Target shape: torch.Size([256])\n", "Size mismatch: torch.Size([128]) vs torch.Size([256])\n", "\n", "Processing key: decoder.up_blocks.3.upsamplers.0.conv.weight\n", "Target shape: torch.Size([256, 256, 3, 3])\n", "Key not found in source model\n", "\n", "Processing key: decoder.up_blocks.3.upsamplers.0.conv.bias\n", "Target shape: torch.Size([256])\n", "Key not found in source model\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.norm1.weight\n", "Target shape: torch.Size([256])\n", "Found source key: decoder.up_blocks.3.resnets.0.norm1.weight\n", "Source shape: torch.Size([256])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.norm1.bias\n", "Target shape: torch.Size([256])\n", "Found source key: decoder.up_blocks.3.resnets.0.norm1.bias\n", "Source shape: torch.Size([256])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.conv1.weight\n", "Target shape: torch.Size([128, 256, 3, 3])\n", "Found source key: decoder.up_blocks.3.resnets.0.conv1.weight\n", "Source shape: torch.Size([128, 256, 3, 3])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.conv1.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.0.conv1.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.norm2.weight\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.0.norm2.weight\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.norm2.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.0.norm2.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.conv2.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Found source key: decoder.up_blocks.3.resnets.0.conv2.weight\n", "Source shape: torch.Size([128, 128, 3, 3])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.conv2.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.0.conv2.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.conv_shortcut.weight\n", "Target shape: torch.Size([128, 256, 1, 1])\n", "Found source key: decoder.up_blocks.3.resnets.0.conv_shortcut.weight\n", "Source shape: torch.Size([128, 256, 1, 1])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.0.conv_shortcut.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.0.conv_shortcut.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.1.norm1.weight\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.1.norm1.weight\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.1.norm1.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.1.norm1.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.1.conv1.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Found source key: decoder.up_blocks.3.resnets.1.conv1.weight\n", "Source shape: torch.Size([128, 128, 3, 3])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.1.conv1.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.1.conv1.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.1.norm2.weight\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.1.norm2.weight\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.1.norm2.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.1.norm2.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.1.conv2.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Found source key: decoder.up_blocks.3.resnets.1.conv2.weight\n", "Source shape: torch.Size([128, 128, 3, 3])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.1.conv2.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.1.conv2.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.2.norm1.weight\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.2.norm1.weight\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.2.norm1.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.2.norm1.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.2.conv1.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Found source key: decoder.up_blocks.3.resnets.2.conv1.weight\n", "Source shape: torch.Size([128, 128, 3, 3])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.2.conv1.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.2.conv1.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.2.norm2.weight\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.2.norm2.weight\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.2.norm2.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.2.norm2.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.2.conv2.weight\n", "Target shape: torch.Size([128, 128, 3, 3])\n", "Found source key: decoder.up_blocks.3.resnets.2.conv2.weight\n", "Source shape: torch.Size([128, 128, 3, 3])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.up_blocks.4.resnets.2.conv2.bias\n", "Target shape: torch.Size([128])\n", "Found source key: decoder.up_blocks.3.resnets.2.conv2.bias\n", "Source shape: torch.Size([128])\n", "Shapes match, copying directly...\n", "\n", "Processing key: decoder.mid_block.attentions.0.group_norm.weight\n", "Target shape: torch.Size([768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.group_norm.bias\n", "Target shape: torch.Size([768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.to_q.weight\n", "Target shape: torch.Size([768, 768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.to_q.bias\n", "Target shape: torch.Size([768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.to_k.weight\n", "Target shape: torch.Size([768, 768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.to_k.bias\n", "Target shape: torch.Size([768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.to_v.weight\n", "Target shape: torch.Size([768, 768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.to_v.bias\n", "Target shape: torch.Size([768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.to_out.0.weight\n", "Target shape: torch.Size([768, 768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.attentions.0.to_out.0.bias\n", "Target shape: torch.Size([768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.mid_block.resnets.0.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.0.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.0.conv1.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.mid_block.resnets.0.conv1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.0.norm2.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.0.norm2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.0.conv2.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.mid_block.resnets.0.conv2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.1.norm1.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.1.norm1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.1.conv1.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.mid_block.resnets.1.conv1.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.1.norm2.weight\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.1.norm2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.mid_block.resnets.1.conv2.weight\n", "Target shape: torch.Size([768, 768, 3, 3])\n", "Size mismatch: torch.Size([512, 512, 3, 3]) vs torch.Size([768, 768, 3, 3])\n", "\n", "Processing key: decoder.mid_block.resnets.1.conv2.bias\n", "Target shape: torch.Size([768])\n", "Size mismatch: torch.Size([512]) vs torch.Size([768])\n", "\n", "Processing key: decoder.condition_encoder.layers.0.weight\n", "Target shape: torch.Size([128, 3, 3, 3])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.0.bias\n", "Target shape: torch.Size([128])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.1.weight\n", "Target shape: torch.Size([256, 128, 3, 3])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.1.bias\n", "Target shape: torch.Size([256])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.2.weight\n", "Target shape: torch.Size([512, 256, 4, 4])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.2.bias\n", "Target shape: torch.Size([512])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.3.weight\n", "Target shape: torch.Size([768, 512, 4, 4])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.3.bias\n", "Target shape: torch.Size([768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.4.weight\n", "Target shape: torch.Size([768, 768, 4, 4])\n", "Key not found in source model\n", "\n", "Processing key: decoder.condition_encoder.layers.4.bias\n", "Target shape: torch.Size([768])\n", "Key not found in source model\n", "\n", "Processing key: decoder.conv_norm_out.weight\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: decoder.conv_norm_out.bias\n", "Target shape: torch.Size([128])\n", "Direct copy...\n", "\n", "Processing key: decoder.conv_out.weight\n", "Target shape: torch.Size([3, 128, 3, 3])\n", "Direct copy...\n", "\n", "Processing key: decoder.conv_out.bias\n", "Target shape: torch.Size([3])\n", "Direct copy...\n", "\n", "Processing key: quant_conv.weight\n", "Target shape: torch.Size([32, 32, 1, 1])\n", "Direct copy...\n", "\n", "Processing key: quant_conv.bias\n", "Target shape: torch.Size([32])\n", "Direct copy...\n", "\n", "Processing key: post_quant_conv.weight\n", "Target shape: torch.Size([16, 16, 1, 1])\n", "Direct copy...\n", "\n", "Processing key: post_quant_conv.bias\n", "Target shape: torch.Size([16])\n", "Direct copy...\n", "\n", "Updating state dict...\n", "\n", "Loading state dict...\n", "\n", "Saving model...\n", "\n", "Transfer statistics: {'перенесено': 130, 'несовпадение_размеров': 124, 'пропущено': 32}\n" ] } ], "source": [ "import torch\n", "from diffusers import AsymmetricAutoencoderKL,AutoencoderKL\n", "from tqdm import tqdm\n", "\n", "def log(message):\n", " print(message)\n", "\n", "def interpolate_tensor(tensor, target_shape):\n", " \"\"\"Интерполяция тензора до целевой формы\"\"\"\n", " print(f\"Interpolating tensor of shape {tensor.shape} to target shape {target_shape}\")\n", " \n", " # Создаем новый тензор нужного размера\n", " result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)\n", " \n", " if len(tensor.shape) == 1: # Для 1D тензоров (bias)\n", " min_size = min(tensor.shape[0], target_shape[0])\n", " result[:min_size] = tensor[:min_size]\n", " if target_shape[0] > min_size:\n", " result[min_size:] = tensor[min_size-1]\n", " else: # Для всех остальных тензоров\n", " # Просто копируем то, что можем, остальное оставляем нулями\n", " if len(tensor.shape) == len(target_shape):\n", " for i in range(len(tensor.shape)):\n", " if tensor.shape[i] > target_shape[i]:\n", " tensor = tensor.narrow(i, 0, target_shape[i])\n", " result[tuple(slice(0, s) for s in tensor.shape)] = tensor\n", " \n", " return result\n", "\n", "def main():\n", " checkpoint_path_old = \"AiArtLab/sdxs\"\n", " checkpoint_path_new = \"simple_vae\"\n", " device = \"cuda\"\n", " dtype = torch.float16\n", "\n", " print(\"Loading models...\")\n", " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder=\"vae\",variant=\"fp16\").to(device, dtype=dtype)\n", " new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n", "\n", " old_state_dict = old_unet.state_dict()\n", " new_state_dict = new_unet.state_dict()\n", "\n", " transferred_state_dict = {}\n", " transfer_stats = {\"перенесено\": 0, \"несовпадение_размеров\": 0, \"пропущено\": 0}\n", "\n", " print(\"\\nProcessing weights...\")\n", " for new_key in tqdm(new_state_dict.keys()):\n", " print(f\"\\nProcessing key: {new_key}\")\n", " print(f\"Target shape: {new_state_dict[new_key].shape}\")\n", "\n", " # Для блоков 4 и 5 используем веса из блока 3\n", " if 'decoder.up_blocks.4.' in new_key or 'decoder.up_blocks.5.' in new_key:\n", " source_key = new_key.replace('decoder.up_blocks.4.', 'decoder.up_blocks.3.')\n", " source_key = source_key.replace('decoder.up_blocks.5.', 'decoder.up_blocks.3.')\n", " \n", " if source_key in old_state_dict:\n", " print(f\"Found source key: {source_key}\")\n", " source_tensor = old_state_dict[source_key]\n", " print(f\"Source shape: {source_tensor.shape}\")\n", " \n", " if source_tensor.shape != new_state_dict[new_key].shape:\n", " print(\"Shapes don't match, interpolating...\")\n", " transferred_state_dict[new_key] = interpolate_tensor(source_tensor, new_state_dict[new_key].shape)\n", " else:\n", " print(\"Shapes match, copying directly...\")\n", " transferred_state_dict[new_key] = source_tensor.clone()\n", " transfer_stats[\"перенесено\"] += 1\n", " continue\n", "\n", " # Для остальных ключей пробуем прямой перенос\n", " if new_key in old_state_dict:\n", " if old_state_dict[new_key].shape == new_state_dict[new_key].shape:\n", " print(\"Direct copy...\")\n", " transferred_state_dict[new_key] = old_state_dict[new_key].clone()\n", " transfer_stats[\"перенесено\"] += 1\n", " else:\n", " print(f\"Size mismatch: {old_state_dict[new_key].shape} vs {new_state_dict[new_key].shape}\")\n", " transfer_stats[\"несовпадение_размеров\"] += 1\n", " else:\n", " print(\"Key not found in source model\")\n", " transfer_stats[\"пропущено\"] += 1\n", "\n", " print(\"\\nUpdating state dict...\")\n", " new_state_dict.update(transferred_state_dict)\n", "\n", " print(\"\\nLoading state dict...\")\n", " new_unet.load_state_dict(new_state_dict)\n", "\n", " print(\"\\nSaving model...\")\n", " new_unet.save_pretrained(\"vae\")\n", "\n", " print(\"\\nTransfer statistics:\", transfer_stats)\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "code", "execution_count": null, "id": "54b1c67b-8eab-4bca-9cfe-b05bcc966abb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }