recoilme commited on
Commit
452841e
·
verified ·
1 Parent(s): 228e9c9

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +73 -0
  2. diffusion_pytorch_model.safetensors +3 -0
  3. unet1.3b.ipynb +1063 -0
config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "unet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 128,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1.0,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 128,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": null,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": [
60
+ 2,
61
+ 2,
62
+ 3,
63
+ 3
64
+ ],
65
+ "up_block_types": [
66
+ "UpBlock2D",
67
+ "CrossAttnUpBlock2D",
68
+ "CrossAttnUpBlock2D",
69
+ "CrossAttnUpBlock2D"
70
+ ],
71
+ "upcast_attention": false,
72
+ "use_linear_projection": false
73
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfdd24ecfa87e4d096bb769ef3b84ee7dab25a69679e4f6b85967bcc97ac8272
3
+ size 5166216656
unet1.3b.ipynb ADDED
@@ -0,0 +1,1063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "82ca7882-410c-4067-863a-07838d485f6a",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "test unet\n",
14
+ "Количество параметров: 1344407376\n",
15
+ "Output shape: torch.Size([1, 16, 60, 48])\n",
16
+ "Output shape: torch.Size([1, 16, 60, 48])\n"
17
+ ]
18
+ }
19
+ ],
20
+ "source": [
21
+ "config_sdxs = {\n",
22
+ " # === Основные размеры и каналы ===\n",
23
+ " \"in_channels\": 16, # Количество входных каналов (совместимость с 16-канальным VAE)\n",
24
+ " \"out_channels\": 16, # Количество выходных каналов (симметрично in_channels)\n",
25
+ " \"center_input_sample\": False, # Отключение центрирования входных данных (стандарт для диффузионных моделей)\n",
26
+ " \"flip_sin_to_cos\": True, # Автоматическое преобразование sin/cos в эмбеддингах времени (для стабильности)\n",
27
+ " \"freq_shift\": 0, # Сдвиг частоты (0 - стандартное значение для частотных эмбеддингов)\n",
28
+ "\n",
29
+ " # === Архитектура блоков ===\n",
30
+ " \"down_block_types\": [ # Типы блоков энкодера (иерархия обработки):\n",
31
+ " \"CrossAttnDownBlock2D\",\n",
32
+ " \"CrossAttnDownBlock2D\",\n",
33
+ " \"CrossAttnDownBlock2D\",\n",
34
+ " \"DownBlock2D\"\n",
35
+ " ],\n",
36
+ " \"mid_block_type\": \"UNetMidBlock2DCrossAttn\", # Центральный блок с cross-attention (бутылочное горлышко сети)\n",
37
+ " \"up_block_types\": [ # Типы блоков декодера (восстановление изображения):\n",
38
+ " \"UpBlock2D\",\n",
39
+ " \"CrossAttnUpBlock2D\",\n",
40
+ " \"CrossAttnUpBlock2D\",\n",
41
+ " \"CrossAttnUpBlock2D\",\n",
42
+ " ],\n",
43
+ " \"only_cross_attention\": False, # Использование как cross-attention, так и self-attention\n",
44
+ "\n",
45
+ " # === Конфигурация каналов ===\n",
46
+ " \"block_out_channels\": [320, 640, 1280, 1280], \n",
47
+ " \"layers_per_block\": 2, # Число слоев в блоках\n",
48
+ " \"downsample_padding\": 1, # Паддинг при уменьшении разрешения\n",
49
+ " \"mid_block_scale_factor\": 1.0, # Усиление сигнала в центральном блоке\n",
50
+ "\n",
51
+ " # === Нормализация ===\n",
52
+ " \"norm_num_groups\": 32, # Число групп для GroupNorm (оптимально для стабильности)\n",
53
+ " \"norm_eps\": 1e-05, # Эпсилон для нормализации (стандартное значение)\n",
54
+ "\n",
55
+ " # === Cross-Attention ===\n",
56
+ " \"cross_attention_dim\": 768, # Размерность текстовых эмбеддинго\n",
57
+ " \n",
58
+ " \"transformer_layers_per_block\": 3, # Число трансформерных слоев (уменьшение с глубиной)\n",
59
+ " \"attention_head_dim\": 8, # Размерность головы внимания \n",
60
+ " \"dual_cross_attention\": False, # Отключение двойного внимания (упрощение архитектуры)\n",
61
+ " \"use_linear_projection\": False, # Изменено на True для лучшей организации памяти\n",
62
+ "\n",
63
+ " # === ResNet Блоки ===\n",
64
+ " \"resnet_time_scale_shift\": \"default\", # Способ интеграции временных эмбеддингов\n",
65
+ " \"resnet_skip_time_act\": False, # Отключение активации в skip-соединениях\n",
66
+ " \"resnet_out_scale_factor\": 1.0, # Коэффициент масштабирования выхода ResNet\n",
67
+ "\n",
68
+ " # === Временные эмбеддинги ===\n",
69
+ " \"time_embedding_type\": \"positional\", # Тип временных эмбеддингов (стандартный подход)\n",
70
+ "\n",
71
+ " # === Свертки ===\n",
72
+ " \"conv_in_kernel\": 3, # Ядро входной свертки (баланс между рецептивным полем и параметрами)\n",
73
+ " \"conv_out_kernel\": 3, # Ядро выходной свертки (симметрично входной)\n",
74
+ "}\n",
75
+ "\n",
76
+ "if 1:\n",
77
+ " checkpoint_path = \"sd15_tmp\"#\"sdxs\"\n",
78
+ " import torch\n",
79
+ " from diffusers import UNet2DConditionModel\n",
80
+ " print(\"test unet\")\n",
81
+ " new_unet = UNet2DConditionModel(**config_sdxs).to(\"cuda\", dtype=torch.float16)\n",
82
+ "\n",
83
+ " assert all(ch % 32 == 0 for ch in new_unet.config[\"block_out_channels\"]), \"Каналы должны быть кратны 32\"\n",
84
+ " num_params = sum(p.numel() for p in new_unet.parameters())\n",
85
+ " print(f\"Количество параметров: {num_params}\")\n",
86
+ "\n",
87
+ " # Генерация тестового латента (640x512 в latent space)\n",
88
+ " test_latent = torch.randn(1, 16, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n",
89
+ " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n",
90
+ " encoder_hidden_states = torch.randn(1, 77, 768).to(\"cuda\", dtype=torch.float16)\n",
91
+ " \n",
92
+ " with torch.no_grad():\n",
93
+ " output = new_unet(\n",
94
+ " test_latent, \n",
95
+ " timesteps, \n",
96
+ " encoder_hidden_states\n",
97
+ " ).sample\n",
98
+ " \n",
99
+ " print(f\"Output shape: {output.shape}\") \n",
100
+ " new_unet.save_pretrained(checkpoint_path)\n",
101
+ " #print(new_unet)\n",
102
+ " del new_unet\n",
103
+ " torch.cuda.empty_cache()\n",
104
+ " print(f\"Output shape: {output.shape}\") \n",
105
+ " # Количество параметров: 1101998736 1344407376"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 3,
111
+ "id": "f980bb1a-9859-44c2-a2df-ff1b073bf435",
112
+ "metadata": {},
113
+ "outputs": [
114
+ {
115
+ "name": "stderr",
116
+ "output_type": "stream",
117
+ "text": [
118
+ "Перенос весов: 100%|██████████| 1006/1006 [00:00<00:00, 36208.99it/s]\n"
119
+ ]
120
+ },
121
+ {
122
+ "name": "stdout",
123
+ "output_type": "stream",
124
+ "text": [
125
+ "Статистика переноса: {'перенесено': 1006, 'несовпадение_размеров': 0, 'пропущено': 0}\n",
126
+ "Неперенесенные ключи в новой модели:\n",
127
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k.weight\n",
128
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n",
129
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n",
130
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q.weight\n",
131
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v.weight\n",
132
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k.weight\n",
133
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n",
134
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n",
135
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q.weight\n",
136
+ "down_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v.weight\n",
137
+ "down_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n",
138
+ "down_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n",
139
+ "down_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.bias\n",
140
+ "down_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.weight\n",
141
+ "down_blocks.0.attentions.0.transformer_blocks.2.norm1.bias\n",
142
+ "down_blocks.0.attentions.0.transformer_blocks.2.norm1.weight\n",
143
+ "down_blocks.0.attentions.0.transformer_blocks.2.norm2.bias\n",
144
+ "down_blocks.0.attentions.0.transformer_blocks.2.norm2.weight\n",
145
+ "down_blocks.0.attentions.0.transformer_blocks.2.norm3.bias\n",
146
+ "down_blocks.0.attentions.0.transformer_blocks.2.norm3.weight\n",
147
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k.weight\n",
148
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n",
149
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n",
150
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q.weight\n",
151
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v.weight\n",
152
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k.weight\n",
153
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n",
154
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n",
155
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q.weight\n",
156
+ "down_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v.weight\n",
157
+ "down_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n",
158
+ "down_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n",
159
+ "down_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.bias\n",
160
+ "down_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.weight\n",
161
+ "down_blocks.0.attentions.1.transformer_blocks.2.norm1.bias\n",
162
+ "down_blocks.0.attentions.1.transformer_blocks.2.norm1.weight\n",
163
+ "down_blocks.0.attentions.1.transformer_blocks.2.norm2.bias\n",
164
+ "down_blocks.0.attentions.1.transformer_blocks.2.norm2.weight\n",
165
+ "down_blocks.0.attentions.1.transformer_blocks.2.norm3.bias\n",
166
+ "down_blocks.0.attentions.1.transformer_blocks.2.norm3.weight\n",
167
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight\n",
168
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n",
169
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n",
170
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight\n",
171
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight\n",
172
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight\n",
173
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n",
174
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n",
175
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight\n",
176
+ "down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight\n",
177
+ "down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n",
178
+ "down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n",
179
+ "down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias\n",
180
+ "down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight\n",
181
+ "down_blocks.1.attentions.0.transformer_blocks.2.norm1.bias\n",
182
+ "down_blocks.1.attentions.0.transformer_blocks.2.norm1.weight\n",
183
+ "down_blocks.1.attentions.0.transformer_blocks.2.norm2.bias\n",
184
+ "down_blocks.1.attentions.0.transformer_blocks.2.norm2.weight\n",
185
+ "down_blocks.1.attentions.0.transformer_blocks.2.norm3.bias\n",
186
+ "down_blocks.1.attentions.0.transformer_blocks.2.norm3.weight\n",
187
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight\n",
188
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n",
189
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n",
190
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight\n",
191
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight\n",
192
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight\n",
193
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n",
194
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n",
195
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight\n",
196
+ "down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight\n",
197
+ "down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n",
198
+ "down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n",
199
+ "down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias\n",
200
+ "down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight\n",
201
+ "down_blocks.1.attentions.1.transformer_blocks.2.norm1.bias\n",
202
+ "down_blocks.1.attentions.1.transformer_blocks.2.norm1.weight\n",
203
+ "down_blocks.1.attentions.1.transformer_blocks.2.norm2.bias\n",
204
+ "down_blocks.1.attentions.1.transformer_blocks.2.norm2.weight\n",
205
+ "down_blocks.1.attentions.1.transformer_blocks.2.norm3.bias\n",
206
+ "down_blocks.1.attentions.1.transformer_blocks.2.norm3.weight\n",
207
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight\n",
208
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n",
209
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n",
210
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight\n",
211
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight\n",
212
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight\n",
213
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n",
214
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n",
215
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight\n",
216
+ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight\n",
217
+ "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n",
218
+ "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n",
219
+ "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias\n",
220
+ "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight\n",
221
+ "down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias\n",
222
+ "down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight\n",
223
+ "down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias\n",
224
+ "down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight\n",
225
+ "down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias\n",
226
+ "down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight\n",
227
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight\n",
228
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n",
229
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n",
230
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight\n",
231
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight\n",
232
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight\n",
233
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n",
234
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n",
235
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight\n",
236
+ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight\n",
237
+ "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n",
238
+ "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n",
239
+ "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias\n",
240
+ "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight\n",
241
+ "down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias\n",
242
+ "down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight\n",
243
+ "down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias\n",
244
+ "down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight\n",
245
+ "down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias\n",
246
+ "down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight\n",
247
+ "mid_block.attentions.0.transformer_blocks.2.attn1.to_k.weight\n",
248
+ "mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n",
249
+ "mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n",
250
+ "mid_block.attentions.0.transformer_blocks.2.attn1.to_q.weight\n",
251
+ "mid_block.attentions.0.transformer_blocks.2.attn1.to_v.weight\n",
252
+ "mid_block.attentions.0.transformer_blocks.2.attn2.to_k.weight\n",
253
+ "mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n",
254
+ "mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n",
255
+ "mid_block.attentions.0.transformer_blocks.2.attn2.to_q.weight\n",
256
+ "mid_block.attentions.0.transformer_blocks.2.attn2.to_v.weight\n",
257
+ "mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n",
258
+ "mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n",
259
+ "mid_block.attentions.0.transformer_blocks.2.ff.net.2.bias\n",
260
+ "mid_block.attentions.0.transformer_blocks.2.ff.net.2.weight\n",
261
+ "mid_block.attentions.0.transformer_blocks.2.norm1.bias\n",
262
+ "mid_block.attentions.0.transformer_blocks.2.norm1.weight\n",
263
+ "mid_block.attentions.0.transformer_blocks.2.norm2.bias\n",
264
+ "mid_block.attentions.0.transformer_blocks.2.norm2.weight\n",
265
+ "mid_block.attentions.0.transformer_blocks.2.norm3.bias\n",
266
+ "mid_block.attentions.0.transformer_blocks.2.norm3.weight\n",
267
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight\n",
268
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n",
269
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n",
270
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight\n",
271
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight\n",
272
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight\n",
273
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n",
274
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n",
275
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight\n",
276
+ "up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight\n",
277
+ "up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n",
278
+ "up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n",
279
+ "up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias\n",
280
+ "up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight\n",
281
+ "up_blocks.1.attentions.0.transformer_blocks.2.norm1.bias\n",
282
+ "up_blocks.1.attentions.0.transformer_blocks.2.norm1.weight\n",
283
+ "up_blocks.1.attentions.0.transformer_blocks.2.norm2.bias\n",
284
+ "up_blocks.1.attentions.0.transformer_blocks.2.norm2.weight\n",
285
+ "up_blocks.1.attentions.0.transformer_blocks.2.norm3.bias\n",
286
+ "up_blocks.1.attentions.0.transformer_blocks.2.norm3.weight\n",
287
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight\n",
288
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n",
289
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n",
290
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight\n",
291
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight\n",
292
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight\n",
293
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n",
294
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n",
295
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight\n",
296
+ "up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight\n",
297
+ "up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n",
298
+ "up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n",
299
+ "up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias\n",
300
+ "up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight\n",
301
+ "up_blocks.1.attentions.1.transformer_blocks.2.norm1.bias\n",
302
+ "up_blocks.1.attentions.1.transformer_blocks.2.norm1.weight\n",
303
+ "up_blocks.1.attentions.1.transformer_blocks.2.norm2.bias\n",
304
+ "up_blocks.1.attentions.1.transformer_blocks.2.norm2.weight\n",
305
+ "up_blocks.1.attentions.1.transformer_blocks.2.norm3.bias\n",
306
+ "up_blocks.1.attentions.1.transformer_blocks.2.norm3.weight\n",
307
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_k.weight\n",
308
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.bias\n",
309
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.weight\n",
310
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_q.weight\n",
311
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_v.weight\n",
312
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_k.weight\n",
313
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.bias\n",
314
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.weight\n",
315
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_q.weight\n",
316
+ "up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_v.weight\n",
317
+ "up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.bias\n",
318
+ "up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.weight\n",
319
+ "up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.bias\n",
320
+ "up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.weight\n",
321
+ "up_blocks.1.attentions.2.transformer_blocks.2.norm1.bias\n",
322
+ "up_blocks.1.attentions.2.transformer_blocks.2.norm1.weight\n",
323
+ "up_blocks.1.attentions.2.transformer_blocks.2.norm2.bias\n",
324
+ "up_blocks.1.attentions.2.transformer_blocks.2.norm2.weight\n",
325
+ "up_blocks.1.attentions.2.transformer_blocks.2.norm3.bias\n",
326
+ "up_blocks.1.attentions.2.transformer_blocks.2.norm3.weight\n",
327
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight\n",
328
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n",
329
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n",
330
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight\n",
331
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight\n",
332
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight\n",
333
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n",
334
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n",
335
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight\n",
336
+ "up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight\n",
337
+ "up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n",
338
+ "up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n",
339
+ "up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias\n",
340
+ "up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight\n",
341
+ "up_blocks.2.attentions.0.transformer_blocks.2.norm1.bias\n",
342
+ "up_blocks.2.attentions.0.transformer_blocks.2.norm1.weight\n",
343
+ "up_blocks.2.attentions.0.transformer_blocks.2.norm2.bias\n",
344
+ "up_blocks.2.attentions.0.transformer_blocks.2.norm2.weight\n",
345
+ "up_blocks.2.attentions.0.transformer_blocks.2.norm3.bias\n",
346
+ "up_blocks.2.attentions.0.transformer_blocks.2.norm3.weight\n",
347
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight\n",
348
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n",
349
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n",
350
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight\n",
351
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight\n",
352
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight\n",
353
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n",
354
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n",
355
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight\n",
356
+ "up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight\n",
357
+ "up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n",
358
+ "up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n",
359
+ "up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias\n",
360
+ "up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight\n",
361
+ "up_blocks.2.attentions.1.transformer_blocks.2.norm1.bias\n",
362
+ "up_blocks.2.attentions.1.transformer_blocks.2.norm1.weight\n",
363
+ "up_blocks.2.attentions.1.transformer_blocks.2.norm2.bias\n",
364
+ "up_blocks.2.attentions.1.transformer_blocks.2.norm2.weight\n",
365
+ "up_blocks.2.attentions.1.transformer_blocks.2.norm3.bias\n",
366
+ "up_blocks.2.attentions.1.transformer_blocks.2.norm3.weight\n",
367
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_k.weight\n",
368
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.bias\n",
369
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.weight\n",
370
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_q.weight\n",
371
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_v.weight\n",
372
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_k.weight\n",
373
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.bias\n",
374
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.weight\n",
375
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_q.weight\n",
376
+ "up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_v.weight\n",
377
+ "up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.bias\n",
378
+ "up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.weight\n",
379
+ "up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.bias\n",
380
+ "up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.weight\n",
381
+ "up_blocks.2.attentions.2.transformer_blocks.2.norm1.bias\n",
382
+ "up_blocks.2.attentions.2.transformer_blocks.2.norm1.weight\n",
383
+ "up_blocks.2.attentions.2.transformer_blocks.2.norm2.bias\n",
384
+ "up_blocks.2.attentions.2.transformer_blocks.2.norm2.weight\n",
385
+ "up_blocks.2.attentions.2.transformer_blocks.2.norm3.bias\n",
386
+ "up_blocks.2.attentions.2.transformer_blocks.2.norm3.weight\n",
387
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_k.weight\n",
388
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_out.0.bias\n",
389
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_out.0.weight\n",
390
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_q.weight\n",
391
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn1.to_v.weight\n",
392
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_k.weight\n",
393
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_out.0.bias\n",
394
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_out.0.weight\n",
395
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_q.weight\n",
396
+ "up_blocks.3.attentions.0.transformer_blocks.2.attn2.to_v.weight\n",
397
+ "up_blocks.3.attentions.0.transformer_blocks.2.ff.net.0.proj.bias\n",
398
+ "up_blocks.3.attentions.0.transformer_blocks.2.ff.net.0.proj.weight\n",
399
+ "up_blocks.3.attentions.0.transformer_blocks.2.ff.net.2.bias\n",
400
+ "up_blocks.3.attentions.0.transformer_blocks.2.ff.net.2.weight\n",
401
+ "up_blocks.3.attentions.0.transformer_blocks.2.norm1.bias\n",
402
+ "up_blocks.3.attentions.0.transformer_blocks.2.norm1.weight\n",
403
+ "up_blocks.3.attentions.0.transformer_blocks.2.norm2.bias\n",
404
+ "up_blocks.3.attentions.0.transformer_blocks.2.norm2.weight\n",
405
+ "up_blocks.3.attentions.0.transformer_blocks.2.norm3.bias\n",
406
+ "up_blocks.3.attentions.0.transformer_blocks.2.norm3.weight\n",
407
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_k.weight\n",
408
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_out.0.bias\n",
409
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_out.0.weight\n",
410
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_q.weight\n",
411
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn1.to_v.weight\n",
412
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_k.weight\n",
413
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_out.0.bias\n",
414
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_out.0.weight\n",
415
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_q.weight\n",
416
+ "up_blocks.3.attentions.1.transformer_blocks.2.attn2.to_v.weight\n",
417
+ "up_blocks.3.attentions.1.transformer_blocks.2.ff.net.0.proj.bias\n",
418
+ "up_blocks.3.attentions.1.transformer_blocks.2.ff.net.0.proj.weight\n",
419
+ "up_blocks.3.attentions.1.transformer_blocks.2.ff.net.2.bias\n",
420
+ "up_blocks.3.attentions.1.transformer_blocks.2.ff.net.2.weight\n",
421
+ "up_blocks.3.attentions.1.transformer_blocks.2.norm1.bias\n",
422
+ "up_blocks.3.attentions.1.transformer_blocks.2.norm1.weight\n",
423
+ "up_blocks.3.attentions.1.transformer_blocks.2.norm2.bias\n",
424
+ "up_blocks.3.attentions.1.transformer_blocks.2.norm2.weight\n",
425
+ "up_blocks.3.attentions.1.transformer_blocks.2.norm3.bias\n",
426
+ "up_blocks.3.attentions.1.transformer_blocks.2.norm3.weight\n",
427
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_k.weight\n",
428
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_out.0.bias\n",
429
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_out.0.weight\n",
430
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_q.weight\n",
431
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn1.to_v.weight\n",
432
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_k.weight\n",
433
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_out.0.bias\n",
434
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_out.0.weight\n",
435
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_q.weight\n",
436
+ "up_blocks.3.attentions.2.transformer_blocks.2.attn2.to_v.weight\n",
437
+ "up_blocks.3.attentions.2.transformer_blocks.2.ff.net.0.proj.bias\n",
438
+ "up_blocks.3.attentions.2.transformer_blocks.2.ff.net.0.proj.weight\n",
439
+ "up_blocks.3.attentions.2.transformer_blocks.2.ff.net.2.bias\n",
440
+ "up_blocks.3.attentions.2.transformer_blocks.2.ff.net.2.weight\n",
441
+ "up_blocks.3.attentions.2.transformer_blocks.2.norm1.bias\n",
442
+ "up_blocks.3.attentions.2.transformer_blocks.2.norm1.weight\n",
443
+ "up_blocks.3.attentions.2.transformer_blocks.2.norm2.bias\n",
444
+ "up_blocks.3.attentions.2.transformer_blocks.2.norm2.weight\n",
445
+ "up_blocks.3.attentions.2.transformer_blocks.2.norm3.bias\n",
446
+ "up_blocks.3.attentions.2.transformer_blocks.2.norm3.weight\n",
447
+ "UNet2DConditionModel(\n",
448
+ " (conv_in): Conv2d(16, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
449
+ " (time_proj): Timesteps()\n",
450
+ " (time_embedding): TimestepEmbedding(\n",
451
+ " (linear_1): Linear(in_features=320, out_features=1280, bias=True)\n",
452
+ " (act): SiLU()\n",
453
+ " (linear_2): Linear(in_features=1280, out_features=1280, bias=True)\n",
454
+ " )\n",
455
+ " (down_blocks): ModuleList(\n",
456
+ " (0): CrossAttnDownBlock2D(\n",
457
+ " (attentions): ModuleList(\n",
458
+ " (0-1): 2 x Transformer2DModel(\n",
459
+ " (norm): GroupNorm(32, 320, eps=1e-06, affine=True)\n",
460
+ " (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))\n",
461
+ " (transformer_blocks): ModuleList(\n",
462
+ " (0-2): 3 x BasicTransformerBlock(\n",
463
+ " (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n",
464
+ " (attn1): Attention(\n",
465
+ " (to_q): Linear(in_features=320, out_features=320, bias=False)\n",
466
+ " (to_k): Linear(in_features=320, out_features=320, bias=False)\n",
467
+ " (to_v): Linear(in_features=320, out_features=320, bias=False)\n",
468
+ " (to_out): ModuleList(\n",
469
+ " (0): Linear(in_features=320, out_features=320, bias=True)\n",
470
+ " (1): Dropout(p=0.0, inplace=False)\n",
471
+ " )\n",
472
+ " )\n",
473
+ " (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n",
474
+ " (attn2): Attention(\n",
475
+ " (to_q): Linear(in_features=320, out_features=320, bias=False)\n",
476
+ " (to_k): Linear(in_features=768, out_features=320, bias=False)\n",
477
+ " (to_v): Linear(in_features=768, out_features=320, bias=False)\n",
478
+ " (to_out): ModuleList(\n",
479
+ " (0): Linear(in_features=320, out_features=320, bias=True)\n",
480
+ " (1): Dropout(p=0.0, inplace=False)\n",
481
+ " )\n",
482
+ " )\n",
483
+ " (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n",
484
+ " (ff): FeedForward(\n",
485
+ " (net): ModuleList(\n",
486
+ " (0): GEGLU(\n",
487
+ " (proj): Linear(in_features=320, out_features=2560, bias=True)\n",
488
+ " )\n",
489
+ " (1): Dropout(p=0.0, inplace=False)\n",
490
+ " (2): Linear(in_features=1280, out_features=320, bias=True)\n",
491
+ " )\n",
492
+ " )\n",
493
+ " )\n",
494
+ " )\n",
495
+ " (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))\n",
496
+ " )\n",
497
+ " )\n",
498
+ " (resnets): ModuleList(\n",
499
+ " (0-1): 2 x ResnetBlock2D(\n",
500
+ " (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)\n",
501
+ " (conv1): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
502
+ " (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)\n",
503
+ " (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)\n",
504
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
505
+ " (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
506
+ " (nonlinearity): SiLU()\n",
507
+ " )\n",
508
+ " )\n",
509
+ " (downsamplers): ModuleList(\n",
510
+ " (0): Downsample2D(\n",
511
+ " (conv): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
512
+ " )\n",
513
+ " )\n",
514
+ " )\n",
515
+ " (1): CrossAttnDownBlock2D(\n",
516
+ " (attentions): ModuleList(\n",
517
+ " (0-1): 2 x Transformer2DModel(\n",
518
+ " (norm): GroupNorm(32, 640, eps=1e-06, affine=True)\n",
519
+ " (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))\n",
520
+ " (transformer_blocks): ModuleList(\n",
521
+ " (0-2): 3 x BasicTransformerBlock(\n",
522
+ " (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n",
523
+ " (attn1): Attention(\n",
524
+ " (to_q): Linear(in_features=640, out_features=640, bias=False)\n",
525
+ " (to_k): Linear(in_features=640, out_features=640, bias=False)\n",
526
+ " (to_v): Linear(in_features=640, out_features=640, bias=False)\n",
527
+ " (to_out): ModuleList(\n",
528
+ " (0): Linear(in_features=640, out_features=640, bias=True)\n",
529
+ " (1): Dropout(p=0.0, inplace=False)\n",
530
+ " )\n",
531
+ " )\n",
532
+ " (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n",
533
+ " (attn2): Attention(\n",
534
+ " (to_q): Linear(in_features=640, out_features=640, bias=False)\n",
535
+ " (to_k): Linear(in_features=768, out_features=640, bias=False)\n",
536
+ " (to_v): Linear(in_features=768, out_features=640, bias=False)\n",
537
+ " (to_out): ModuleList(\n",
538
+ " (0): Linear(in_features=640, out_features=640, bias=True)\n",
539
+ " (1): Dropout(p=0.0, inplace=False)\n",
540
+ " )\n",
541
+ " )\n",
542
+ " (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n",
543
+ " (ff): FeedForward(\n",
544
+ " (net): ModuleList(\n",
545
+ " (0): GEGLU(\n",
546
+ " (proj): Linear(in_features=640, out_features=5120, bias=True)\n",
547
+ " )\n",
548
+ " (1): Dropout(p=0.0, inplace=False)\n",
549
+ " (2): Linear(in_features=2560, out_features=640, bias=True)\n",
550
+ " )\n",
551
+ " )\n",
552
+ " )\n",
553
+ " )\n",
554
+ " (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))\n",
555
+ " )\n",
556
+ " )\n",
557
+ " (resnets): ModuleList(\n",
558
+ " (0): ResnetBlock2D(\n",
559
+ " (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)\n",
560
+ " (conv1): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
561
+ " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n",
562
+ " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n",
563
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
564
+ " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
565
+ " (nonlinearity): SiLU()\n",
566
+ " (conv_shortcut): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))\n",
567
+ " )\n",
568
+ " (1): ResnetBlock2D(\n",
569
+ " (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)\n",
570
+ " (conv1): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
571
+ " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n",
572
+ " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n",
573
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
574
+ " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
575
+ " (nonlinearity): SiLU()\n",
576
+ " )\n",
577
+ " )\n",
578
+ " (downsamplers): ModuleList(\n",
579
+ " (0): Downsample2D(\n",
580
+ " (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
581
+ " )\n",
582
+ " )\n",
583
+ " )\n",
584
+ " (2): CrossAttnDownBlock2D(\n",
585
+ " (attentions): ModuleList(\n",
586
+ " (0-1): 2 x Transformer2DModel(\n",
587
+ " (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)\n",
588
+ " (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
589
+ " (transformer_blocks): ModuleList(\n",
590
+ " (0-2): 3 x BasicTransformerBlock(\n",
591
+ " (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
592
+ " (attn1): Attention(\n",
593
+ " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n",
594
+ " (to_k): Linear(in_features=1280, out_features=1280, bias=False)\n",
595
+ " (to_v): Linear(in_features=1280, out_features=1280, bias=False)\n",
596
+ " (to_out): ModuleList(\n",
597
+ " (0): Linear(in_features=1280, out_features=1280, bias=True)\n",
598
+ " (1): Dropout(p=0.0, inplace=False)\n",
599
+ " )\n",
600
+ " )\n",
601
+ " (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
602
+ " (attn2): Attention(\n",
603
+ " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n",
604
+ " (to_k): Linear(in_features=768, out_features=1280, bias=False)\n",
605
+ " (to_v): Linear(in_features=768, out_features=1280, bias=False)\n",
606
+ " (to_out): ModuleList(\n",
607
+ " (0): Linear(in_features=1280, out_features=1280, bias=True)\n",
608
+ " (1): Dropout(p=0.0, inplace=False)\n",
609
+ " )\n",
610
+ " )\n",
611
+ " (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
612
+ " (ff): FeedForward(\n",
613
+ " (net): ModuleList(\n",
614
+ " (0): GEGLU(\n",
615
+ " (proj): Linear(in_features=1280, out_features=10240, bias=True)\n",
616
+ " )\n",
617
+ " (1): Dropout(p=0.0, inplace=False)\n",
618
+ " (2): Linear(in_features=5120, out_features=1280, bias=True)\n",
619
+ " )\n",
620
+ " )\n",
621
+ " )\n",
622
+ " )\n",
623
+ " (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
624
+ " )\n",
625
+ " )\n",
626
+ " (resnets): ModuleList(\n",
627
+ " (0): ResnetBlock2D(\n",
628
+ " (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)\n",
629
+ " (conv1): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
630
+ " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
631
+ " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
632
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
633
+ " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
634
+ " (nonlinearity): SiLU()\n",
635
+ " (conv_shortcut): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
636
+ " )\n",
637
+ " (1): ResnetBlock2D(\n",
638
+ " (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
639
+ " (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
640
+ " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
641
+ " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
642
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
643
+ " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
644
+ " (nonlinearity): SiLU()\n",
645
+ " )\n",
646
+ " )\n",
647
+ " (downsamplers): ModuleList(\n",
648
+ " (0): Downsample2D(\n",
649
+ " (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
650
+ " )\n",
651
+ " )\n",
652
+ " )\n",
653
+ " (3): DownBlock2D(\n",
654
+ " (resnets): ModuleList(\n",
655
+ " (0-1): 2 x ResnetBlock2D(\n",
656
+ " (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
657
+ " (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
658
+ " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
659
+ " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
660
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
661
+ " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
662
+ " (nonlinearity): SiLU()\n",
663
+ " )\n",
664
+ " )\n",
665
+ " )\n",
666
+ " )\n",
667
+ " (up_blocks): ModuleList(\n",
668
+ " (0): UpBlock2D(\n",
669
+ " (resnets): ModuleList(\n",
670
+ " (0-2): 3 x ResnetBlock2D(\n",
671
+ " (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)\n",
672
+ " (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
673
+ " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
674
+ " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
675
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
676
+ " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
677
+ " (nonlinearity): SiLU()\n",
678
+ " (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
679
+ " )\n",
680
+ " )\n",
681
+ " (upsamplers): ModuleList(\n",
682
+ " (0): Upsample2D(\n",
683
+ " (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
684
+ " )\n",
685
+ " )\n",
686
+ " )\n",
687
+ " (1): CrossAttnUpBlock2D(\n",
688
+ " (attentions): ModuleList(\n",
689
+ " (0-2): 3 x Transformer2DModel(\n",
690
+ " (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)\n",
691
+ " (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
692
+ " (transformer_blocks): ModuleList(\n",
693
+ " (0-2): 3 x BasicTransformerBlock(\n",
694
+ " (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
695
+ " (attn1): Attention(\n",
696
+ " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n",
697
+ " (to_k): Linear(in_features=1280, out_features=1280, bias=False)\n",
698
+ " (to_v): Linear(in_features=1280, out_features=1280, bias=False)\n",
699
+ " (to_out): ModuleList(\n",
700
+ " (0): Linear(in_features=1280, out_features=1280, bias=True)\n",
701
+ " (1): Dropout(p=0.0, inplace=False)\n",
702
+ " )\n",
703
+ " )\n",
704
+ " (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
705
+ " (attn2): Attention(\n",
706
+ " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n",
707
+ " (to_k): Linear(in_features=768, out_features=1280, bias=False)\n",
708
+ " (to_v): Linear(in_features=768, out_features=1280, bias=False)\n",
709
+ " (to_out): ModuleList(\n",
710
+ " (0): Linear(in_features=1280, out_features=1280, bias=True)\n",
711
+ " (1): Dropout(p=0.0, inplace=False)\n",
712
+ " )\n",
713
+ " )\n",
714
+ " (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
715
+ " (ff): FeedForward(\n",
716
+ " (net): ModuleList(\n",
717
+ " (0): GEGLU(\n",
718
+ " (proj): Linear(in_features=1280, out_features=10240, bias=True)\n",
719
+ " )\n",
720
+ " (1): Dropout(p=0.0, inplace=False)\n",
721
+ " (2): Linear(in_features=5120, out_features=1280, bias=True)\n",
722
+ " )\n",
723
+ " )\n",
724
+ " )\n",
725
+ " )\n",
726
+ " (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
727
+ " )\n",
728
+ " )\n",
729
+ " (resnets): ModuleList(\n",
730
+ " (0-1): 2 x ResnetBlock2D(\n",
731
+ " (norm1): GroupNorm(32, 2560, eps=1e-05, affine=True)\n",
732
+ " (conv1): Conv2d(2560, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
733
+ " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
734
+ " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
735
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
736
+ " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
737
+ " (nonlinearity): SiLU()\n",
738
+ " (conv_shortcut): Conv2d(2560, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
739
+ " )\n",
740
+ " (2): ResnetBlock2D(\n",
741
+ " (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)\n",
742
+ " (conv1): Conv2d(1920, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
743
+ " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
744
+ " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
745
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
746
+ " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
747
+ " (nonlinearity): SiLU()\n",
748
+ " (conv_shortcut): Conv2d(1920, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
749
+ " )\n",
750
+ " )\n",
751
+ " (upsamplers): ModuleList(\n",
752
+ " (0): Upsample2D(\n",
753
+ " (conv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
754
+ " )\n",
755
+ " )\n",
756
+ " )\n",
757
+ " (2): CrossAttnUpBlock2D(\n",
758
+ " (attentions): ModuleList(\n",
759
+ " (0-2): 3 x Transformer2DModel(\n",
760
+ " (norm): GroupNorm(32, 640, eps=1e-06, affine=True)\n",
761
+ " (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))\n",
762
+ " (transformer_blocks): ModuleList(\n",
763
+ " (0-2): 3 x BasicTransformerBlock(\n",
764
+ " (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n",
765
+ " (attn1): Attention(\n",
766
+ " (to_q): Linear(in_features=640, out_features=640, bias=False)\n",
767
+ " (to_k): Linear(in_features=640, out_features=640, bias=False)\n",
768
+ " (to_v): Linear(in_features=640, out_features=640, bias=False)\n",
769
+ " (to_out): ModuleList(\n",
770
+ " (0): Linear(in_features=640, out_features=640, bias=True)\n",
771
+ " (1): Dropout(p=0.0, inplace=False)\n",
772
+ " )\n",
773
+ " )\n",
774
+ " (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n",
775
+ " (attn2): Attention(\n",
776
+ " (to_q): Linear(in_features=640, out_features=640, bias=False)\n",
777
+ " (to_k): Linear(in_features=768, out_features=640, bias=False)\n",
778
+ " (to_v): Linear(in_features=768, out_features=640, bias=False)\n",
779
+ " (to_out): ModuleList(\n",
780
+ " (0): Linear(in_features=640, out_features=640, bias=True)\n",
781
+ " (1): Dropout(p=0.0, inplace=False)\n",
782
+ " )\n",
783
+ " )\n",
784
+ " (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)\n",
785
+ " (ff): FeedForward(\n",
786
+ " (net): ModuleList(\n",
787
+ " (0): GEGLU(\n",
788
+ " (proj): Linear(in_features=640, out_features=5120, bias=True)\n",
789
+ " )\n",
790
+ " (1): Dropout(p=0.0, inplace=False)\n",
791
+ " (2): Linear(in_features=2560, out_features=640, bias=True)\n",
792
+ " )\n",
793
+ " )\n",
794
+ " )\n",
795
+ " )\n",
796
+ " (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))\n",
797
+ " )\n",
798
+ " )\n",
799
+ " (resnets): ModuleList(\n",
800
+ " (0): ResnetBlock2D(\n",
801
+ " (norm1): GroupNorm(32, 1920, eps=1e-05, affine=True)\n",
802
+ " (conv1): Conv2d(1920, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
803
+ " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n",
804
+ " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n",
805
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
806
+ " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
807
+ " (nonlinearity): SiLU()\n",
808
+ " (conv_shortcut): Conv2d(1920, 640, kernel_size=(1, 1), stride=(1, 1))\n",
809
+ " )\n",
810
+ " (1): ResnetBlock2D(\n",
811
+ " (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
812
+ " (conv1): Conv2d(1280, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
813
+ " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n",
814
+ " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n",
815
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
816
+ " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
817
+ " (nonlinearity): SiLU()\n",
818
+ " (conv_shortcut): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))\n",
819
+ " )\n",
820
+ " (2): ResnetBlock2D(\n",
821
+ " (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)\n",
822
+ " (conv1): Conv2d(960, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
823
+ " (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)\n",
824
+ " (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)\n",
825
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
826
+ " (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
827
+ " (nonlinearity): SiLU()\n",
828
+ " (conv_shortcut): Conv2d(960, 640, kernel_size=(1, 1), stride=(1, 1))\n",
829
+ " )\n",
830
+ " )\n",
831
+ " (upsamplers): ModuleList(\n",
832
+ " (0): Upsample2D(\n",
833
+ " (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
834
+ " )\n",
835
+ " )\n",
836
+ " )\n",
837
+ " (3): CrossAttnUpBlock2D(\n",
838
+ " (attentions): ModuleList(\n",
839
+ " (0-2): 3 x Transformer2DModel(\n",
840
+ " (norm): GroupNorm(32, 320, eps=1e-06, affine=True)\n",
841
+ " (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))\n",
842
+ " (transformer_blocks): ModuleList(\n",
843
+ " (0-2): 3 x BasicTransformerBlock(\n",
844
+ " (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n",
845
+ " (attn1): Attention(\n",
846
+ " (to_q): Linear(in_features=320, out_features=320, bias=False)\n",
847
+ " (to_k): Linear(in_features=320, out_features=320, bias=False)\n",
848
+ " (to_v): Linear(in_features=320, out_features=320, bias=False)\n",
849
+ " (to_out): ModuleList(\n",
850
+ " (0): Linear(in_features=320, out_features=320, bias=True)\n",
851
+ " (1): Dropout(p=0.0, inplace=False)\n",
852
+ " )\n",
853
+ " )\n",
854
+ " (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n",
855
+ " (attn2): Attention(\n",
856
+ " (to_q): Linear(in_features=320, out_features=320, bias=False)\n",
857
+ " (to_k): Linear(in_features=768, out_features=320, bias=False)\n",
858
+ " (to_v): Linear(in_features=768, out_features=320, bias=False)\n",
859
+ " (to_out): ModuleList(\n",
860
+ " (0): Linear(in_features=320, out_features=320, bias=True)\n",
861
+ " (1): Dropout(p=0.0, inplace=False)\n",
862
+ " )\n",
863
+ " )\n",
864
+ " (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n",
865
+ " (ff): FeedForward(\n",
866
+ " (net): ModuleList(\n",
867
+ " (0): GEGLU(\n",
868
+ " (proj): Linear(in_features=320, out_features=2560, bias=True)\n",
869
+ " )\n",
870
+ " (1): Dropout(p=0.0, inplace=False)\n",
871
+ " (2): Linear(in_features=1280, out_features=320, bias=True)\n",
872
+ " )\n",
873
+ " )\n",
874
+ " )\n",
875
+ " )\n",
876
+ " (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))\n",
877
+ " )\n",
878
+ " )\n",
879
+ " (resnets): ModuleList(\n",
880
+ " (0): ResnetBlock2D(\n",
881
+ " (norm1): GroupNorm(32, 960, eps=1e-05, affine=True)\n",
882
+ " (conv1): Conv2d(960, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
883
+ " (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)\n",
884
+ " (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)\n",
885
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
886
+ " (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
887
+ " (nonlinearity): SiLU()\n",
888
+ " (conv_shortcut): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))\n",
889
+ " )\n",
890
+ " (1-2): 2 x ResnetBlock2D(\n",
891
+ " (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)\n",
892
+ " (conv1): Conv2d(640, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
893
+ " (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)\n",
894
+ " (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)\n",
895
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
896
+ " (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
897
+ " (nonlinearity): SiLU()\n",
898
+ " (conv_shortcut): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1))\n",
899
+ " )\n",
900
+ " )\n",
901
+ " )\n",
902
+ " )\n",
903
+ " (mid_block): UNetMidBlock2DCrossAttn(\n",
904
+ " (attentions): ModuleList(\n",
905
+ " (0): Transformer2DModel(\n",
906
+ " (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)\n",
907
+ " (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
908
+ " (transformer_blocks): ModuleList(\n",
909
+ " (0-2): 3 x BasicTransformerBlock(\n",
910
+ " (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
911
+ " (attn1): Attention(\n",
912
+ " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n",
913
+ " (to_k): Linear(in_features=1280, out_features=1280, bias=False)\n",
914
+ " (to_v): Linear(in_features=1280, out_features=1280, bias=False)\n",
915
+ " (to_out): ModuleList(\n",
916
+ " (0): Linear(in_features=1280, out_features=1280, bias=True)\n",
917
+ " (1): Dropout(p=0.0, inplace=False)\n",
918
+ " )\n",
919
+ " )\n",
920
+ " (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
921
+ " (attn2): Attention(\n",
922
+ " (to_q): Linear(in_features=1280, out_features=1280, bias=False)\n",
923
+ " (to_k): Linear(in_features=768, out_features=1280, bias=False)\n",
924
+ " (to_v): Linear(in_features=768, out_features=1280, bias=False)\n",
925
+ " (to_out): ModuleList(\n",
926
+ " (0): Linear(in_features=1280, out_features=1280, bias=True)\n",
927
+ " (1): Dropout(p=0.0, inplace=False)\n",
928
+ " )\n",
929
+ " )\n",
930
+ " (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
931
+ " (ff): FeedForward(\n",
932
+ " (net): ModuleList(\n",
933
+ " (0): GEGLU(\n",
934
+ " (proj): Linear(in_features=1280, out_features=10240, bias=True)\n",
935
+ " )\n",
936
+ " (1): Dropout(p=0.0, inplace=False)\n",
937
+ " (2): Linear(in_features=5120, out_features=1280, bias=True)\n",
938
+ " )\n",
939
+ " )\n",
940
+ " )\n",
941
+ " )\n",
942
+ " (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))\n",
943
+ " )\n",
944
+ " )\n",
945
+ " (resnets): ModuleList(\n",
946
+ " (0-1): 2 x ResnetBlock2D(\n",
947
+ " (norm1): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
948
+ " (conv1): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
949
+ " (time_emb_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
950
+ " (norm2): GroupNorm(32, 1280, eps=1e-05, affine=True)\n",
951
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
952
+ " (conv2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
953
+ " (nonlinearity): SiLU()\n",
954
+ " )\n",
955
+ " )\n",
956
+ " )\n",
957
+ " (conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)\n",
958
+ " (conv_act): SiLU()\n",
959
+ " (conv_out): Conv2d(320, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
960
+ ")\n"
961
+ ]
962
+ }
963
+ ],
964
+ "source": [
965
+ "import torch\n",
966
+ "from diffusers import UNet2DConditionModel\n",
967
+ "from tqdm import tqdm\n",
968
+ "\n",
969
+ "def log(message):\n",
970
+ " print(message)\n",
971
+ "\n",
972
+ "def main():\n",
973
+ " checkpoint_path_old = \"unet\"\n",
974
+ " checkpoint_path_new = \"sd15_tmp\"\n",
975
+ " device = \"cuda\"\n",
976
+ " dtype = torch.float16\n",
977
+ "\n",
978
+ " # Загрузка моделей\n",
979
+ " old_unet = UNet2DConditionModel.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
980
+ " new_unet = UNet2DConditionModel.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
981
+ "\n",
982
+ " old_state_dict = old_unet.state_dict()\n",
983
+ " new_state_dict = new_unet.state_dict()\n",
984
+ "\n",
985
+ " transferred_state_dict = {}\n",
986
+ " transfer_stats = {\n",
987
+ " \"перенесено\": 0,\n",
988
+ " \"несовпадение_размеров\": 0,\n",
989
+ " \"пропущено\": 0\n",
990
+ " }\n",
991
+ "\n",
992
+ " transferred_keys = set()\n",
993
+ "\n",
994
+ " # Обрабатываем каждый ключ старой модели\n",
995
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
996
+ " new_key = old_key\n",
997
+ "\n",
998
+ " # Проверяем, существует ли ключ в новой модели\n",
999
+ " if new_key in new_state_dict:\n",
1000
+ " # Проверяем совместимость размеров\n",
1001
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
1002
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
1003
+ " transferred_keys.add(new_key)\n",
1004
+ " transfer_stats[\"перенесено\"] += 1\n",
1005
+ " #log(f\"✓ Перенос: {old_key} -> {new_key}, форма: {old_state_dict[old_key].shape}\")\n",
1006
+ " else:\n",
1007
+ " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n",
1008
+ " transfer_stats[\"несовпадение_размеров\"] += 1\n",
1009
+ " else:\n",
1010
+ " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n",
1011
+ " transfer_stats[\"пропущено\"] += 1\n",
1012
+ "\n",
1013
+ " # Обновляем состояние новой модели перенесенными весами\n",
1014
+ " new_state_dict.update(transferred_state_dict)\n",
1015
+ " new_unet.load_state_dict(new_state_dict)\n",
1016
+ " new_unet.save_pretrained(\"unet_1.3b\")\n",
1017
+ "\n",
1018
+ " # Получаем список неперенесенных ключей\n",
1019
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
1020
+ "\n",
1021
+ " print(\"Статистика переноса:\", transfer_stats)\n",
1022
+ " print(\"Неперенесенные ключи в новой модели:\")\n",
1023
+ " for key in non_transferred_keys:\n",
1024
+ " print(key)\n",
1025
+ "\n",
1026
+ " print(new_unet)\n",
1027
+ "\n",
1028
+ "if __name__ == \"__main__\":\n",
1029
+ " main()\n",
1030
+ "# Статистика переноса: {'перенесено': 686, 'несовпадение_размеров': 0, 'пропущено': 0}"
1031
+ ]
1032
+ },
1033
+ {
1034
+ "cell_type": "code",
1035
+ "execution_count": null,
1036
+ "id": "f2438e3d-4b83-4b3f-8e78-53cbcc35f6e4",
1037
+ "metadata": {},
1038
+ "outputs": [],
1039
+ "source": []
1040
+ }
1041
+ ],
1042
+ "metadata": {
1043
+ "kernelspec": {
1044
+ "display_name": "Python 3 (ipykernel)",
1045
+ "language": "python",
1046
+ "name": "python3"
1047
+ },
1048
+ "language_info": {
1049
+ "codemirror_mode": {
1050
+ "name": "ipython",
1051
+ "version": 3
1052
+ },
1053
+ "file_extension": ".py",
1054
+ "mimetype": "text/x-python",
1055
+ "name": "python",
1056
+ "nbconvert_exporter": "python",
1057
+ "pygments_lexer": "ipython3",
1058
+ "version": "3.12.3"
1059
+ }
1060
+ },
1061
+ "nbformat": 4,
1062
+ "nbformat_minor": 5
1063
+ }