babkasotona commited on
Commit
445fb66
·
verified ·
1 Parent(s): b4f2b1a

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/config-checkpoint.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "AutoencoderKL",
3
  "_diffusers_version": "0.37.0",
4
- "_name_or_path": "vae16x32ch",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
@@ -14,7 +14,6 @@
14
  "DownEncoderBlock2D",
15
  "DownEncoderBlock2D",
16
  "DownEncoderBlock2D",
17
- "DownEncoderBlock2D",
18
  "DownEncoderBlock2D"
19
  ],
20
  "force_upcast": false,
@@ -99,7 +98,6 @@
99
  "UpDecoderBlock2D",
100
  "UpDecoderBlock2D",
101
  "UpDecoderBlock2D",
102
- "UpDecoderBlock2D",
103
  "UpDecoderBlock2D"
104
  ],
105
  "use_post_quant_conv": true,
 
1
  {
2
  "_class_name": "AutoencoderKL",
3
  "_diffusers_version": "0.37.0",
4
+ "_name_or_path": "vae16x32ch_empty",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
 
14
  "DownEncoderBlock2D",
15
  "DownEncoderBlock2D",
16
  "DownEncoderBlock2D",
 
17
  "DownEncoderBlock2D"
18
  ],
19
  "force_upcast": false,
 
98
  "UpDecoderBlock2D",
99
  "UpDecoderBlock2D",
100
  "UpDecoderBlock2D",
 
101
  "UpDecoderBlock2D"
102
  ],
103
  "use_post_quant_conv": true,
