File size: 21,209 Bytes
77a5b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c15deb04-94a0-4073-a174-adcd22af10b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Создана новая модель: <class 'diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL'>\n",
      "\n",
      "--- Перенос весов ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 326/326 [00:00<00:00, 54186.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "✅ Перенос завершён.\n",
      "Статистика:\n",
      "  перенесено: 227\n",
      "  дублировано: 0\n",
      "  пропущено: 0\n",
      "AutoencoderKL(\n",
      "  (encoder): Encoder(\n",
      "    (conv_in): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (down_blocks): ModuleList(\n",
      "      (0): DownEncoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0-2): 3 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "        (downsamplers): ModuleList(\n",
      "          (0): Downsample2D(\n",
      "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (1): DownEncoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0): ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "            (conv_shortcut): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "          )\n",
      "          (1-2): 2 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "        (downsamplers): ModuleList(\n",
      "          (0): Downsample2D(\n",
      "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (2): DownEncoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0): ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "            (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          )\n",
      "          (1-2): 2 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "        (downsamplers): ModuleList(\n",
      "          (0): Downsample2D(\n",
      "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (3): DownEncoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0): ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "            (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
      "          )\n",
      "          (1-2): 2 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "        (downsamplers): ModuleList(\n",
      "          (0): Downsample2D(\n",
      "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (4): DownEncoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0-2): 3 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (mid_block): UNetMidBlock2D(\n",
      "      (attentions): ModuleList(\n",
      "        (0): Attention(\n",
      "          (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "          (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
      "          (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
      "          (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
      "          (to_out): ModuleList(\n",
      "            (0): Linear(in_features=512, out_features=512, bias=True)\n",
      "            (1): Dropout(p=0.0, inplace=False)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (resnets): ModuleList(\n",
      "        (0-1): 2 x ResnetBlock2D(\n",
      "          (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "          (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "          (nonlinearity): SiLU()\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "    (conv_act): SiLU()\n",
      "    (conv_out): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "  )\n",
      "  (decoder): Decoder(\n",
      "    (conv_in): Conv2d(16, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (up_blocks): ModuleList(\n",
      "      (0-1): 2 x UpDecoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0-3): 4 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "        (upsamplers): ModuleList(\n",
      "          (0): Upsample2D(\n",
      "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (2): UpDecoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0): ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "            (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
      "          )\n",
      "          (1-3): 3 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "        (upsamplers): ModuleList(\n",
      "          (0): Upsample2D(\n",
      "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (3): UpDecoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0): ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "            (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
      "          )\n",
      "          (1-3): 3 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "        (upsamplers): ModuleList(\n",
      "          (0): Upsample2D(\n",
      "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (4): UpDecoderBlock2D(\n",
      "        (resnets): ModuleList(\n",
      "          (0): ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "            (conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))\n",
      "          )\n",
      "          (1-3): 3 x ResnetBlock2D(\n",
      "            (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)\n",
      "            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n",
      "            (dropout): Dropout(p=0.0, inplace=False)\n",
      "            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "            (nonlinearity): SiLU()\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (mid_block): UNetMidBlock2D(\n",
      "      (attentions): ModuleList(\n",
      "        (0): Attention(\n",
      "          (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "          (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
      "          (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
      "          (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
      "          (to_out): ModuleList(\n",
      "            (0): Linear(in_features=512, out_features=512, bias=True)\n",
      "            (1): Dropout(p=0.0, inplace=False)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (resnets): ModuleList(\n",
      "        (0-1): 2 x ResnetBlock2D(\n",
      "          (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "          (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
      "          (dropout): Dropout(p=0.0, inplace=False)\n",
      "          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "          (nonlinearity): SiLU()\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (conv_norm_out): GroupNorm(32, 64, eps=1e-06, affine=True)\n",
      "    (conv_act): SiLU()\n",
      "    (conv_out): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# ---- Конфиг новой модели ----\n",
    "config = {\n",
    "    \"_class_name\": \"AutoencoderKL\",\n",
    "    \"act_fn\": \"silu\",\n",
    "    \"in_channels\": 3,\n",
    "    \"out_channels\": 3,\n",
    "    \"scaling_factor\": 1.0,\n",
    "    \"norm_num_groups\": 32,\n",
    "    \"block_out_channels\": [64, 128, 256, 512, 512],\n",
    "    \"down_block_types\": [\n",
    "        \"DownEncoderBlock2D\",\n",
    "        \"DownEncoderBlock2D\",\n",
    "        \"DownEncoderBlock2D\",\n",
    "        \"DownEncoderBlock2D\",\n",
    "        \"DownEncoderBlock2D\",\n",
    "    ],\n",
    "    \"latent_channels\": 16,\n",
    "    \"up_block_types\": [\n",
    "        \"UpDecoderBlock2D\",\n",
    "        \"UpDecoderBlock2D\",\n",
    "        \"UpDecoderBlock2D\",\n",
    "        \"UpDecoderBlock2D\",\n",
    "        \"UpDecoderBlock2D\",\n",
    "    ],\n",
    "}\n",
    "\n",
    "# ---- Создание пустой асимметричной модели ----\n",
    "vae = AutoencoderKL(\n",
    "    act_fn=config[\"act_fn\"],\n",
    "    block_out_channels=config[\"block_out_channels\"],\n",
    "    down_block_types=config[\"down_block_types\"],\n",
    "    latent_channels=config[\"latent_channels\"],\n",
    "    up_block_types=config[\"up_block_types\"],\n",
    "    in_channels=config[\"in_channels\"],\n",
    "    out_channels=config[\"out_channels\"],\n",
    "    scaling_factor=config[\"scaling_factor\"],\n",
    "    norm_num_groups=config[\"norm_num_groups\"],\n",
    "    layers_per_block=3,\n",
    "    sample_size=1024,\n",
    "    use_post_quant_conv = False,\n",
    "    use_quant_conv = False,\n",
    ")\n",
    "\n",
    "vae.save_pretrained(\"vae_empty\")\n",
    "print(\"✅ Создана новая модель:\", type(vae))\n",
    "\n",
    "# ---- Функция переноса весов старого VAE ----\n",
    "def transfer_weights(old_path, new_path, save_path=\"asymmetric_vae\", device=\"cuda\", dtype=torch.float16):\n",
    "    old_vae = AsymmetricAutoencoderKL.from_pretrained(old_path).to(device, dtype=dtype)\n",
    "    new_vae = AutoencoderKL.from_pretrained(new_path).to(device, dtype=dtype)\n",
    "\n",
    "    old_sd = old_vae.state_dict()\n",
    "    new_sd = new_vae.state_dict()\n",
    "\n",
    "    transferred_keys = set()\n",
    "    transfer_stats = {\"перенесено\": 0, \"дублировано\": 0, \"пропущено\": 0}\n",
    "\n",
    "    print(\"\\n--- Перенос весов ---\")\n",
    "    for k, v in tqdm(old_sd.items()):\n",
    "        # Копирование энкодера и прочих совпадающих ключей\n",
    "        if (\"encoder\" in k) or (\"quant_conv\" in k) or (\"post_quant_conv\" in k):\n",
    "            if k in new_sd and new_sd[k].shape == v.shape:\n",
    "                new_sd[k] = v.clone()\n",
    "                transferred_keys.add(k)\n",
    "                transfer_stats[\"перенесено\"] += 1\n",
    "            continue\n",
    "\n",
    "        # Копирование декодера (без сдвига)\n",
    "        if \"decoder.up_blocks\" in k:\n",
    "            if k in new_sd and new_sd[k].shape == v.shape:\n",
    "                new_sd[k] = v.clone()\n",
    "                transferred_keys.add(k)\n",
    "                transfer_stats[\"перенесено\"] += 1\n",
    "            continue\n",
    "\n",
    "    # Дублирование весов старого первого 512→512 блока в новый блок 64→128 для апскейла\n",
    "    #ref_prefix = \"encoder.down_blocks.1\"\n",
    "    #new_prefix = \"encoder.down_blocks.0\"\n",
    "    #for k, v in old_sd.items():\n",
    "    #    if k.startswith(ref_prefix) and new_prefix + k[len(ref_prefix):] in new_sd:\n",
    "    #        new_k = k.replace(ref_prefix, new_prefix)\n",
    "    #        if new_sd[new_k].shape == v.shape:\n",
    "    #            new_sd[new_k] = v.clone()\n",
    "    #            transferred_keys.add(new_k)\n",
    "    #            transfer_stats[\"дублировано\"] += 1\n",
    "\n",
    "    # Загрузка и сохранение\n",
    "    new_vae.load_state_dict(new_sd, strict=False)\n",
    "    new_vae.save_pretrained(save_path)\n",
    "\n",
    "    print(\"\\n✅ Перенос завершён.\")\n",
    "    print(\"Статистика:\")\n",
    "    for k, v in transfer_stats.items():\n",
    "        print(f\"  {k}: {v}\")\n",
    "    print(new_vae)\n",
    "\n",
    "# ---- Запуск переноса ----\n",
    "transfer_weights(\"vae16\", \"vae_empty\", save_path=\"vae17\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59fcafb9-6d89-49b4-8362-b4891f591687",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}