.ipynb_checkpoints/create_symmetric-Copy1-checkpoint.ipynb ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "407171be-ab46-442b-a0bd-83ca75173eba",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "AutoencoderKL(\n",
14
+ " (encoder): Encoder(\n",
15
+ " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
16
+ " (down_blocks): ModuleList(\n",
17
+ " (0-1): 2 x DownEncoderBlock2D(\n",
18
+ " (resnets): ModuleList(\n",
19
+ " (0-2): 3 x ResnetBlock2D(\n",
20
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
21
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
22
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
23
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
24
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
25
+ " (nonlinearity): SiLU()\n",
26
+ " )\n",
27
+ " )\n",
28
+ " (downsamplers): ModuleList(\n",
29
+ " (0): Downsample2D(\n",
30
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
31
+ " )\n",
32
+ " )\n",
33
+ " )\n",
34
+ " (2): DownEncoderBlock2D(\n",
35
+ " (resnets): ModuleList(\n",
36
+ " (0): ResnetBlock2D(\n",
37
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
38
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
39
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
40
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
41
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
42
+ " (nonlinearity): SiLU()\n",
43
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
44
+ " )\n",
45
+ " (1-2): 2 x ResnetBlock2D(\n",
46
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
47
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
48
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
49
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
50
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
51
+ " (nonlinearity): SiLU()\n",
52
+ " )\n",
53
+ " )\n",
54
+ " (downsamplers): ModuleList(\n",
55
+ " (0): Downsample2D(\n",
56
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
57
+ " )\n",
58
+ " )\n",
59
+ " )\n",
60
+ " (3): DownEncoderBlock2D(\n",
61
+ " (resnets): ModuleList(\n",
62
+ " (0): ResnetBlock2D(\n",
63
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
64
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
65
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
66
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
67
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
68
+ " (nonlinearity): SiLU()\n",
69
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
70
+ " )\n",
71
+ " (1-2): 2 x ResnetBlock2D(\n",
72
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
73
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
74
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
75
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
76
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
77
+ " (nonlinearity): SiLU()\n",
78
+ " )\n",
79
+ " )\n",
80
+ " (downsamplers): ModuleList(\n",
81
+ " (0): Downsample2D(\n",
82
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n",
83
+ " )\n",
84
+ " )\n",
85
+ " )\n",
86
+ " (4): DownEncoderBlock2D(\n",
87
+ " (resnets): ModuleList(\n",
88
+ " (0-2): 3 x ResnetBlock2D(\n",
89
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
90
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
91
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
92
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
93
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
94
+ " (nonlinearity): SiLU()\n",
95
+ " )\n",
96
+ " )\n",
97
+ " )\n",
98
+ " )\n",
99
+ " (mid_block): UNetMidBlock2D(\n",
100
+ " (attentions): ModuleList(\n",
101
+ " (0): Attention(\n",
102
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
103
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
104
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
105
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
106
+ " (to_out): ModuleList(\n",
107
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
108
+ " (1): Dropout(p=0.0, inplace=False)\n",
109
+ " )\n",
110
+ " )\n",
111
+ " )\n",
112
+ " (resnets): ModuleList(\n",
113
+ " (0-1): 2 x ResnetBlock2D(\n",
114
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
115
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
116
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
117
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
118
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
119
+ " (nonlinearity): SiLU()\n",
120
+ " )\n",
121
+ " )\n",
122
+ " )\n",
123
+ " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
124
+ " (conv_act): SiLU()\n",
125
+ " (conv_out): Conv2d(512, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
126
+ " )\n",
127
+ " (decoder): Decoder(\n",
128
+ " (conv_in): Conv2d(32, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
129
+ " (up_blocks): ModuleList(\n",
130
+ " (0-1): 2 x UpDecoderBlock2D(\n",
131
+ " (resnets): ModuleList(\n",
132
+ " (0-3): 4 x ResnetBlock2D(\n",
133
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
134
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
135
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
136
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
137
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
138
+ " (nonlinearity): SiLU()\n",
139
+ " )\n",
140
+ " )\n",
141
+ " (upsamplers): ModuleList(\n",
142
+ " (0): Upsample2D(\n",
143
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
144
+ " )\n",
145
+ " )\n",
146
+ " )\n",
147
+ " (2): UpDecoderBlock2D(\n",
148
+ " (resnets): ModuleList(\n",
149
+ " (0): ResnetBlock2D(\n",
150
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
151
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
152
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
153
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
154
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
155
+ " (nonlinearity): SiLU()\n",
156
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
157
+ " )\n",
158
+ " (1-3): 3 x ResnetBlock2D(\n",
159
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
160
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
161
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
162
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
163
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
164
+ " (nonlinearity): SiLU()\n",
165
+ " )\n",
166
+ " )\n",
167
+ " (upsamplers): ModuleList(\n",
168
+ " (0): Upsample2D(\n",
169
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
170
+ " )\n",
171
+ " )\n",
172
+ " )\n",
173
+ " (3): UpDecoderBlock2D(\n",
174
+ " (resnets): ModuleList(\n",
175
+ " (0): ResnetBlock2D(\n",
176
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
177
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
178
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
179
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
180
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
181
+ " (nonlinearity): SiLU()\n",
182
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
183
+ " )\n",
184
+ " (1-3): 3 x ResnetBlock2D(\n",
185
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
186
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
187
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
188
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
189
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
190
+ " (nonlinearity): SiLU()\n",
191
+ " )\n",
192
+ " )\n",
193
+ " (upsamplers): ModuleList(\n",
194
+ " (0): Upsample2D(\n",
195
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
196
+ " )\n",
197
+ " )\n",
198
+ " )\n",
199
+ " (4): UpDecoderBlock2D(\n",
200
+ " (resnets): ModuleList(\n",
201
+ " (0-3): 4 x ResnetBlock2D(\n",
202
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
203
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
204
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
205
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
206
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
207
+ " (nonlinearity): SiLU()\n",
208
+ " )\n",
209
+ " )\n",
210
+ " )\n",
211
+ " )\n",
212
+ " (mid_block): UNetMidBlock2D(\n",
213
+ " (attentions): ModuleList(\n",
214
+ " (0): Attention(\n",
215
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
216
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
217
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
218
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
219
+ " (to_out): ModuleList(\n",
220
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
221
+ " (1): Dropout(p=0.0, inplace=False)\n",
222
+ " )\n",
223
+ " )\n",
224
+ " )\n",
225
+ " (resnets): ModuleList(\n",
226
+ " (0-1): 2 x ResnetBlock2D(\n",
227
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
228
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
229
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
230
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
231
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
232
+ " (nonlinearity): SiLU()\n",
233
+ " )\n",
234
+ " )\n",
235
+ " )\n",
236
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
237
+ " (conv_act): SiLU()\n",
238
+ " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
239
+ " )\n",
240
+ " (quant_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
241
+ " (post_quant_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
242
+ ")\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "from diffusers.models import AutoencoderKL\n",
248
+ "import torch\n",
249
+ "\n",
250
+ "config = {\n",
251
+ " \"_class_name\": \"AutoencoderKL\",\n",
252
+ " \"_diffusers_version\": \"0.36.0\",\n",
253
+ " \"act_fn\": \"silu\",\n",
254
+ " \"block_out_channels\": [\n",
255
+ " 128,\n",
256
+ " 128,\n",
257
+ " 256,\n",
258
+ " 512,\n",
259
+ " 512\n",
260
+ " ],\n",
261
+ " \"down_block_types\": [\n",
262
+ " \"DownEncoderBlock2D\",\n",
263
+ " \"DownEncoderBlock2D\",\n",
264
+ " \"DownEncoderBlock2D\",\n",
265
+ " \"DownEncoderBlock2D\",\n",
266
+ " \"DownEncoderBlock2D\"\n",
267
+ " ],\n",
268
+ " \"force_upcast\": False,\n",
269
+ " \"in_channels\": 3,\n",
270
+ " \"latent_channels\": 32,\n",
271
+ " \"latents_mean\": [\n",
272
+ " -0.03542253375053406,\n",
273
+ " 0.20086465775966644,\n",
274
+ " -0.016413161531090736,\n",
275
+ " -0.0956302210688591,\n",
276
+ " -0.2672063112258911,\n",
277
+ " 0.2609933018684387,\n",
278
+ " -0.07806991040706635,\n",
279
+ " -0.48407721519470215,\n",
280
+ " 0.21844269335269928,\n",
281
+ " -0.1122383326292038,\n",
282
+ " 0.27197545766830444,\n",
283
+ " -0.18958772718906403,\n",
284
+ " 0.18776826560497284,\n",
285
+ " 0.0987580344080925,\n",
286
+ " 0.2837068736553192,\n",
287
+ " -0.4486690163612366,\n",
288
+ " 0.4816776514053345,\n",
289
+ " 0.02947971224784851,\n",
290
+ " -0.1337375044822693,\n",
291
+ " -0.39750921726226807,\n",
292
+ " -0.08513020724058151,\n",
293
+ " -0.054023586213588715,\n",
294
+ " -0.3943594992160797,\n",
295
+ " 0.23918119072914124,\n",
296
+ " -0.12466679513454437,\n",
297
+ " 0.09935147315263748,\n",
298
+ " 0.31858691573143005,\n",
299
+ " 0.48585832118988037,\n",
300
+ " -0.6416525840759277,\n",
301
+ " -0.15164820849895477,\n",
302
+ " -0.4693508744239807,\n",
303
+ " -0.13071806728839874\n",
304
+ " ],\n",
305
+ " \"latents_std\": [\n",
306
+ " 1.5792087316513062,\n",
307
+ " 1.5769503116607666,\n",
308
+ " 1.5864241123199463,\n",
309
+ " 1.6454921960830688,\n",
310
+ " 1.5336694717407227,\n",
311
+ " 1.5587652921676636,\n",
312
+ " 1.5838669538497925,\n",
313
+ " 1.5659377574920654,\n",
314
+ " 1.6860467195510864,\n",
315
+ " 1.5192310810089111,\n",
316
+ " 1.573639988899231,\n",
317
+ " 1.5953549146652222,\n",
318
+ " 1.5271092653274536,\n",
319
+ " 1.6246271133422852,\n",
320
+ " 1.7054023742675781,\n",
321
+ " 1.607722282409668,\n",
322
+ " 1.558642864227295,\n",
323
+ " 1.5824549198150635,\n",
324
+ " 1.6202995777130127,\n",
325
+ " 1.6206320524215698,\n",
326
+ " 1.6379750967025757,\n",
327
+ " 1.6527063846588135,\n",
328
+ " 1.498811960220337,\n",
329
+ " 1.5706247091293335,\n",
330
+ " 1.5854856967926025,\n",
331
+ " 1.4828169345855713,\n",
332
+ " 1.5693111419677734,\n",
333
+ " 1.692481517791748,\n",
334
+ " 1.6409776210784912,\n",
335
+ " 1.6216280460357666,\n",
336
+ " 1.6087706089019775,\n",
337
+ " 1.5776633024215698\n",
338
+ " ],\n",
339
+ " \"layers_per_block\": 2,\n",
340
+ " \"mid_block_add_attention\": True,\n",
341
+ " \"norm_num_groups\": 32,\n",
342
+ " \"out_channels\": 3,\n",
343
+ " \"sample_size\": 32,\n",
344
+ " \"scaling_factor\": 1.0,\n",
345
+ " \"shift_factor\": 0.0,\n",
346
+ " \"up_block_types\": [\n",
347
+ " \"UpDecoderBlock2D\",\n",
348
+ " \"UpDecoderBlock2D\",\n",
349
+ " \"UpDecoderBlock2D\",\n",
350
+ " \"UpDecoderBlock2D\",\n",
351
+ " \"UpDecoderBlock2D\"\n",
352
+ " ],\n",
353
+ " \"use_post_quant_conv\": True,\n",
354
+ " \"use_quant_conv\": True\n",
355
+ "}\n",
356
+ "\n",
357
+ "\n",
358
+ "vae = AutoencoderKL(\n",
359
+ " act_fn=config[\"act_fn\"],\n",
360
+ " block_out_channels=config[\"block_out_channels\"],\n",
361
+ " down_block_types=config[\"down_block_types\"],\n",
362
+ " up_block_types=config[\"up_block_types\"],\n",
363
+ " in_channels=config[\"in_channels\"],\n",
364
+ " out_channels=config[\"out_channels\"],\n",
365
+ " latent_channels=config[\"latent_channels\"],\n",
366
+ " layers_per_block=3, #config[\"layers_per_block\"],\n",
367
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
368
+ " sample_size=config[\"sample_size\"],\n",
369
+ " scaling_factor=config[\"scaling_factor\"],\n",
370
+ " force_upcast=config[\"force_upcast\"],\n",
371
+ " mid_block_add_attention=config[\"mid_block_add_attention\"],\n",
372
+ " use_quant_conv=config[\"use_quant_conv\"],\n",
373
+ " use_post_quant_conv=config[\"use_post_quant_conv\"],\n",
374
+ " latents_mean=(config[\"latents_mean\"]),\n",
375
+ " latents_std=(config[\"latents_std\"]),\n",
376
+ ")\n",
377
+ "\n",
378
+ "vae.save_pretrained(\"vae16x32ch_empty\")\n",
379
+ "\n",
380
+ "print(vae)"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 1,
386
+ "id": "a2950158-5203-42b9-8791-e231ddbf1063",
387
+ "metadata": {},
388
+ "outputs": [
389
+ {
390
+ "name": "stderr",
391
+ "output_type": "stream",
392
+ "text": [
393
+ "Перенос весов: 100%|██████████| 292/292 [00:00<00:00, 35760.83it/s]\n"
394
+ ]
395
+ },
396
+ {
397
+ "name": "stdout",
398
+ "output_type": "stream",
399
+ "text": [
400
+ "Статистика переноса: {'перенесено': 292, 'несовпадение_размеров': 0, 'пропущено': 0}\n",
401
+ "Неперенесенные ключи в новой модели:\n",
402
+ "decoder.up_blocks.0.resnets.3.conv1.bias\n",
403
+ "decoder.up_blocks.0.resnets.3.conv1.weight\n",
404
+ "decoder.up_blocks.0.resnets.3.conv2.bias\n",
405
+ "decoder.up_blocks.0.resnets.3.conv2.weight\n",
406
+ "decoder.up_blocks.0.resnets.3.norm1.bias\n",
407
+ "decoder.up_blocks.0.resnets.3.norm1.weight\n",
408
+ "decoder.up_blocks.0.resnets.3.norm2.bias\n",
409
+ "decoder.up_blocks.0.resnets.3.norm2.weight\n",
410
+ "decoder.up_blocks.1.resnets.3.conv1.bias\n",
411
+ "decoder.up_blocks.1.resnets.3.conv1.weight\n",
412
+ "decoder.up_blocks.1.resnets.3.conv2.bias\n",
413
+ "decoder.up_blocks.1.resnets.3.conv2.weight\n",
414
+ "decoder.up_blocks.1.resnets.3.norm1.bias\n",
415
+ "decoder.up_blocks.1.resnets.3.norm1.weight\n",
416
+ "decoder.up_blocks.1.resnets.3.norm2.bias\n",
417
+ "decoder.up_blocks.1.resnets.3.norm2.weight\n",
418
+ "decoder.up_blocks.2.resnets.3.conv1.bias\n",
419
+ "decoder.up_blocks.2.resnets.3.conv1.weight\n",
420
+ "decoder.up_blocks.2.resnets.3.conv2.bias\n",
421
+ "decoder.up_blocks.2.resnets.3.conv2.weight\n",
422
+ "decoder.up_blocks.2.resnets.3.norm1.bias\n",
423
+ "decoder.up_blocks.2.resnets.3.norm1.weight\n",
424
+ "decoder.up_blocks.2.resnets.3.norm2.bias\n",
425
+ "decoder.up_blocks.2.resnets.3.norm2.weight\n",
426
+ "decoder.up_blocks.3.resnets.3.conv1.bias\n",
427
+ "decoder.up_blocks.3.resnets.3.conv1.weight\n",
428
+ "decoder.up_blocks.3.resnets.3.conv2.bias\n",
429
+ "decoder.up_blocks.3.resnets.3.conv2.weight\n",
430
+ "decoder.up_blocks.3.resnets.3.norm1.bias\n",
431
+ "decoder.up_blocks.3.resnets.3.norm1.weight\n",
432
+ "decoder.up_blocks.3.resnets.3.norm2.bias\n",
433
+ "decoder.up_blocks.3.resnets.3.norm2.weight\n",
434
+ "decoder.up_blocks.4.resnets.3.conv1.bias\n",
435
+ "decoder.up_blocks.4.resnets.3.conv1.weight\n",
436
+ "decoder.up_blocks.4.resnets.3.conv2.bias\n",
437
+ "decoder.up_blocks.4.resnets.3.conv2.weight\n",
438
+ "decoder.up_blocks.4.resnets.3.norm1.bias\n",
439
+ "decoder.up_blocks.4.resnets.3.norm1.weight\n",
440
+ "decoder.up_blocks.4.resnets.3.norm2.bias\n",
441
+ "decoder.up_blocks.4.resnets.3.norm2.weight\n",
442
+ "encoder.down_blocks.0.resnets.2.conv1.bias\n",
443
+ "encoder.down_blocks.0.resnets.2.conv1.weight\n",
444
+ "encoder.down_blocks.0.resnets.2.conv2.bias\n",
445
+ "encoder.down_blocks.0.resnets.2.conv2.weight\n",
446
+ "encoder.down_blocks.0.resnets.2.norm1.bias\n",
447
+ "encoder.down_blocks.0.resnets.2.norm1.weight\n",
448
+ "encoder.down_blocks.0.resnets.2.norm2.bias\n",
449
+ "encoder.down_blocks.0.resnets.2.norm2.weight\n",
450
+ "encoder.down_blocks.1.resnets.2.conv1.bias\n",
451
+ "encoder.down_blocks.1.resnets.2.conv1.weight\n",
452
+ "encoder.down_blocks.1.resnets.2.conv2.bias\n",
453
+ "encoder.down_blocks.1.resnets.2.conv2.weight\n",
454
+ "encoder.down_blocks.1.resnets.2.norm1.bias\n",
455
+ "encoder.down_blocks.1.resnets.2.norm1.weight\n",
456
+ "encoder.down_blocks.1.resnets.2.norm2.bias\n",
457
+ "encoder.down_blocks.1.resnets.2.norm2.weight\n",
458
+ "encoder.down_blocks.2.resnets.2.conv1.bias\n",
459
+ "encoder.down_blocks.2.resnets.2.conv1.weight\n",
460
+ "encoder.down_blocks.2.resnets.2.conv2.bias\n",
461
+ "encoder.down_blocks.2.resnets.2.conv2.weight\n",
462
+ "encoder.down_blocks.2.resnets.2.norm1.bias\n",
463
+ "encoder.down_blocks.2.resnets.2.norm1.weight\n",
464
+ "encoder.down_blocks.2.resnets.2.norm2.bias\n",
465
+ "encoder.down_blocks.2.resnets.2.norm2.weight\n",
466
+ "encoder.down_blocks.3.resnets.2.conv1.bias\n",
467
+ "encoder.down_blocks.3.resnets.2.conv1.weight\n",
468
+ "encoder.down_blocks.3.resnets.2.conv2.bias\n",
469
+ "encoder.down_blocks.3.resnets.2.conv2.weight\n",
470
+ "encoder.down_blocks.3.resnets.2.norm1.bias\n",
471
+ "encoder.down_blocks.3.resnets.2.norm1.weight\n",
472
+ "encoder.down_blocks.3.resnets.2.norm2.bias\n",
473
+ "encoder.down_blocks.3.resnets.2.norm2.weight\n",
474
+ "encoder.down_blocks.4.resnets.2.conv1.bias\n",
475
+ "encoder.down_blocks.4.resnets.2.conv1.weight\n",
476
+ "encoder.down_blocks.4.resnets.2.conv2.bias\n",
477
+ "encoder.down_blocks.4.resnets.2.conv2.weight\n",
478
+ "encoder.down_blocks.4.resnets.2.norm1.bias\n",
479
+ "encoder.down_blocks.4.resnets.2.norm1.weight\n",
480
+ "encoder.down_blocks.4.resnets.2.norm2.bias\n",
481
+ "encoder.down_blocks.4.resnets.2.norm2.weight\n"
482
+ ]
483
+ }
484
+ ],
485
+ "source": [
486
+ "import torch\n",
487
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
488
+ "from tqdm import tqdm\n",
489
+ "import torch.nn.init as init\n",
490
+ "\n",
491
+ "def log(message):\n",
492
+ " print(message)\n",
493
+ "\n",
494
+ "def main():\n",
495
+ " checkpoint_path_old = \"vae16x32ch_new\"\n",
496
+ " checkpoint_path_new = \"vae16x32ch_empty\"\n",
497
+ " device = \"cuda\"\n",
498
+ " dtype = torch.float32\n",
499
+ "\n",
500
+ " # Загрузка моделей\n",
501
+ " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
502
+ " new_unet = AutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
503
+ "\n",
504
+ " old_state_dict = old_unet.state_dict()\n",
505
+ " new_state_dict = new_unet.state_dict()\n",
506
+ "\n",
507
+ " transferred_state_dict = {}\n",
508
+ " transfer_stats = {\n",
509
+ " \"перенесено\": 0,\n",
510
+ " \"несовпадение_размеров\": 0,\n",
511
+ " \"пропущено\": 0\n",
512
+ " }\n",
513
+ "\n",
514
+ " transferred_keys = set()\n",
515
+ "\n",
516
+ " # Обрабатываем каждый ключ старой модели\n",
517
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
518
+ " new_key = old_key\n",
519
+ "\n",
520
+ " if new_key in new_state_dict:\n",
521
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
522
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
523
+ " transferred_keys.add(new_key)\n",
524
+ " transfer_stats[\"перенесено\"] += 1\n",
525
+ " else:\n",
526
+ " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n",
527
+ " transfer_stats[\"несовпадение_размеров\"] += 1\n",
528
+ " else:\n",
529
+ " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n",
530
+ " transfer_stats[\"пропущено\"] += 1\n",
531
+ "\n",
532
+ " # Обновляем состояние новой модели перенесенными весами\n",
533
+ " new_state_dict.update(transferred_state_dict)\n",
534
+ " \n",
535
+ " # Инициализируем веса для нового mid блока\n",
536
+ " #new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n",
537
+ " \n",
538
+ " new_unet.load_state_dict(new_state_dict)\n",
539
+ " new_unet.save_pretrained(\"vae16x32ch\")\n",
540
+ "\n",
541
+ " # Получаем список неперенесенных ключей\n",
542
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
543
+ "\n",
544
+ " print(\"Статистика переноса:\", transfer_stats)\n",
545
+ " print(\"Неперенесенные ключи в новой модели:\")\n",
546
+ " for key in non_transferred_keys:\n",
547
+ " print(key)\n",
548
+ "\n",
549
+ "if __name__ == \"__main__\":\n",
550
+ " main()"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": 1,
556
+ "id": "b316ee6c-d295-4396-9177-78e39a53055b",
557
+ "metadata": {},
558
+ "outputs": [
559
+ {
560
+ "name": "stderr",
561
+ "output_type": "stream",
562
+ "text": [
563
+ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
564
+ ]
565
+ },
566
+ {
567
+ "name": "stdout",
568
+ "output_type": "stream",
569
+ "text": [
570
+ "ok\n"
571
+ ]
572
+ }
573
+ ],
574
+ "source": [
575
+ "import torch\n",
576
+ "\n",
577
+ "from torchvision import transforms, utils\n",
578
+ "\n",
579
+ "import diffusers\n",
580
+ "from diffusers import AsymmetricAutoencoderKL\n",
581
+ "\n",
582
+ "from diffusers.utils import load_image\n",
583
+ "\n",
584
+ "def crop_image_to_nearest_divisible_by_8(img):\n",
585
+ " # Check if the image height and width are divisible by 8\n",
586
+ " if img.shape[1] % 8 == 0 and img.shape[2] % 8 == 0:\n",
587
+ " return img\n",
588
+ " else:\n",
589
+ " # Calculate the closest lower resolution divisible by 8\n",
590
+ " new_height = img.shape[1] - (img.shape[1] % 8)\n",
591
+ " new_width = img.shape[2] - (img.shape[2] % 8)\n",
592
+ " \n",
593
+ " # Use CenterCrop to crop the image\n",
594
+ " transform = transforms.CenterCrop((new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR)\n",
595
+ " img = transform(img).to(torch.float32).clamp(-1, 1)\n",
596
+ " \n",
597
+ " return img\n",
598
+ " \n",
599
+ "to_tensor = transforms.ToTensor()\n",
600
+ "\n",
601
+ "device = \"cuda\"\n",
602
+ "dtype=torch.float16\n",
603
+ "vae = AsymmetricAutoencoderKL.from_pretrained(\"asymmetric_vae\",torch_dtype=dtype).to(device).eval()\n",
604
+ "\n",
605
+ "image = load_image(\"123456789.jpg\")\n",
606
+ "\n",
607
+ "image = crop_image_to_nearest_divisible_by_8(to_tensor(image)).unsqueeze(0).to(device,dtype=dtype)\n",
608
+ "\n",
609
+ "upscaled_image = vae(image).sample\n",
610
+ "#vae.config.scaled_factor\n",
611
+ "# Save the reconstructed image\n",
612
+ "utils.save_image(upscaled_image, \"test.png\")\n",
613
+ "print('ok')"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": 11,
619
+ "id": "5a01b8e9-73c9-4da7-a097-e334019bd8e9",
620
+ "metadata": {},
621
+ "outputs": [
622
+ {
623
+ "name": "stderr",
624
+ "output_type": "stream",
625
+ "text": [
626
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False, 'latents_mean': [-0.03542253375053406, 0.20086465775966644, -0.016413161531090736, -0.0956302210688591, -0.2672063112258911, 0.2609933018684387, -0.07806991040706635, -0.48407721519470215, 0.21844269335269928, -0.1122383326292038, 0.27197545766830444, -0.18958772718906403, 0.18776826560497284, 0.0987580344080925, 0.2837068736553192, -0.4486690163612366, 0.4816776514053345, 0.02947971224784851, -0.1337375044822693, -0.39750921726226807, -0.08513020724058151, -0.054023586213588715, -0.3943594992160797, 0.23918119072914124, -0.12466679513454437, 0.09935147315263748, 0.31858691573143005, 0.48585832118988037, -0.6416525840759277, -0.15164820849895477, -0.4693508744239807, -0.13071806728839874], 'latents_std': [1.5792087316513062, 1.5769503116607666, 1.5864241123199463, 1.6454921960830688, 1.5336694717407227, 1.5587652921676636, 1.5838669538497925, 1.5659377574920654, 1.6860467195510864, 1.5192310810089111, 1.573639988899231, 1.5953549146652222, 1.5271092653274536, 1.6246271133422852, 1.7054023742675781, 1.607722282409668, 1.558642864227295, 1.5824549198150635, 1.6202995777130127, 1.6206320524215698, 1.6379750967025757, 1.6527063846588135, 1.498811960220337, 1.5706247091293335, 1.5854856967926025, 1.4828169345855713, 1.5693111419677734, 1.692481517791748, 1.6409776210784912, 1.6216280460357666, 1.6087706089019775, 1.5776633024215698]} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
627
+ "Перенос весов: 100%|██████████| 284/284 [00:00<00:00, 30094.80it/s]\n"
628
+ ]
629
+ },
630
+ {
631
+ "name": "stdout",
632
+ "output_type": "stream",
633
+ "text": [
634
+ "Статистика: {'перенесено': 292, 'несовпадение_размеров': 0, 'пропущено': 10}\n",
635
+ "\n",
636
+ "Неперенесенные ��лючи:\n"
637
+ ]
638
+ }
639
+ ],
640
+ "source": [
641
+ "import torch\n",
642
+ "from diffusers import AutoencoderKL, AsymmetricAutoencoderKL\n",
643
+ "from tqdm import tqdm\n",
644
+ "\n",
645
+ "\n",
646
+ "def log(message):\n",
647
+ " print(message)\n",
648
+ "\n",
649
+ "\n",
650
+ "def remap_key(old_key: str):\n",
651
+ " \"\"\"\n",
652
+ " Смещение только encoder.down_blocks\n",
653
+ " \"\"\"\n",
654
+ "\n",
655
+ " if \"encoder.down_blocks\" not in old_key:\n",
656
+ " return [old_key]\n",
657
+ "\n",
658
+ " parts = old_key.split(\".\")\n",
659
+ " block_id = int(parts[2])\n",
660
+ "\n",
661
+ " if block_id == 0:\n",
662
+ " # первый блок копируем дважды\n",
663
+ " return [\n",
664
+ " old_key.replace(\"down_blocks.0\", \"down_blocks.0\"),\n",
665
+ " old_key.replace(\"down_blocks.0\", \"down_blocks.1\"),\n",
666
+ " ]\n",
667
+ "\n",
668
+ " # остальные блоки сдвигаем\n",
669
+ " new_block = block_id + 1\n",
670
+ " return [old_key.replace(f\"down_blocks.{block_id}\", f\"down_blocks.{new_block}\")]\n",
671
+ "\n",
672
+ "\n",
673
+ "def main():\n",
674
+ " checkpoint_path_old = \"asymmetric_vae_new\"\n",
675
+ " checkpoint_path_new = \"vae16x32ch_empty\"\n",
676
+ "\n",
677
+ " device = \"cuda\"\n",
678
+ " dtype = torch.float32\n",
679
+ "\n",
680
+ " old_vae = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
681
+ " new_vae = AutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
682
+ "\n",
683
+ " old_state_dict = old_vae.state_dict()\n",
684
+ " new_state_dict = new_vae.state_dict()\n",
685
+ "\n",
686
+ " transferred_state_dict = {}\n",
687
+ " transferred_keys = set()\n",
688
+ "\n",
689
+ " transfer_stats = {\n",
690
+ " \"перенесено\": 0,\n",
691
+ " \"несовпадение_размеров\": 0,\n",
692
+ " \"пропущено\": 0\n",
693
+ " }\n",
694
+ "\n",
695
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
696
+ "\n",
697
+ " new_keys = remap_key(old_key)\n",
698
+ "\n",
699
+ " for new_key in new_keys:\n",
700
+ "\n",
701
+ " if new_key in new_state_dict:\n",
702
+ "\n",
703
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
704
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
705
+ " transferred_keys.add(new_key)\n",
706
+ " transfer_stats[\"перенесено\"] += 1\n",
707
+ " else:\n",
708
+ " log(\n",
709
+ " f\"✗ Несовпадение размеров: \"\n",
710
+ " f\"{old_key} {old_state_dict[old_key].shape} \"\n",
711
+ " f\"-> {new_key} {new_state_dict[new_key].shape}\"\n",
712
+ " )\n",
713
+ " transfer_stats[\"несовпадение_размеров\"] += 1\n",
714
+ " else:\n",
715
+ " transfer_stats[\"пропущено\"] += 1\n",
716
+ "\n",
717
+ " new_state_dict.update(transferred_state_dict)\n",
718
+ "\n",
719
+ " new_vae.load_state_dict(new_state_dict)\n",
720
+ " new_vae.save_pretrained(\"vae16x32ch\")\n",
721
+ "\n",
722
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
723
+ "\n",
724
+ " print(\"Статистика:\", transfer_stats)\n",
725
+ "\n",
726
+ " print(\"\\nНеперенесенные ключи:\")\n",
727
+ " for key in non_transferred_keys:\n",
728
+ " print(key)\n",
729
+ "\n",
730
+ "\n",
731
+ "if __name__ == \"__main__\":\n",
732
+ " main()"
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "code",
737
+ "execution_count": null,
738
+ "id": "fe8f1ceb-8d3e-4df5-a1dc-1b56a0d398a2",
739
+ "metadata": {},
740
+ "outputs": [],
741
+ "source": []
742
+ }
743
+ ],
744
+ "metadata": {
745
+ "kernelspec": {
746
+ "display_name": "Python3 (ipykernel)",
747
+ "language": "python",
748
+ "name": "python3"
749
+ },
750
+ "language_info": {
751
+ "codemirror_mode": {
752
+ "name": "ipython",
753
+ "version": 3
754
+ },
755
+ "file_extension": ".py",
756
+ "mimetype": "text/x-python",
757
+ "name": "python",
758
+ "nbconvert_exporter": "python",
759
+ "pygments_lexer": "ipython3",
760
+ "version": "3.12.12"
761
+ }
762
+ },
763
+ "nbformat": 4,
764
+ "nbformat_minor": 5
765
+ }
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "AutoencoderKL",
3
  "_diffusers_version": "0.37.0",
4
- "_name_or_path": "vae16x32ch_new",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
@@ -88,7 +88,7 @@
88
  1.6087706089019775,
89
  1.5776633024215698
90
  ],
91
- "layers_per_block": 2,
92
  "mid_block_add_attention": true,
93
  "norm_num_groups": 32,
94
  "out_channels": 3,
 
1
  {
2
  "_class_name": "AutoencoderKL",
3
  "_diffusers_version": "0.37.0",
4
+ "_name_or_path": "vae16x32ch",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
 
88
  1.6087706089019775,
89
  1.5776633024215698
90
  ],
91
+ "layers_per_block": 3,
92
  "mid_block_add_attention": true,
93
  "norm_num_groups": 32,
94
  "out_channels": 3,
create_symmetric-Copy1.ipynb ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "407171be-ab46-442b-a0bd-83ca75173eba",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "AutoencoderKL(\n",
14
+ " (encoder): Encoder(\n",
15
+ " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
16
+ " (down_blocks): ModuleList(\n",
17
+ " (0-1): 2 x DownEncoderBlock2D(\n",
18
+ " (resnets): ModuleList(\n",
19
+ " (0-2): 3 x ResnetBlock2D(\n",
20
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
21
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
22
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
23
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
24
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
25
+ " (nonlinearity): SiLU()\n",
26
+ " )\n",
27
+ " )\n",
28
+ " (downsamplers): ModuleList(\n",
29
+ " (0): Downsample2D(\n",
30
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
31
+ " )\n",
32
+ " )\n",
33
+ " )\n",
34
+ " (2): DownEncoderBlock2D(\n",
35
+ " (resnets): ModuleList(\n",
36
+ " (0): ResnetBlock2D(\n",
37
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
38
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
39
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
40
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
41
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
42
+ " (nonlinearity): SiLU()\n",
43
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
44
+ " )\n",
45
+ " (1-2): 2 x ResnetBlock2D(\n",
46
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
47
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
48
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
49
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
50
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
51
+ " (nonlinearity): SiLU()\n",
52
+ " )\n",
53
+ " )\n",
54
+ " (downsamplers): ModuleList(\n",
55
+ " (0): Downsample2D(\n",
56
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
57
+ " )\n",
58
+ " )\n",
59
+ " )\n",
60
+ " (3): DownEncoderBlock2D(\n",
61
+ " (resnets): ModuleList(\n",
62
+ " (0): ResnetBlock2D(\n",
63
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
64
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
65
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
66
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
67
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
68
+ " (nonlinearity): SiLU()\n",
69
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
70
+ " )\n",
71
+ " (1-2): 2 x ResnetBlock2D(\n",
72
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
73
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
74
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
75
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
76
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
77
+ " (nonlinearity): SiLU()\n",
78
+ " )\n",
79
+ " )\n",
80
+ " (downsamplers): ModuleList(\n",
81
+ " (0): Downsample2D(\n",
82
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n",
83
+ " )\n",
84
+ " )\n",
85
+ " )\n",
86
+ " (4): DownEncoderBlock2D(\n",
87
+ " (resnets): ModuleList(\n",
88
+ " (0-2): 3 x ResnetBlock2D(\n",
89
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
90
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
91
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
92
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
93
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
94
+ " (nonlinearity): SiLU()\n",
95
+ " )\n",
96
+ " )\n",
97
+ " )\n",
98
+ " )\n",
99
+ " (mid_block): UNetMidBlock2D(\n",
100
+ " (attentions): ModuleList(\n",
101
+ " (0): Attention(\n",
102
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
103
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
104
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
105
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
106
+ " (to_out): ModuleList(\n",
107
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
108
+ " (1): Dropout(p=0.0, inplace=False)\n",
109
+ " )\n",
110
+ " )\n",
111
+ " )\n",
112
+ " (resnets): ModuleList(\n",
113
+ " (0-1): 2 x ResnetBlock2D(\n",
114
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
115
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
116
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
117
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
118
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
119
+ " (nonlinearity): SiLU()\n",
120
+ " )\n",
121
+ " )\n",
122
+ " )\n",
123
+ " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
124
+ " (conv_act): SiLU()\n",
125
+ " (conv_out): Conv2d(512, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
126
+ " )\n",
127
+ " (decoder): Decoder(\n",
128
+ " (conv_in): Conv2d(32, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
129
+ " (up_blocks): ModuleList(\n",
130
+ " (0-1): 2 x UpDecoderBlock2D(\n",
131
+ " (resnets): ModuleList(\n",
132
+ " (0-3): 4 x ResnetBlock2D(\n",
133
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
134
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
135
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
136
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
137
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
138
+ " (nonlinearity): SiLU()\n",
139
+ " )\n",
140
+ " )\n",
141
+ " (upsamplers): ModuleList(\n",
142
+ " (0): Upsample2D(\n",
143
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
144
+ " )\n",
145
+ " )\n",
146
+ " )\n",
147
+ " (2): UpDecoderBlock2D(\n",
148
+ " (resnets): ModuleList(\n",
149
+ " (0): ResnetBlock2D(\n",
150
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
151
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
152
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
153
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
154
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
155
+ " (nonlinearity): SiLU()\n",
156
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
157
+ " )\n",
158
+ " (1-3): 3 x ResnetBlock2D(\n",
159
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
160
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
161
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
162
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
163
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
164
+ " (nonlinearity): SiLU()\n",
165
+ " )\n",
166
+ " )\n",
167
+ " (upsamplers): ModuleList(\n",
168
+ " (0): Upsample2D(\n",
169
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
170
+ " )\n",
171
+ " )\n",
172
+ " )\n",
173
+ " (3): UpDecoderBlock2D(\n",
174
+ " (resnets): ModuleList(\n",
175
+ " (0): ResnetBlock2D(\n",
176
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
177
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
178
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
179
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
180
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
181
+ " (nonlinearity): SiLU()\n",
182
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
183
+ " )\n",
184
+ " (1-3): 3 x ResnetBlock2D(\n",
185
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
186
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
187
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
188
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
189
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
190
+ " (nonlinearity): SiLU()\n",
191
+ " )\n",
192
+ " )\n",
193
+ " (upsamplers): ModuleList(\n",
194
+ " (0): Upsample2D(\n",
195
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
196
+ " )\n",
197
+ " )\n",
198
+ " )\n",
199
+ " (4): UpDecoderBlock2D(\n",
200
+ " (resnets): ModuleList(\n",
201
+ " (0-3): 4 x ResnetBlock2D(\n",
202
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
203
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
204
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
205
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
206
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
207
+ " (nonlinearity): SiLU()\n",
208
+ " )\n",
209
+ " )\n",
210
+ " )\n",
211
+ " )\n",
212
+ " (mid_block): UNetMidBlock2D(\n",
213
+ " (attentions): ModuleList(\n",
214
+ " (0): Attention(\n",
215
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
216
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
217
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
218
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
219
+ " (to_out): ModuleList(\n",
220
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
221
+ " (1): Dropout(p=0.0, inplace=False)\n",
222
+ " )\n",
223
+ " )\n",
224
+ " )\n",
225
+ " (resnets): ModuleList(\n",
226
+ " (0-1): 2 x ResnetBlock2D(\n",
227
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
228
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
229
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
230
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
231
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
232
+ " (nonlinearity): SiLU()\n",
233
+ " )\n",
234
+ " )\n",
235
+ " )\n",
236
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
237
+ " (conv_act): SiLU()\n",
238
+ " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
239
+ " )\n",
240
+ " (quant_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
241
+ " (post_quant_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
242
+ ")\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "from diffusers.models import AutoencoderKL\n",
248
+ "import torch\n",
249
+ "\n",
250
+ "config = {\n",
251
+ " \"_class_name\": \"AutoencoderKL\",\n",
252
+ " \"_diffusers_version\": \"0.36.0\",\n",
253
+ " \"act_fn\": \"silu\",\n",
254
+ " \"block_out_channels\": [\n",
255
+ " 128,\n",
256
+ " 128,\n",
257
+ " 256,\n",
258
+ " 512,\n",
259
+ " 512\n",
260
+ " ],\n",
261
+ " \"down_block_types\": [\n",
262
+ " \"DownEncoderBlock2D\",\n",
263
+ " \"DownEncoderBlock2D\",\n",
264
+ " \"DownEncoderBlock2D\",\n",
265
+ " \"DownEncoderBlock2D\",\n",
266
+ " \"DownEncoderBlock2D\"\n",
267
+ " ],\n",
268
+ " \"force_upcast\": False,\n",
269
+ " \"in_channels\": 3,\n",
270
+ " \"latent_channels\": 32,\n",
271
+ " \"latents_mean\": [\n",
272
+ " -0.03542253375053406,\n",
273
+ " 0.20086465775966644,\n",
274
+ " -0.016413161531090736,\n",
275
+ " -0.0956302210688591,\n",
276
+ " -0.2672063112258911,\n",
277
+ " 0.2609933018684387,\n",
278
+ " -0.07806991040706635,\n",
279
+ " -0.48407721519470215,\n",
280
+ " 0.21844269335269928,\n",
281
+ " -0.1122383326292038,\n",
282
+ " 0.27197545766830444,\n",
283
+ " -0.18958772718906403,\n",
284
+ " 0.18776826560497284,\n",
285
+ " 0.0987580344080925,\n",
286
+ " 0.2837068736553192,\n",
287
+ " -0.4486690163612366,\n",
288
+ " 0.4816776514053345,\n",
289
+ " 0.02947971224784851,\n",
290
+ " -0.1337375044822693,\n",
291
+ " -0.39750921726226807,\n",
292
+ " -0.08513020724058151,\n",
293
+ " -0.054023586213588715,\n",
294
+ " -0.3943594992160797,\n",
295
+ " 0.23918119072914124,\n",
296
+ " -0.12466679513454437,\n",
297
+ " 0.09935147315263748,\n",
298
+ " 0.31858691573143005,\n",
299
+ " 0.48585832118988037,\n",
300
+ " -0.6416525840759277,\n",
301
+ " -0.15164820849895477,\n",
302
+ " -0.4693508744239807,\n",
303
+ " -0.13071806728839874\n",
304
+ " ],\n",
305
+ " \"latents_std\": [\n",
306
+ " 1.5792087316513062,\n",
307
+ " 1.5769503116607666,\n",
308
+ " 1.5864241123199463,\n",
309
+ " 1.6454921960830688,\n",
310
+ " 1.5336694717407227,\n",
311
+ " 1.5587652921676636,\n",
312
+ " 1.5838669538497925,\n",
313
+ " 1.5659377574920654,\n",
314
+ " 1.6860467195510864,\n",
315
+ " 1.5192310810089111,\n",
316
+ " 1.573639988899231,\n",
317
+ " 1.5953549146652222,\n",
318
+ " 1.5271092653274536,\n",
319
+ " 1.6246271133422852,\n",
320
+ " 1.7054023742675781,\n",
321
+ " 1.607722282409668,\n",
322
+ " 1.558642864227295,\n",
323
+ " 1.5824549198150635,\n",
324
+ " 1.6202995777130127,\n",
325
+ " 1.6206320524215698,\n",
326
+ " 1.6379750967025757,\n",
327
+ " 1.6527063846588135,\n",
328
+ " 1.498811960220337,\n",
329
+ " 1.5706247091293335,\n",
330
+ " 1.5854856967926025,\n",
331
+ " 1.4828169345855713,\n",
332
+ " 1.5693111419677734,\n",
333
+ " 1.692481517791748,\n",
334
+ " 1.6409776210784912,\n",
335
+ " 1.6216280460357666,\n",
336
+ " 1.6087706089019775,\n",
337
+ " 1.5776633024215698\n",
338
+ " ],\n",
339
+ " \"layers_per_block\": 2,\n",
340
+ " \"mid_block_add_attention\": True,\n",
341
+ " \"norm_num_groups\": 32,\n",
342
+ " \"out_channels\": 3,\n",
343
+ " \"sample_size\": 32,\n",
344
+ " \"scaling_factor\": 1.0,\n",
345
+ " \"shift_factor\": 0.0,\n",
346
+ " \"up_block_types\": [\n",
347
+ " \"UpDecoderBlock2D\",\n",
348
+ " \"UpDecoderBlock2D\",\n",
349
+ " \"UpDecoderBlock2D\",\n",
350
+ " \"UpDecoderBlock2D\",\n",
351
+ " \"UpDecoderBlock2D\"\n",
352
+ " ],\n",
353
+ " \"use_post_quant_conv\": True,\n",
354
+ " \"use_quant_conv\": True\n",
355
+ "}\n",
356
+ "\n",
357
+ "\n",
358
+ "vae = AutoencoderKL(\n",
359
+ " act_fn=config[\"act_fn\"],\n",
360
+ " block_out_channels=config[\"block_out_channels\"],\n",
361
+ " down_block_types=config[\"down_block_types\"],\n",
362
+ " up_block_types=config[\"up_block_types\"],\n",
363
+ " in_channels=config[\"in_channels\"],\n",
364
+ " out_channels=config[\"out_channels\"],\n",
365
+ " latent_channels=config[\"latent_channels\"],\n",
366
+ " layers_per_block=3, #config[\"layers_per_block\"],\n",
367
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
368
+ " sample_size=config[\"sample_size\"],\n",
369
+ " scaling_factor=config[\"scaling_factor\"],\n",
370
+ " force_upcast=config[\"force_upcast\"],\n",
371
+ " mid_block_add_attention=config[\"mid_block_add_attention\"],\n",
372
+ " use_quant_conv=config[\"use_quant_conv\"],\n",
373
+ " use_post_quant_conv=config[\"use_post_quant_conv\"],\n",
374
+ " latents_mean=(config[\"latents_mean\"]),\n",
375
+ " latents_std=(config[\"latents_std\"]),\n",
376
+ ")\n",
377
+ "\n",
378
+ "vae.save_pretrained(\"vae16x32ch_empty\")\n",
379
+ "\n",
380
+ "print(vae)"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 1,
386
+ "id": "a2950158-5203-42b9-8791-e231ddbf1063",
387
+ "metadata": {},
388
+ "outputs": [
389
+ {
390
+ "name": "stderr",
391
+ "output_type": "stream",
392
+ "text": [
393
+ "Перенос весов: 100%|██████████| 292/292 [00:00<00:00, 35760.83it/s]\n"
394
+ ]
395
+ },
396
+ {
397
+ "name": "stdout",
398
+ "output_type": "stream",
399
+ "text": [
400
+ "Статистика переноса: {'перенесено': 292, 'несовпадение_размеров': 0, 'пропущено': 0}\n",
401
+ "Неперенесенные ключи в новой модели:\n",
402
+ "decoder.up_blocks.0.resnets.3.conv1.bias\n",
403
+ "decoder.up_blocks.0.resnets.3.conv1.weight\n",
404
+ "decoder.up_blocks.0.resnets.3.conv2.bias\n",
405
+ "decoder.up_blocks.0.resnets.3.conv2.weight\n",
406
+ "decoder.up_blocks.0.resnets.3.norm1.bias\n",
407
+ "decoder.up_blocks.0.resnets.3.norm1.weight\n",
408
+ "decoder.up_blocks.0.resnets.3.norm2.bias\n",
409
+ "decoder.up_blocks.0.resnets.3.norm2.weight\n",
410
+ "decoder.up_blocks.1.resnets.3.conv1.bias\n",
411
+ "decoder.up_blocks.1.resnets.3.conv1.weight\n",
412
+ "decoder.up_blocks.1.resnets.3.conv2.bias\n",
413
+ "decoder.up_blocks.1.resnets.3.conv2.weight\n",
414
+ "decoder.up_blocks.1.resnets.3.norm1.bias\n",
415
+ "decoder.up_blocks.1.resnets.3.norm1.weight\n",
416
+ "decoder.up_blocks.1.resnets.3.norm2.bias\n",
417
+ "decoder.up_blocks.1.resnets.3.norm2.weight\n",
418
+ "decoder.up_blocks.2.resnets.3.conv1.bias\n",
419
+ "decoder.up_blocks.2.resnets.3.conv1.weight\n",
420
+ "decoder.up_blocks.2.resnets.3.conv2.bias\n",
421
+ "decoder.up_blocks.2.resnets.3.conv2.weight\n",
422
+ "decoder.up_blocks.2.resnets.3.norm1.bias\n",
423
+ "decoder.up_blocks.2.resnets.3.norm1.weight\n",
424
+ "decoder.up_blocks.2.resnets.3.norm2.bias\n",
425
+ "decoder.up_blocks.2.resnets.3.norm2.weight\n",
426
+ "decoder.up_blocks.3.resnets.3.conv1.bias\n",
427
+ "decoder.up_blocks.3.resnets.3.conv1.weight\n",
428
+ "decoder.up_blocks.3.resnets.3.conv2.bias\n",
429
+ "decoder.up_blocks.3.resnets.3.conv2.weight\n",
430
+ "decoder.up_blocks.3.resnets.3.norm1.bias\n",
431
+ "decoder.up_blocks.3.resnets.3.norm1.weight\n",
432
+ "decoder.up_blocks.3.resnets.3.norm2.bias\n",
433
+ "decoder.up_blocks.3.resnets.3.norm2.weight\n",
434
+ "decoder.up_blocks.4.resnets.3.conv1.bias\n",
435
+ "decoder.up_blocks.4.resnets.3.conv1.weight\n",
436
+ "decoder.up_blocks.4.resnets.3.conv2.bias\n",
437
+ "decoder.up_blocks.4.resnets.3.conv2.weight\n",
438
+ "decoder.up_blocks.4.resnets.3.norm1.bias\n",
439
+ "decoder.up_blocks.4.resnets.3.norm1.weight\n",
440
+ "decoder.up_blocks.4.resnets.3.norm2.bias\n",
441
+ "decoder.up_blocks.4.resnets.3.norm2.weight\n",
442
+ "encoder.down_blocks.0.resnets.2.conv1.bias\n",
443
+ "encoder.down_blocks.0.resnets.2.conv1.weight\n",
444
+ "encoder.down_blocks.0.resnets.2.conv2.bias\n",
445
+ "encoder.down_blocks.0.resnets.2.conv2.weight\n",
446
+ "encoder.down_blocks.0.resnets.2.norm1.bias\n",
447
+ "encoder.down_blocks.0.resnets.2.norm1.weight\n",
448
+ "encoder.down_blocks.0.resnets.2.norm2.bias\n",
449
+ "encoder.down_blocks.0.resnets.2.norm2.weight\n",
450
+ "encoder.down_blocks.1.resnets.2.conv1.bias\n",
451
+ "encoder.down_blocks.1.resnets.2.conv1.weight\n",
452
+ "encoder.down_blocks.1.resnets.2.conv2.bias\n",
453
+ "encoder.down_blocks.1.resnets.2.conv2.weight\n",
454
+ "encoder.down_blocks.1.resnets.2.norm1.bias\n",
455
+ "encoder.down_blocks.1.resnets.2.norm1.weight\n",
456
+ "encoder.down_blocks.1.resnets.2.norm2.bias\n",
457
+ "encoder.down_blocks.1.resnets.2.norm2.weight\n",
458
+ "encoder.down_blocks.2.resnets.2.conv1.bias\n",
459
+ "encoder.down_blocks.2.resnets.2.conv1.weight\n",
460
+ "encoder.down_blocks.2.resnets.2.conv2.bias\n",
461
+ "encoder.down_blocks.2.resnets.2.conv2.weight\n",
462
+ "encoder.down_blocks.2.resnets.2.norm1.bias\n",
463
+ "encoder.down_blocks.2.resnets.2.norm1.weight\n",
464
+ "encoder.down_blocks.2.resnets.2.norm2.bias\n",
465
+ "encoder.down_blocks.2.resnets.2.norm2.weight\n",
466
+ "encoder.down_blocks.3.resnets.2.conv1.bias\n",
467
+ "encoder.down_blocks.3.resnets.2.conv1.weight\n",
468
+ "encoder.down_blocks.3.resnets.2.conv2.bias\n",
469
+ "encoder.down_blocks.3.resnets.2.conv2.weight\n",
470
+ "encoder.down_blocks.3.resnets.2.norm1.bias\n",
471
+ "encoder.down_blocks.3.resnets.2.norm1.weight\n",
472
+ "encoder.down_blocks.3.resnets.2.norm2.bias\n",
473
+ "encoder.down_blocks.3.resnets.2.norm2.weight\n",
474
+ "encoder.down_blocks.4.resnets.2.conv1.bias\n",
475
+ "encoder.down_blocks.4.resnets.2.conv1.weight\n",
476
+ "encoder.down_blocks.4.resnets.2.conv2.bias\n",
477
+ "encoder.down_blocks.4.resnets.2.conv2.weight\n",
478
+ "encoder.down_blocks.4.resnets.2.norm1.bias\n",
479
+ "encoder.down_blocks.4.resnets.2.norm1.weight\n",
480
+ "encoder.down_blocks.4.resnets.2.norm2.bias\n",
481
+ "encoder.down_blocks.4.resnets.2.norm2.weight\n"
482
+ ]
483
+ }
484
+ ],
485
+ "source": [
486
+ "import torch\n",
487
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
488
+ "from tqdm import tqdm\n",
489
+ "import torch.nn.init as init\n",
490
+ "\n",
491
+ "def log(message):\n",
492
+ " print(message)\n",
493
+ "\n",
494
+ "def main():\n",
495
+ " checkpoint_path_old = \"vae16x32ch_new\"\n",
496
+ " checkpoint_path_new = \"vae16x32ch_empty\"\n",
497
+ " device = \"cuda\"\n",
498
+ " dtype = torch.float32\n",
499
+ "\n",
500
+ " # Загрузка моделей\n",
501
+ " old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
502
+ " new_unet = AutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
503
+ "\n",
504
+ " old_state_dict = old_unet.state_dict()\n",
505
+ " new_state_dict = new_unet.state_dict()\n",
506
+ "\n",
507
+ " transferred_state_dict = {}\n",
508
+ " transfer_stats = {\n",
509
+ " \"перенесено\": 0,\n",
510
+ " \"несовпадение_размеров\": 0,\n",
511
+ " \"пропущено\": 0\n",
512
+ " }\n",
513
+ "\n",
514
+ " transferred_keys = set()\n",
515
+ "\n",
516
+ " # Обрабатываем каждый ключ старой модели\n",
517
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
518
+ " new_key = old_key\n",
519
+ "\n",
520
+ " if new_key in new_state_dict:\n",
521
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
522
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
523
+ " transferred_keys.add(new_key)\n",
524
+ " transfer_stats[\"перенесено\"] += 1\n",
525
+ " else:\n",
526
+ " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n",
527
+ " transfer_stats[\"несовпадение_размеров\"] += 1\n",
528
+ " else:\n",
529
+ " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n",
530
+ " transfer_stats[\"пропущено\"] += 1\n",
531
+ "\n",
532
+ " # Обновляем состояние новой модели перенесенными весами\n",
533
+ " new_state_dict.update(transferred_state_dict)\n",
534
+ " \n",
535
+ " # Инициализируем веса для нового mid блока\n",
536
+ " #new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n",
537
+ " \n",
538
+ " new_unet.load_state_dict(new_state_dict)\n",
539
+ " new_unet.save_pretrained(\"vae16x32ch\")\n",
540
+ "\n",
541
+ " # Получаем список неперенесенных ключей\n",
542
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
543
+ "\n",
544
+ " print(\"Статистика переноса:\", transfer_stats)\n",
545
+ " print(\"Неперенесенные ключи в новой модели:\")\n",
546
+ " for key in non_transferred_keys:\n",
547
+ " print(key)\n",
548
+ "\n",
549
+ "if __name__ == \"__main__\":\n",
550
+ " main()"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": 1,
556
+ "id": "b316ee6c-d295-4396-9177-78e39a53055b",
557
+ "metadata": {},
558
+ "outputs": [
559
+ {
560
+ "name": "stderr",
561
+ "output_type": "stream",
562
+ "text": [
563
+ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
564
+ ]
565
+ },
566
+ {
567
+ "name": "stdout",
568
+ "output_type": "stream",
569
+ "text": [
570
+ "ok\n"
571
+ ]
572
+ }
573
+ ],
574
+ "source": [
575
+ "import torch\n",
576
+ "\n",
577
+ "from torchvision import transforms, utils\n",
578
+ "\n",
579
+ "import diffusers\n",
580
+ "from diffusers import AsymmetricAutoencoderKL\n",
581
+ "\n",
582
+ "from diffusers.utils import load_image\n",
583
+ "\n",
584
+ "def crop_image_to_nearest_divisible_by_8(img):\n",
585
+ " # Check if the image height and width are divisible by 8\n",
586
+ " if img.shape[1] % 8 == 0 and img.shape[2] % 8 == 0:\n",
587
+ " return img\n",
588
+ " else:\n",
589
+ " # Calculate the closest lower resolution divisible by 8\n",
590
+ " new_height = img.shape[1] - (img.shape[1] % 8)\n",
591
+ " new_width = img.shape[2] - (img.shape[2] % 8)\n",
592
+ " \n",
593
+ " # Use CenterCrop to crop the image\n",
594
+ " transform = transforms.CenterCrop((new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR)\n",
595
+ " img = transform(img).to(torch.float32).clamp(-1, 1)\n",
596
+ " \n",
597
+ " return img\n",
598
+ " \n",
599
+ "to_tensor = transforms.ToTensor()\n",
600
+ "\n",
601
+ "device = \"cuda\"\n",
602
+ "dtype=torch.float16\n",
603
+ "vae = AsymmetricAutoencoderKL.from_pretrained(\"asymmetric_vae\",torch_dtype=dtype).to(device).eval()\n",
604
+ "\n",
605
+ "image = load_image(\"123456789.jpg\")\n",
606
+ "\n",
607
+ "image = crop_image_to_nearest_divisible_by_8(to_tensor(image)).unsqueeze(0).to(device,dtype=dtype)\n",
608
+ "\n",
609
+ "upscaled_image = vae(image).sample\n",
610
+ "#vae.config.scaled_factor\n",
611
+ "# Save the reconstructed image\n",
612
+ "utils.save_image(upscaled_image, \"test.png\")\n",
613
+ "print('ok')"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": 11,
619
+ "id": "5a01b8e9-73c9-4da7-a097-e334019bd8e9",
620
+ "metadata": {},
621
+ "outputs": [
622
+ {
623
+ "name": "stderr",
624
+ "output_type": "stream",
625
+ "text": [
626
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False, 'latents_mean': [-0.03542253375053406, 0.20086465775966644, -0.016413161531090736, -0.0956302210688591, -0.2672063112258911, 0.2609933018684387, -0.07806991040706635, -0.48407721519470215, 0.21844269335269928, -0.1122383326292038, 0.27197545766830444, -0.18958772718906403, 0.18776826560497284, 0.0987580344080925, 0.2837068736553192, -0.4486690163612366, 0.4816776514053345, 0.02947971224784851, -0.1337375044822693, -0.39750921726226807, -0.08513020724058151, -0.054023586213588715, -0.3943594992160797, 0.23918119072914124, -0.12466679513454437, 0.09935147315263748, 0.31858691573143005, 0.48585832118988037, -0.6416525840759277, -0.15164820849895477, -0.4693508744239807, -0.13071806728839874], 'latents_std': [1.5792087316513062, 1.5769503116607666, 1.5864241123199463, 1.6454921960830688, 1.5336694717407227, 1.5587652921676636, 1.5838669538497925, 1.5659377574920654, 1.6860467195510864, 1.5192310810089111, 1.573639988899231, 1.5953549146652222, 1.5271092653274536, 1.6246271133422852, 1.7054023742675781, 1.607722282409668, 1.558642864227295, 1.5824549198150635, 1.6202995777130127, 1.6206320524215698, 1.6379750967025757, 1.6527063846588135, 1.498811960220337, 1.5706247091293335, 1.5854856967926025, 1.4828169345855713, 1.5693111419677734, 1.692481517791748, 1.6409776210784912, 1.6216280460357666, 1.6087706089019775, 1.5776633024215698]} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
627
+ "Перенос весов: 100%|██████████| 284/284 [00:00<00:00, 30094.80it/s]\n"
628
+ ]
629
+ },
630
+ {
631
+ "name": "stdout",
632
+ "output_type": "stream",
633
+ "text": [
634
+ "Статистика: {'перенесено': 292, 'несовпадение_размеров': 0, 'пропущено': 10}\n",
635
+ "\n",
636
+ "Неперенесенные ��лючи:\n"
637
+ ]
638
+ }
639
+ ],
640
+ "source": [
641
+ "import torch\n",
642
+ "from diffusers import AutoencoderKL, AsymmetricAutoencoderKL\n",
643
+ "from tqdm import tqdm\n",
644
+ "\n",
645
+ "\n",
646
+ "def log(message):\n",
647
+ " print(message)\n",
648
+ "\n",
649
+ "\n",
650
+ "def remap_key(old_key: str):\n",
651
+ " \"\"\"\n",
652
+ " Смещение только encoder.down_blocks\n",
653
+ " \"\"\"\n",
654
+ "\n",
655
+ " if \"encoder.down_blocks\" not in old_key:\n",
656
+ " return [old_key]\n",
657
+ "\n",
658
+ " parts = old_key.split(\".\")\n",
659
+ " block_id = int(parts[2])\n",
660
+ "\n",
661
+ " if block_id == 0:\n",
662
+ " # первый блок копируем дважды\n",
663
+ " return [\n",
664
+ " old_key.replace(\"down_blocks.0\", \"down_blocks.0\"),\n",
665
+ " old_key.replace(\"down_blocks.0\", \"down_blocks.1\"),\n",
666
+ " ]\n",
667
+ "\n",
668
+ " # остальные блоки сдвигаем\n",
669
+ " new_block = block_id + 1\n",
670
+ " return [old_key.replace(f\"down_blocks.{block_id}\", f\"down_blocks.{new_block}\")]\n",
671
+ "\n",
672
+ "\n",
673
+ "def main():\n",
674
+ " checkpoint_path_old = \"asymmetric_vae_new\"\n",
675
+ " checkpoint_path_new = \"vae16x32ch_empty\"\n",
676
+ "\n",
677
+ " device = \"cuda\"\n",
678
+ " dtype = torch.float32\n",
679
+ "\n",
680
+ " old_vae = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
681
+ " new_vae = AutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
682
+ "\n",
683
+ " old_state_dict = old_vae.state_dict()\n",
684
+ " new_state_dict = new_vae.state_dict()\n",
685
+ "\n",
686
+ " transferred_state_dict = {}\n",
687
+ " transferred_keys = set()\n",
688
+ "\n",
689
+ " transfer_stats = {\n",
690
+ " \"перенесено\": 0,\n",
691
+ " \"несовпадение_размеров\": 0,\n",
692
+ " \"пропущено\": 0\n",
693
+ " }\n",
694
+ "\n",
695
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
696
+ "\n",
697
+ " new_keys = remap_key(old_key)\n",
698
+ "\n",
699
+ " for new_key in new_keys:\n",
700
+ "\n",
701
+ " if new_key in new_state_dict:\n",
702
+ "\n",
703
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
704
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
705
+ " transferred_keys.add(new_key)\n",
706
+ " transfer_stats[\"перенесено\"] += 1\n",
707
+ " else:\n",
708
+ " log(\n",
709
+ " f\"✗ Несовпадение размеров: \"\n",
710
+ " f\"{old_key} {old_state_dict[old_key].shape} \"\n",
711
+ " f\"-> {new_key} {new_state_dict[new_key].shape}\"\n",
712
+ " )\n",
713
+ " transfer_stats[\"несовпадение_размеров\"] += 1\n",
714
+ " else:\n",
715
+ " transfer_stats[\"пропущено\"] += 1\n",
716
+ "\n",
717
+ " new_state_dict.update(transferred_state_dict)\n",
718
+ "\n",
719
+ " new_vae.load_state_dict(new_state_dict)\n",
720
+ " new_vae.save_pretrained(\"vae16x32ch\")\n",
721
+ "\n",
722
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
723
+ "\n",
724
+ " print(\"Статистика:\", transfer_stats)\n",
725
+ "\n",
726
+ " print(\"\\nНеперенесенные ключи:\")\n",
727
+ " for key in non_transferred_keys:\n",
728
+ " print(key)\n",
729
+ "\n",
730
+ "\n",
731
+ "if __name__ == \"__main__\":\n",
732
+ " main()"
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "code",
737
+ "execution_count": null,
738
+ "id": "fe8f1ceb-8d3e-4df5-a1dc-1b56a0d398a2",
739
+ "metadata": {},
740
+ "outputs": [],
741
+ "source": []
742
+ }
743
+ ],
744
+ "metadata": {
745
+ "kernelspec": {
746
+ "display_name": "Python3 (ipykernel)",
747
+ "language": "python",
748
+ "name": "python3"
749
+ },
750
+ "language_info": {
751
+ "codemirror_mode": {
752
+ "name": "ipython",
753
+ "version": 3
754
+ },
755
+ "file_extension": ".py",
756
+ "mimetype": "text/x-python",
757
+ "name": "python",
758
+ "nbconvert_exporter": "python",
759
+ "pygments_lexer": "ipython3",
760
+ "version": "3.12.12"
761
+ }
762
+ },
763
+ "nbformat": 4,
764
+ "nbformat_minor": 5
765
+ }
diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d349936a1faab1555bf81e68ba1e6fd2b84f6a6a46ffd11470079581d48bdfea
3
- size 343311604
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb16e414b415d5addff3f309a6a2f1e2ba39145587b1b2b6aa77b82ab20d5a4
3
+ size 433047700
scale.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from diffusers import AutoencoderKL
5
+ from tqdm import tqdm
6
+ import pathlib
7
+
8
+ # ── 1. Загружаем VAE ──────────────────────────────────────────────────────────
9
+ vae = AutoencoderKL.from_pretrained("vae32ch", torch_dtype=torch.float32)
10
+ vae.eval().cuda()
11
+
12
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) # = 8
13
+
14
+ # ── 2. Собираем все PNG рекурсивно ───────────────────────────────────────────
15
+ dataset_path = pathlib.Path("/workspace/ds")
16
+ image_paths = sorted(dataset_path.rglob("*.png"))
17
+ print(f"Найдено картинок: {len(image_paths)}")
18
+
19
+ # Берём первые 3000
20
+ image_paths = image_paths[:30000]
21
+
22
+ # ── 3. Препроцессинг — кроп до кратного 8 без ресайза ────────────────────────
23
+ def preprocess(path):
24
+ img = Image.open(path).convert("RGB")
25
+ w, h = img.size
26
+
27
+ new_w = (w // vae_scale_factor) * vae_scale_factor
28
+ new_h = (h // vae_scale_factor) * vae_scale_factor
29
+
30
+ if new_w != w or new_h != h:
31
+ left = (w - new_w) // 2
32
+ top = (h - new_h) // 2
33
+ img = img.crop((left, top, left + new_w, top + new_h))
34
+
35
+ x = torch.from_numpy(np.array(img).astype(np.float32) / 255.0)
36
+ x = x.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
37
+ x = x * 2.0 - 1.0 # [-1, 1]
38
+ return x
39
+
40
+ # ── 4. Считаем статистику по каналам ─────────────────────────────────────────
41
+ latent_channels = vae.config.latent_channels # 32
42
+
43
+ all_means = [] # [N, C]
44
+ all_stds = [] # [N, C]
45
+ errors = []
46
+
47
+ with torch.no_grad():
48
+ for path in tqdm(image_paths, desc="Encoding"):
49
+ try:
50
+ x = preprocess(path).cuda()
51
+ lat = vae.encode(x).latent_dist.sample() # [1, C, H, W]
52
+ flat = lat.squeeze(0).float().reshape(latent_channels, -1) # [C, H*W]
53
+
54
+ all_means.append(flat.mean(dim=1).cpu()) # [C]
55
+ all_stds.append(flat.std(dim=1).cpu()) # [C]
56
+
57
+ except Exception as e:
58
+ errors.append((path, str(e)))
59
+
60
+ if errors:
61
+ print(f"\nОшибки ({len(errors)}):")
62
+ for p, e in errors:
63
+ print(f" {p}: {e}")
64
+
65
+ mean = torch.stack(all_means).mean(dim=0) # [C]
66
+ std = torch.stack(all_stds).mean(dim=0) # [C]
67
+
68
+ print(f"\nОбработано картинок: {len(all_means)}")
69
+ print(f"\nlatents_mean ({latent_channels} каналов):")
70
+ print(mean.tolist())
71
+ print(f"\nlatents_std ({latent_channels} каналов):")
72
+ print(std.tolist())
73
+
74
+ # ── 5. Создаём новый VAE с той же архитектурой + scaling векторы ──────────────
75
+ cfg = vae.config
76
+
77
+ new_vae = AutoencoderKL(
78
+ in_channels = cfg.in_channels,
79
+ out_channels = cfg.out_channels,
80
+ latent_channels = cfg.latent_channels,
81
+ block_out_channels = cfg.block_out_channels,
82
+ layers_per_block = cfg.layers_per_block,
83
+ norm_num_groups = cfg.norm_num_groups,
84
+ act_fn = cfg.act_fn,
85
+ down_block_types = cfg.down_block_types,
86
+ up_block_types = cfg.up_block_types,
87
+ )
88
+ new_vae.eval()
89
+
90
+ # Переносим веса
91
+ result = new_vae.load_state_dict(vae.state_dict(), strict=False)
92
+ print(f"\nВеса перенесены: {result}")
93
+
94
+ # Прописываем scaling векторы в конфиг
95
+ new_vae.register_to_config(
96
+ latents_mean = mean.tolist(),
97
+ latents_std = std.tolist(),
98
+ scaling_factor = 1.0,
99
+ shift_factor = 0.0,
100
+ )
101
+
102
+ print(f"\nlatents_mean в конфиге: {new_vae.config.latents_mean[:4]}...")
103
+ print(f"latents_std в конфиге: {new_vae.config.latents_std[:4]}...")
104
+
105
+ # ── 6. Сохраняем ──────────────────────────────────────────────────────────────
106
+ new_vae.save_pretrained("vae32ch2")
107
+ print("\nСохранено в vae32ch2/")
train_vae_16x.py CHANGED
@@ -29,7 +29,7 @@ from collections import deque
29
 
30
  # --------------------------- Параметры ---------------------------
31
  ds_path = "/workspace/d23"
32
- project = "vae16x32ch_new"
33
  batch_size = 1
34
  base_learning_rate = 6e-6
35
  min_learning_rate = 7e-7
@@ -41,8 +41,8 @@ use_decay = True
41
  optimizer_type = "adam8bit"
42
  dtype = torch.float32
43
 
44
- model_resolution = 768 #448 #288
45
- high_resolution = 768 #896 #576
46
  limit = 0
47
  save_barrier = 1.3
48
  warmup_percent = 0.005
@@ -53,7 +53,7 @@ clip_grad_norm = 1.0
53
  mixed_precision = "no"
54
  gradient_accumulation_steps = 1
55
  generated_folder = "samples"
56
- save_as = "vae16x32ch_new"
57
  num_workers = 0
58
  device = None
59
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -95,7 +95,7 @@ accelerator = Accelerator(
95
  device = accelerator.device
96
 
97
  # reproducibility
98
- seed = int(datetime.now().strftime("%Y%m%d")) + 13
99
  torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
100
  torch.backends.cudnn.benchmark = False
101
 
 
29
 
30
  # --------------------------- Параметры ---------------------------
31
  ds_path = "/workspace/d23"
32
+ project = "vae16x32ch"
33
  batch_size = 1
34
  base_learning_rate = 6e-6
35
  min_learning_rate = 7e-7
 
41
  optimizer_type = "adam8bit"
42
  dtype = torch.float32
43
 
44
+ model_resolution = 640 #448 #288
45
+ high_resolution = 640 #896 #576
46
  limit = 0
47
  save_barrier = 1.3
48
  warmup_percent = 0.005
 
53
  mixed_precision = "no"
54
  gradient_accumulation_steps = 1
55
  generated_folder = "samples"
56
+ save_as = "vae16x32ch"
57
  num_workers = 0
58
  device = None
59
  torch.backends.cuda.matmul.allow_tf32 = True
 
95
  device = accelerator.device
96
 
97
  # reproducibility
98
+ seed = int(datetime.now().strftime("%Y%m%d")) + 42
99
  torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
100
  torch.backends.cudnn.benchmark = False
101