recoilme commited on
Commit
6c0c34c
·
verified ·
1 Parent(s): f84b91f

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/transfer_simplevae3-checkpoint.ipynb ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "c15deb04-94a0-4073-a174-adcd22af10b8",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
14
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
15
+ ]
16
+ },
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "✅ Создана новая модель: <class 'diffusers.models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL'>\n",
22
+ "\n",
23
+ "--- Перенос весов ---\n"
24
+ ]
25
+ },
26
+ {
27
+ "name": "stderr",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "100%|██████████| 324/324 [00:00<00:00, 56241.13it/s]"
31
+ ]
32
+ },
33
+ {
34
+ "name": "stdout",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "\n",
38
+ "✅ Перенос завершён.\n",
39
+ "Статистика:\n",
40
+ " перенесено: 251\n",
41
+ " дублировано: 2\n",
42
+ " пропущено: 0\n",
43
+ "AsymmetricAutoencoderKL(\n",
44
+ " (encoder): Encoder(\n",
45
+ " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
46
+ " (down_blocks): ModuleList(\n",
47
+ " (0-1): 2 x DownEncoderBlock2D(\n",
48
+ " (resnets): ModuleList(\n",
49
+ " (0-1): 2 x ResnetBlock2D(\n",
50
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
51
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
52
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
53
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
54
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
55
+ " (nonlinearity): SiLU()\n",
56
+ " )\n",
57
+ " )\n",
58
+ " (downsamplers): ModuleList(\n",
59
+ " (0): Downsample2D(\n",
60
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
61
+ " )\n",
62
+ " )\n",
63
+ " )\n",
64
+ " (2): DownEncoderBlock2D(\n",
65
+ " (resnets): ModuleList(\n",
66
+ " (0): ResnetBlock2D(\n",
67
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
68
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
69
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
70
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
71
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
72
+ " (nonlinearity): SiLU()\n",
73
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
74
+ " )\n",
75
+ " (1): ResnetBlock2D(\n",
76
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
77
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
78
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
79
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
80
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
81
+ " (nonlinearity): SiLU()\n",
82
+ " )\n",
83
+ " )\n",
84
+ " (downsamplers): ModuleList(\n",
85
+ " (0): Downsample2D(\n",
86
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
87
+ " )\n",
88
+ " )\n",
89
+ " )\n",
90
+ " (3): DownEncoderBlock2D(\n",
91
+ " (resnets): ModuleList(\n",
92
+ " (0): ResnetBlock2D(\n",
93
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
94
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
95
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
96
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
97
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
98
+ " (nonlinearity): SiLU()\n",
99
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
100
+ " )\n",
101
+ " (1): ResnetBlock2D(\n",
102
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
103
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
104
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
105
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
106
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
107
+ " (nonlinearity): SiLU()\n",
108
+ " )\n",
109
+ " )\n",
110
+ " (downsamplers): ModuleList(\n",
111
+ " (0): Downsample2D(\n",
112
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n",
113
+ " )\n",
114
+ " )\n",
115
+ " )\n",
116
+ " )\n",
117
+ " (mid_block): UNetMidBlock2D(\n",
118
+ " (attentions): ModuleList(\n",
119
+ " (0): Attention(\n",
120
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
121
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
122
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
123
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
124
+ " (to_out): ModuleList(\n",
125
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
126
+ " (1): Dropout(p=0.0, inplace=False)\n",
127
+ " )\n",
128
+ " )\n",
129
+ " )\n",
130
+ " (resnets): ModuleList(\n",
131
+ " (0-1): 2 x ResnetBlock2D(\n",
132
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
133
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
134
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
135
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
136
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
137
+ " (nonlinearity): SiLU()\n",
138
+ " )\n",
139
+ " )\n",
140
+ " )\n",
141
+ " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
142
+ " (conv_act): SiLU()\n",
143
+ " (conv_out): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
144
+ " )\n",
145
+ " (decoder): MaskConditionDecoder(\n",
146
+ " (conv_in): Conv2d(16, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
147
+ " (up_blocks): ModuleList(\n",
148
+ " (0-1): 2 x UpDecoderBlock2D(\n",
149
+ " (resnets): ModuleList(\n",
150
+ " (0-3): 4 x ResnetBlock2D(\n",
151
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
152
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
153
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
154
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
155
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
156
+ " (nonlinearity): SiLU()\n",
157
+ " )\n",
158
+ " )\n",
159
+ " (upsamplers): ModuleList(\n",
160
+ " (0): Upsample2D(\n",
161
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
162
+ " )\n",
163
+ " )\n",
164
+ " )\n",
165
+ " (2): UpDecoderBlock2D(\n",
166
+ " (resnets): ModuleList(\n",
167
+ " (0): ResnetBlock2D(\n",
168
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
169
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
170
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
171
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
172
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
173
+ " (nonlinearity): SiLU()\n",
174
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
175
+ " )\n",
176
+ " (1-3): 3 x ResnetBlock2D(\n",
177
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
178
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
179
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
180
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
181
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
182
+ " (nonlinearity): SiLU()\n",
183
+ " )\n",
184
+ " )\n",
185
+ " (upsamplers): ModuleList(\n",
186
+ " (0): Upsample2D(\n",
187
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
188
+ " )\n",
189
+ " )\n",
190
+ " )\n",
191
+ " (3): UpDecoderBlock2D(\n",
192
+ " (resnets): ModuleList(\n",
193
+ " (0): ResnetBlock2D(\n",
194
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
195
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
196
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
197
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
198
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
199
+ " (nonlinearity): SiLU()\n",
200
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
201
+ " )\n",
202
+ " (1-3): 3 x ResnetBlock2D(\n",
203
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
204
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
205
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
206
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
207
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
208
+ " (nonlinearity): SiLU()\n",
209
+ " )\n",
210
+ " )\n",
211
+ " (upsamplers): ModuleList(\n",
212
+ " (0): Upsample2D(\n",
213
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
214
+ " )\n",
215
+ " )\n",
216
+ " )\n",
217
+ " (4): UpDecoderBlock2D(\n",
218
+ " (resnets): ModuleList(\n",
219
+ " (0-3): 4 x ResnetBlock2D(\n",
220
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
221
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
222
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
223
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
224
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
225
+ " (nonlinearity): SiLU()\n",
226
+ " )\n",
227
+ " )\n",
228
+ " )\n",
229
+ " )\n",
230
+ " (mid_block): UNetMidBlock2D(\n",
231
+ " (attentions): ModuleList(\n",
232
+ " (0): Attention(\n",
233
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
234
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
235
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
236
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
237
+ " (to_out): ModuleList(\n",
238
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
239
+ " (1): Dropout(p=0.0, inplace=False)\n",
240
+ " )\n",
241
+ " )\n",
242
+ " )\n",
243
+ " (resnets): ModuleList(\n",
244
+ " (0-1): 2 x ResnetBlock2D(\n",
245
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
246
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
247
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
248
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
249
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
250
+ " (nonlinearity): SiLU()\n",
251
+ " )\n",
252
+ " )\n",
253
+ " )\n",
254
+ " (condition_encoder): MaskConditionEncoder(\n",
255
+ " (layers): Sequential(\n",
256
+ " (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
257
+ " (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
258
+ " (2): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
259
+ " (3): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
260
+ " (4): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
261
+ " )\n",
262
+ " )\n",
263
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
264
+ " (conv_act): SiLU()\n",
265
+ " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
266
+ " )\n",
267
+ " (quant_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
268
+ " (post_quant_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))\n",
269
+ ")\n"
270
+ ]
271
+ },
272
+ {
273
+ "name": "stderr",
274
+ "output_type": "stream",
275
+ "text": [
276
+ "\n"
277
+ ]
278
+ }
279
+ ],
280
+ "source": [
281
+ "from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL\n",
282
+ "import torch\n",
283
+ "from tqdm import tqdm\n",
284
+ "\n",
285
+ "# ---- Конфиг новой модели ----\n",
286
+ "config = {\n",
287
+ " \"_class_name\": \"AsymmetricAutoencoderKL\",\n",
288
+ " \"act_fn\": \"silu\",\n",
289
+ " \"in_channels\": 3,\n",
290
+ " \"out_channels\": 3,\n",
291
+ " \"scaling_factor\": 1.0,\n",
292
+ " \"norm_num_groups\": 32,\n",
293
+ " \"down_block_out_channels\": [128, 128, 256, 512, 512],\n",
294
+ " \"down_block_types\": [\n",
295
+ " \"DownEncoderBlock2D\",\n",
296
+ " \"DownEncoderBlock2D\",\n",
297
+ " \"DownEncoderBlock2D\",\n",
298
+ " \"DownEncoderBlock2D\",\n",
299
+ " ],\n",
300
+ " \"latent_channels\": 16,\n",
301
+ " \"up_block_out_channels\": [128, 128, 256, 512, 512],\n",
302
+ " \"up_block_types\": [\n",
303
+ " \"UpDecoderBlock2D\",\n",
304
+ " \"UpDecoderBlock2D\",\n",
305
+ " \"UpDecoderBlock2D\",\n",
306
+ " \"UpDecoderBlock2D\",\n",
307
+ " \"UpDecoderBlock2D\",\n",
308
+ " ],\n",
309
+ "}\n",
310
+ "\n",
311
+ "# ---- Создание пустой асимметричной модели ----\n",
312
+ "vae = AsymmetricAutoencoderKL(\n",
313
+ " act_fn=config[\"act_fn\"],\n",
314
+ " down_block_out_channels=config[\"down_block_out_channels\"],\n",
315
+ " down_block_types=config[\"down_block_types\"],\n",
316
+ " latent_channels=config[\"latent_channels\"],\n",
317
+ " up_block_out_channels=config[\"up_block_out_channels\"],\n",
318
+ " up_block_types=config[\"up_block_types\"],\n",
319
+ " in_channels=config[\"in_channels\"],\n",
320
+ " out_channels=config[\"out_channels\"],\n",
321
+ " scaling_factor=config[\"scaling_factor\"],\n",
322
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
323
+ " layers_per_down_block=2,\n",
324
+ " layers_per_up_block=3,\n",
325
+ " sample_size=1024\n",
326
+ ")\n",
327
+ "\n",
328
+ "vae.save_pretrained(\"asymmetric_vae_empty\")\n",
329
+ "print(\"✅ Создана новая модель:\", type(vae))\n",
330
+ "\n",
331
+ "# ---- Функция переноса весов старого VAE ----\n",
332
+ "def transfer_weights(old_path, new_path, save_path=\"asymmetric_vae\", device=\"cuda\", dtype=torch.float16):\n",
333
+ " old_vae = AsymmetricAutoencoderKL.from_pretrained(old_path).to(device, dtype=dtype)\n",
334
+ " new_vae = AsymmetricAutoencoderKL.from_pretrained(new_path).to(device, dtype=dtype)\n",
335
+ "\n",
336
+ " old_sd = old_vae.state_dict()\n",
337
+ " new_sd = new_vae.state_dict()\n",
338
+ "\n",
339
+ " transferred_keys = set()\n",
340
+ " transfer_stats = {\"перенесено\": 0, \"дублировано\": 0, \"пропущено\": 0}\n",
341
+ "\n",
342
+ " print(\"\\n--- Перенос весов ---\")\n",
343
+ " for k, v in tqdm(old_sd.items()):\n",
344
+ " # Копирование энкодера и прочих совпадающих ключей\n",
345
+ " if (\"encoder\" in k) or (\"quant_conv\" in k) or (\"post_quant_conv\" in k):\n",
346
+ " if k in new_sd and new_sd[k].shape == v.shape:\n",
347
+ " new_sd[k] = v.clone()\n",
348
+ " transferred_keys.add(k)\n",
349
+ " transfer_stats[\"перенесено\"] += 1\n",
350
+ " continue\n",
351
+ "\n",
352
+ " # Копирование декодера (без сдвига)\n",
353
+ " if \"decoder.up_blocks\" in k:\n",
354
+ " if k in new_sd and new_sd[k].shape == v.shape:\n",
355
+ " new_sd[k] = v.clone()\n",
356
+ " transferred_keys.add(k)\n",
357
+ " transfer_stats[\"перенесено\"] += 1\n",
358
+ " continue\n",
359
+ "\n",
360
+ " # Дублирование весов старого первого 512→512 блока в новый блок 64→128 для апскейла\n",
361
+ " ref_prefix = \"encoder.down_blocks.1\"\n",
362
+ " new_prefix = \"encoder.down_blocks.0\"\n",
363
+ " for k, v in old_sd.items():\n",
364
+ " if k.startswith(ref_prefix) and new_prefix + k[len(ref_prefix):] in new_sd:\n",
365
+ " new_k = k.replace(ref_prefix, new_prefix)\n",
366
+ " if new_sd[new_k].shape == v.shape:\n",
367
+ " new_sd[new_k] = v.clone()\n",
368
+ " transferred_keys.add(new_k)\n",
369
+ " transfer_stats[\"дублировано\"] += 1\n",
370
+ "\n",
371
+ " # Загрузка и сохранение\n",
372
+ " new_vae.load_state_dict(new_sd, strict=False)\n",
373
+ " new_vae.save_pretrained(save_path)\n",
374
+ "\n",
375
+ " print(\"\\n✅ Перенос завершён.\")\n",
376
+ " print(\"Статистика:\")\n",
377
+ " for k, v in transfer_stats.items():\n",
378
+ " print(f\" {k}: {v}\")\n",
379
+ " print(new_vae)\n",
380
+ "\n",
381
+ "# ---- Запуск переноса ----\n",
382
+ "transfer_weights(\"vae10\", \"asymmetric_vae_empty\", save_path=\"vae11\")\n"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "id": "59fcafb9-6d89-49b4-8362-b4891f591687",
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": []
392
+ }
393
+ ],
394
+ "metadata": {
395
+ "kernelspec": {
396
+ "display_name": "Python 3 (ipykernel)",
397
+ "language": "python",
398
+ "name": "python3"
399
+ },
400
+ "language_info": {
401
+ "codemirror_mode": {
402
+ "name": "ipython",
403
+ "version": 3
404
+ },
405
+ "file_extension": ".py",
406
+ "mimetype": "text/x-python",
407
+ "name": "python",
408
+ "nbconvert_exporter": "python",
409
+ "pygments_lexer": "ipython3",
410
+ "version": "3.11.10"
411
+ }
412
+ },
413
+ "nbformat": 4,
414
+ "nbformat_minor": 5
415
+ }
.ipynb_checkpoints/vae_comp-checkpoint.ipynb ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "b3b23a40-8354-4287-bac2-32f9d084fff3",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py:202: UserWarning: The `local_dir_use_symlinks` argument is deprecated and ignored in `hf_hub_download`. Downloading to a local directory does not use symlinks anymore.\n",
14
+ " warnings.warn(\n"
15
+ ]
16
+ },
17
+ {
18
+ "data": {
19
+ "application/vnd.jupyter.widget-view+json": {
20
+ "model_id": "96d38ff0fa134b02a5a21c96bdfd36b5",
21
+ "version_major": 2,
22
+ "version_minor": 0
23
+ },
24
+ "text/plain": [
25
+ "vae/config.json: 0%| | 0.00/752 [00:00<?, ?B/s]"
26
+ ]
27
+ },
28
+ "metadata": {},
29
+ "output_type": "display_data"
30
+ },
31
+ {
32
+ "data": {
33
+ "application/vnd.jupyter.widget-view+json": {
34
+ "model_id": "0a44c60705d44f58b5a07ead45936327",
35
+ "version_major": 2,
36
+ "version_minor": 0
37
+ },
38
+ "text/plain": [
39
+ "vae/diffusion_pytorch_model.safetensors: 0%| | 0.00/191M [00:00<?, ?B/s]"
40
+ ]
41
+ },
42
+ "metadata": {},
43
+ "output_type": "display_data"
44
+ },
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "sdxs_vae log-variance: 1.840\n"
50
+ ]
51
+ },
52
+ {
53
+ "name": "stderr",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
57
+ ]
58
+ },
59
+ {
60
+ "name": "stdout",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "vae9 log-variance: 1.840\n",
64
+ "Готово\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "import torch\n",
70
+ "from PIL import Image\n",
71
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
72
+ "from torchvision.transforms.functional import to_pil_image\n",
73
+ "import matplotlib.pyplot as plt\n",
74
+ "import os\n",
75
+ "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n",
76
+ "\n",
77
+ "# путь к вашей картинке\n",
78
+ "IMG_PATH = \"1234567890.png\"\n",
79
+ "OUT_DIR = \"vaetest\"\n",
80
+ "device = \"cuda\"\n",
81
+ "dtype = torch.float32 # ← единый float32\n",
82
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
83
+ "\n",
84
+ "# список VAE\n",
85
+ "VAES = {\n",
86
+ " #\"sdxl\": \"madebyollin/sdxl-vae-fp16-fix\",\n",
87
+ " \"sdxs_vae\": \"AiArtLab/sdxs-1b\",\n",
88
+ " #\"vae8\": \"/workspace/simplevae2x/vae8\",\n",
89
+ " \"vae9\": \"/workspace/simplevae2x/vae9\"\n",
90
+ "}\n",
91
+ "\n",
92
+ "def load_image(path):\n",
93
+ " img = Image.open(path).convert('RGB')\n",
94
+ " # обрезаем до кратности 8\n",
95
+ " w, h = img.size\n",
96
+ " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n",
97
+ " tensor = ToTensor()(img).unsqueeze(0) # [0,1]\n",
98
+ " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor) # [-1,1]\n",
99
+ " return img, tensor.to(device, dtype=dtype)\n",
100
+ "\n",
101
+ "# обратно в PIL\n",
102
+ "def tensor_to_img(t):\n",
103
+ " t = (t * 0.5 + 0.5).clamp(0, 1)\n",
104
+ " return to_pil_image(t[0])\n",
105
+ "\n",
106
+ "def logvariance(latents):\n",
107
+ " \"\"\"Возвращает лог-дисперсию по всем элементам.\"\"\"\n",
108
+ " return torch.log(latents.var() + 1e-8).item()\n",
109
+ "\n",
110
+ "def plot_latent_distribution(latents, title, save_path):\n",
111
+ " \"\"\"Гистограмма + QQ-plot.\"\"\"\n",
112
+ " lat = latents.detach().cpu().numpy().flatten()\n",
113
+ " plt.figure(figsize=(10, 4))\n",
114
+ "\n",
115
+ " # гистограмма\n",
116
+ " plt.subplot(1, 2, 1)\n",
117
+ " plt.hist(lat, bins=100, density=True, alpha=0.7, color='steelblue')\n",
118
+ " plt.title(f\"{title} histogram\")\n",
119
+ " plt.xlabel(\"latent value\")\n",
120
+ " plt.ylabel(\"density\")\n",
121
+ "\n",
122
+ " # QQ-plot\n",
123
+ " from scipy.stats import probplot\n",
124
+ " plt.subplot(1, 2, 2)\n",
125
+ " probplot(lat, dist=\"norm\", plot=plt)\n",
126
+ " plt.title(f\"{title} QQ-plot\")\n",
127
+ "\n",
128
+ " plt.tight_layout()\n",
129
+ " plt.savefig(save_path)\n",
130
+ " plt.close()\n",
131
+ "\n",
132
+ "for name, repo in VAES.items():\n",
133
+ " if name==\"sdxs_vae\":\n",
134
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, subfolder=\"vae\", torch_dtype=dtype).to(device)\n",
135
+ " else:\n",
136
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, torch_dtype=dtype).to(device)#, subfolder=\"vae\", variant=\"fp16\"\n",
137
+ "\n",
138
+ " cfg = vae.config\n",
139
+ " scale = getattr(cfg, \"scaling_factor\", 1.)\n",
140
+ " shift = getattr(cfg, \"shift_factor\", 0.0)\n",
141
+ " mean = getattr(cfg, \"latents_mean\", None)\n",
142
+ " std = getattr(cfg, \"latents_std\", None)\n",
143
+ "\n",
144
+ " C = 4 # 4 для SDXL\n",
145
+ " if mean is not None:\n",
146
+ " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, C, 1, 1)\n",
147
+ " if std is not None:\n",
148
+ " std = torch.tensor(std, device=device, dtype=dtype).view(1, C, 1, 1)\n",
149
+ " if shift is not None:\n",
150
+ " shift = torch.tensor(shift, device=device, dtype=dtype)\n",
151
+ " else:\n",
152
+ " shift = 0.0 \n",
153
+ "\n",
154
+ " scale = torch.tensor(scale, device=device, dtype=dtype)\n",
155
+ "\n",
156
+ " img, x = load_image(IMG_PATH)\n",
157
+ " img.save(os.path.join(OUT_DIR, f\"original.png\"))\n",
158
+ "\n",
159
+ " with torch.no_grad():\n",
160
+ " # encode\n",
161
+ " latents = vae.encode(x).latent_dist.sample().to(dtype)\n",
162
+ " if mean is not None and std is not None:\n",
163
+ " latents = (latents - mean) / std\n",
164
+ " latents = latents * scale + shift\n",
165
+ "\n",
166
+ " lv = logvariance(latents)\n",
167
+ " print(f\"{name} log-variance: {lv:.3f}\")\n",
168
+ "\n",
169
+ " # график\n",
170
+ " plot_latent_distribution(latents, f\"{name}_latents\",\n",
171
+ " os.path.join(OUT_DIR, f\"dist_{name}.png\"))\n",
172
+ "\n",
173
+ " # decode\n",
174
+ " latents = (latents - shift) / scale\n",
175
+ " if mean is not None and std is not None:\n",
176
+ " latents = latents * std + mean\n",
177
+ " rec = vae.decode(latents).sample\n",
178
+ "\n",
179
+ " tensor_to_img(rec).save(os.path.join(OUT_DIR, f\"decoded_{name}.png\"))\n",
180
+ "\n",
181
+ "print(\"Готово\")"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "id": "200b72ab-1978-4d71-9aba-b1ef97cf0b27",
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": []
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "Python 3 (ipykernel)",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.11.10"
210
+ }
211
+ },
212
+ "nbformat": 4,
213
+ "nbformat_minor": 5
214
+ }
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AsymmetricAutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "vae16",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 128,
9
+ 256,
10
+ 512,
11
+ 512
12
+ ],
13
+ "down_block_out_channels": [
14
+ 128,
15
+ 128,
16
+ 256,
17
+ 512,
18
+ 512
19
+ ],
20
+ "down_block_types": [
21
+ "DownEncoderBlock2D",
22
+ "DownEncoderBlock2D",
23
+ "DownEncoderBlock2D",
24
+ "DownEncoderBlock2D"
25
+ ],
26
+ "force_upcast": false,
27
+ "in_channels": 3,
28
+ "latent_channels": 16,
29
+ "layers_per_down_block": 2,
30
+ "layers_per_up_block": 3,
31
+ "norm_num_groups": 32,
32
+ "out_channels": 3,
33
+ "sample_size": 1024,
34
+ "scaling_factor": 1.0,
35
+ "up_block_out_channels": [
36
+ 128,
37
+ 128,
38
+ 256,
39
+ 512,
40
+ 512
41
+ ],
42
+ "up_block_types": [
43
+ "UpDecoderBlock2D",
44
+ "UpDecoderBlock2D",
45
+ "UpDecoderBlock2D",
46
+ "UpDecoderBlock2D",
47
+ "UpDecoderBlock2D"
48
+ ]
49
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c222daaf0f6f70a5c762254739cc9dc04494a890eddc110ec2c8431667515593
3
+ size 392649428
train_vae_fdl.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
+ from collections import deque
29
+
30
+ # --------------------------- Параметры ---------------------------
31
+ ds_path = "/workspace/d23"
32
+ project = "vae16"
33
+ batch_size = 3
34
+ base_learning_rate = 6e-6
35
+ min_learning_rate = 7e-7
36
+ num_epochs = 1
37
+ sample_interval_share = 10
38
+ use_wandb = True
39
+ save_model = True
40
+ use_decay = True
41
+ optimizer_type = "adam8bit"
42
+ dtype = torch.float32
43
+
44
+ model_resolution = 512 #288
45
+ high_resolution = 512 #576
46
+ limit = 0
47
+ save_barrier = 1.3
48
+ warmup_percent = 0.005
49
+ percentile_clipping = 99
50
+ beta2 = 0.997
51
+ eps = 1e-8
52
+ clip_grad_norm = 1.0
53
+ mixed_precision = "no"
54
+ gradient_accumulation_steps = 1
55
+ generated_folder = "samples"
56
+ save_as = "vae16"
57
+ num_workers = 0
58
+ device = None
59
+ torch.backends.cuda.matmul.allow_tf32 = True
60
+ torch.backends.cudnn.allow_tf32 = True
61
+ # Включение Flash Attention 2/SDPA #MAX_JOBS=4 pip install flash-attn --no-build-isolation
62
+ torch.backends.cuda.enable_flash_sdp(True)
63
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
64
+ torch.backends.cuda.enable_math_sdp(False)
65
+
66
+ # --- Режимы обучения ---
67
+ # QWEN: учим только декодер
68
+ train_encoder_only = False
69
+ train_up_only = False
70
+ train_down_only = False
71
+ full_training = True # если True — учим весь VAE и добавляем KL (ниже)
72
+ kl_ratio = 0.001
73
+
74
+ # Доли лоссов
75
+ loss_ratios = {
76
+ "lpips": 0.70,#0.50,
77
+ "fdl" : 0.10,#0.25,
78
+ "edge": 0.05,
79
+ "mse": 0.10,
80
+ "mae": 0.049,
81
+ "kl": 0.001, # активируем при full_training=True
82
+ }
83
+ median_coeff_steps = 250
84
+
85
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
86
+
87
+ # QWEN: конфиг загрузки модели
88
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
89
+
90
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
91
+
92
+ accelerator = Accelerator(
93
+ mixed_precision=mixed_precision,
94
+ gradient_accumulation_steps=gradient_accumulation_steps
95
+ )
96
+ device = accelerator.device
97
+
98
+ # reproducibility
99
+ seed = int(datetime.now().strftime("%Y%m%d")) + 13
100
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
101
+ torch.backends.cudnn.benchmark = False
102
+
103
+ # --------------------------- WandB ---------------------------
104
+ if use_wandb and accelerator.is_main_process:
105
+ wandb.init(project=project, config={
106
+ "batch_size": batch_size,
107
+ "base_learning_rate": base_learning_rate,
108
+ "num_epochs": num_epochs,
109
+ "optimizer_type": optimizer_type,
110
+ "model_resolution": model_resolution,
111
+ "high_resolution": high_resolution,
112
+ "gradient_accumulation_steps": gradient_accumulation_steps,
113
+ "train_encoder_only": train_encoder_only,
114
+ "full_training": full_training,
115
+ "kl_ratio": kl_ratio,
116
+ "vae_kind": vae_kind,
117
+ })
118
+
119
+ # --------------------------- VAE ---------------------------
120
+ def get_core_model(model):
121
+ m = model
122
+ # если модель уже обёрнута torch.compile
123
+ if hasattr(m, "_orig_mod"):
124
+ m = m._orig_mod
125
+ return m
126
+
127
+ def is_video_vae(model) -> bool:
128
+ # WAN/Qwen — это видео-VAEs
129
+ if vae_kind in ("wan", "qwen"):
130
+ return True
131
+ # fallback по структуре (если понадобится)
132
+ try:
133
+ core = get_core_model(model)
134
+ enc = getattr(core, "encoder", None)
135
+ conv_in = getattr(enc, "conv_in", None)
136
+ w = getattr(conv_in, "weight", None)
137
+ if isinstance(w, torch.nn.Parameter):
138
+ return w.ndim == 5
139
+ except Exception:
140
+ pass
141
+ return False
142
+
143
+ # загрузка
144
+ if vae_kind == "qwen":
145
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
146
+ else:
147
+ if vae_kind == "wan":
148
+ vae = AutoencoderKLWan.from_pretrained(project)
149
+ else:
150
+ # старое поведение (пример)
151
+ if model_resolution==high_resolution:
152
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
153
+ else:
154
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
155
+
156
+ vae = vae.to(dtype)
157
+
158
+ # torch.compile (опционально)
159
+ if hasattr(torch, "compile"):
160
+ try:
161
+ vae = torch.compile(vae)
162
+ except Exception as e:
163
+ print(f"[WARN] torch.compile failed: {e}")
164
+
165
+ # --------------------------- Freeze/Unfreeze ---------------------------
166
+ core = get_core_model(vae)
167
+
168
+ for p in core.parameters():
169
+ p.requires_grad = False
170
+
171
+ unfrozen_param_names = []
172
+
173
+ if full_training and not train_encoder_only:
174
+ for name, p in core.named_parameters():
175
+ p.requires_grad = True
176
+ unfrozen_param_names.append(name)
177
+ loss_ratios["kl"] = float(kl_ratio)
178
+ trainable_module = core
179
+ else:
180
+ # учим только 0-й блок декодера + post_quant_conv
181
+ if hasattr(core, "encoder"):
182
+ if train_down_only:#hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
183
+ # --- только 0-й up_block ---
184
+ for name, p in core.encoder.down_blocks[0].named_parameters():
185
+ p.requires_grad = True
186
+ unfrozen_param_names.append(f"{name}")
187
+ else:
188
+ print("Decoder — fallback to full decoder")
189
+ for name, p in core.decoder.named_parameters():
190
+ p.requires_grad = True
191
+ unfrozen_param_names.append(f"decoder.{name}")
192
+ if hasattr(core, "post_quant_conv"):
193
+ for name, p in core.post_quant_conv.named_parameters():
194
+ p.requires_grad = True
195
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
196
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
197
+
198
+
199
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
200
+ for nm in unfrozen_param_names[:200]:
201
+ print(" ", nm)
202
+
203
+ # --------------------------- Датасет ---------------------------
204
+ class PngFolderDataset(Dataset):
205
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
206
+ self.root_dir = root_dir
207
+ self.resolution = resolution
208
+ self.paths = []
209
+ for root, _, files in os.walk(root_dir):
210
+ for fname in files:
211
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
212
+ self.paths.append(os.path.join(root, fname))
213
+ if limit:
214
+ self.paths = self.paths[:limit]
215
+ valid = []
216
+ for p in self.paths:
217
+ try:
218
+ with Image.open(p) as im:
219
+ im.verify()
220
+ valid.append(p)
221
+ except (OSError, UnidentifiedImageError):
222
+ continue
223
+ self.paths = valid
224
+ if len(self.paths) == 0:
225
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
226
+ random.shuffle(self.paths)
227
+
228
+ def __len__(self):
229
+ return len(self.paths)
230
+
231
+ def __getitem__(self, idx):
232
+ p = self.paths[idx % len(self.paths)]
233
+ with Image.open(p) as img:
234
+ img = img.convert("RGB")
235
+ if not resize_long_side or resize_long_side <= 0:
236
+ return img
237
+ w, h = img.size
238
+ long = max(w, h)
239
+ if long <= resize_long_side:
240
+ return img
241
+ scale = resize_long_side / float(long)
242
+ new_w = int(round(w * scale))
243
+ new_h = int(round(h * scale))
244
+ return img.resize((new_w, new_h), Image.BICUBIC)
245
+
246
+ def random_crop(img, sz):
247
+ w, h = img.size
248
+ if w < sz or h < sz:
249
+ img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
250
+ x = random.randint(0, max(1, img.width - sz))
251
+ y = random.randint(0, max(1, img.height - sz))
252
+ return img.crop((x, y, x + sz, y + sz))
253
+
254
+ tfm = transforms.Compose([
255
+ transforms.ToTensor(),
256
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
257
+ ])
258
+
259
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
260
+ print("len(dataset)",len(dataset))
261
+ if len(dataset) < batch_size:
262
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
263
+
264
+ def collate_fn(batch):
265
+ imgs = []
266
+ for img in batch:
267
+ img = random_crop(img, high_resolution)
268
+ imgs.append(tfm(img))
269
+ return torch.stack(imgs)
270
+
271
+ dataloader = DataLoader(
272
+ dataset,
273
+ batch_size=batch_size,
274
+ shuffle=True,
275
+ collate_fn=collate_fn,
276
+ num_workers=num_workers,
277
+ pin_memory=True,
278
+ drop_last=True
279
+ )
280
+
281
+ # --------------------------- Оптимизатор ---------------------------
282
+ def get_param_groups(module, weight_decay=0.001):
283
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
284
+ decay_params, no_decay_params = [], []
285
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
286
+ if not p.requires_grad:
287
+ continue
288
+ if any(nd in n for nd in no_decay):
289
+ no_decay_params.append(p)
290
+ else:
291
+ decay_params.append(p)
292
+ return [
293
+ {"params": decay_params, "weight_decay": weight_decay},
294
+ {"params": no_decay_params, "weight_decay": 0.0},
295
+ ]
296
+
297
+ def get_param_groups(module, weight_decay=0.001):
298
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
299
+ decay_params, no_decay_params = [], []
300
+ for n, p in module.named_parameters():
301
+ if not p.requires_grad:
302
+ continue
303
+ n_l = n.lower()
304
+ if any(t in n_l for t in no_decay_tokens):
305
+ no_decay_params.append(p)
306
+ else:
307
+ decay_params.append(p)
308
+ return [
309
+ {"params": decay_params, "weight_decay": weight_decay},
310
+ {"params": no_decay_params, "weight_decay": 0.0},
311
+ ]
312
+
313
+ def create_optimizer(name, param_groups):
314
+ if name == "adam8bit":
315
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
316
+ raise ValueError(name)
317
+
318
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
319
+ optimizer = create_optimizer(optimizer_type, param_groups)
320
+
321
+ # --------------------------- LR schedule ---------------------------
322
+ batches_per_epoch = len(dataloader)
323
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
324
+ total_steps = steps_per_epoch * num_epochs
325
+
326
+ def lr_lambda(step):
327
+ if not use_decay:
328
+ return 1.0
329
+ x = float(step) / float(max(1, total_steps))
330
+ warmup = float(warmup_percent)
331
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
332
+ if x < warmup:
333
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
334
+ decay_ratio = (x - warmup) / (1.0 - warmup)
335
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
336
+
337
+ scheduler = LambdaLR(optimizer, lr_lambda)
338
+
339
+ # Подготовка
340
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
341
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
342
+
343
+ # fdl
344
+ fdl_loss = FDL_loss()
345
+ fdl_loss = fdl_loss.to(accelerator.device)
346
+
347
+ # --------------------------- LPIPS и вспомогательные ---------------------------
348
+ _lpips_net = None
349
+ def _get_lpips():
350
+ global _lpips_net
351
+ if _lpips_net is None:
352
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
353
+ return _lpips_net
354
+
355
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
356
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
357
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
358
+ C = x.shape[1]
359
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
360
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
361
+ gx = F.conv2d(x, kx, padding=1, groups=C)
362
+ gy = F.conv2d(x, ky, padding=1, groups=C)
363
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
364
+
365
+ class MedianLossNormalizer:
366
+ def __init__(self, desired_ratios: dict, window_steps: int):
367
+ s = sum(desired_ratios.values())
368
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
369
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
370
+ self.window = window_steps
371
+
372
+ def update_and_total(self, abs_losses: dict):
373
+ for k, v in abs_losses.items():
374
+ if k in self.buffers:
375
+ self.buffers[k].append(float(v.detach().abs().cpu()))
376
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
377
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
378
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
379
+ return total, coeffs, meds
380
+
381
+ if full_training and not train_encoder_only:
382
+ loss_ratios["kl"] = float(kl_ratio)
383
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
384
+
385
+ # --------------------------- Сэмплы ---------------------------
386
+ @torch.no_grad()
387
+ def get_fixed_samples(n=3):
388
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
389
+ pil_imgs = [dataset[i] for i in idx]
390
+ tensors = []
391
+ for img in pil_imgs:
392
+ img = random_crop(img, high_resolution)
393
+ tensors.append(tfm(img))
394
+ return torch.stack(tensors).to(accelerator.device, dtype)
395
+
396
+ fixed_samples = get_fixed_samples()
397
+
398
+ @torch.no_grad()
399
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
400
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
401
+ return Image.fromarray(arr)
402
+
403
+
404
+ @torch.no_grad()
405
+ def generate_and_save_samples(step=None):
406
+ try:
407
+ #temp_vae = accelerator.unwrap_model(vae).eval()
408
+ if hasattr(vae, "module"):
409
+ # Если это DDP или DistributedDataParallel
410
+ unwrapped_vae = vae.module
411
+ else:
412
+ unwrapped_vae = vae
413
+
414
+ # Если использовался torch.compile, достаем оригинал
415
+ if hasattr(unwrapped_vae, "_orig_mod"):
416
+ temp_vae = unwrapped_vae._orig_mod
417
+ else:
418
+ temp_vae = unwrapped_vae
419
+
420
+ temp_vae = temp_vae.eval()
421
+ lpips_net = _get_lpips()
422
+ with torch.no_grad():
423
+ orig_high = fixed_samples
424
+ orig_low = F.interpolate(
425
+ orig_high,
426
+ size=(model_resolution, model_resolution),
427
+ mode="bilinear",
428
+ align_corners=False
429
+ )
430
+ model_dtype = next(temp_vae.parameters()).dtype
431
+ orig_low = orig_low.to(dtype=model_dtype)
432
+
433
+ # Encode/decode с учётом видео-режима
434
+ if is_video_vae(temp_vae):
435
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
436
+ enc = temp_vae.encode(x_in)
437
+ latents_mean = enc.latent_dist.mean
438
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
439
+ rec = dec.squeeze(2) # [B,3,H,W]
440
+ else:
441
+ enc = temp_vae.encode(orig_low)
442
+ latents_mean = enc.latent_dist.mean
443
+ rec = temp_vae.decode(latents_mean).sample
444
+
445
+ # Подгон размеров, если надо
446
+ #if rec.shape[-2:] != orig_high.shape[-2:]:
447
+ # rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
448
+
449
+ # Сохраняем все real/decoded
450
+ for i in range(rec.shape[0]):
451
+ real_img = _to_pil_uint8(orig_high[i])
452
+ dec_img = _to_pil_uint8(rec[i])
453
+ real_img.save(f"{generated_folder}/sample_real_{i}.png")
454
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.png")
455
+
456
+ # LPIPS
457
+ lpips_scores = []
458
+ for i in range(rec.shape[0]):
459
+ orig_full = orig_high[i:i+1].to(torch.float32)
460
+ rec_full = rec[i:i+1].to(torch.float32)
461
+ #if rec_full.shape[-2:] != orig_full.shape[-2:]:
462
+ # rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
463
+ lpips_val = lpips_net(orig_full, rec_full).item()
464
+ lpips_scores.append(lpips_val)
465
+ avg_lpips = float(np.mean(lpips_scores))
466
+
467
+ # W&B логирование
468
+ if use_wandb and accelerator.is_main_process:
469
+ log_data = {"lpips_mean": avg_lpips}
470
+ for i in range(rec.shape[0]):
471
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.png", caption=f"real_{i}")
472
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.png", caption=f"decoded_{i}")
473
+ wandb.log(log_data, step=step)
474
+
475
+ finally:
476
+ gc.collect()
477
+ torch.cuda.empty_cache()
478
+
479
+
480
+ if accelerator.is_main_process and save_model:
481
+ print("Генерация сэмплов до старта обучения...")
482
+ generate_and_save_samples(0)
483
+
484
+ accelerator.wait_for_everyone()
485
+
486
+ # --------------------------- Тренировка ---------------------------
487
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
488
+ global_step = 0
489
+ min_loss = float("inf")
490
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
491
+
492
+ for epoch in range(num_epochs):
493
+ vae.train()
494
+ batch_losses, batch_grads = [], []
495
+ track_losses = {k: [] for k in loss_ratios.keys()}
496
+
497
+ for imgs in dataloader:
498
+ with accelerator.accumulate(vae):
499
+ imgs = imgs.to(accelerator.device)
500
+
501
+ if high_resolution != model_resolution:
502
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution),mode="area") # mode="bilinear", align_corners=False)
503
+ else:
504
+ imgs_low = imgs
505
+
506
+ model_dtype = next(vae.parameters()).dtype
507
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
508
+
509
+ # Вместо: current_vae = accelerator.unwrap_model(vae)
510
+ unwrapped = vae.module if hasattr(vae, "module") else vae
511
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
512
+
513
+
514
+ # QWEN: encode/decode с T=1
515
+ if is_video_vae(current_vae):
516
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
517
+ enc = current_vae.encode(x_in)
518
+ latents = enc.latent_dist.mean if train_encoder_only else enc.latent_dist.sample()
519
+ dec = current_vae.decode(latents).sample # [B,3,1,H,W]
520
+ rec = dec.squeeze(2) # [B,3,H,W]
521
+ else:
522
+ enc = current_vae.encode(imgs_low_model)
523
+ latents = enc.latent_dist.mean if train_encoder_only else enc.latent_dist.sample()
524
+ rec = current_vae.decode(latents).sample
525
+
526
+ #if rec.shape[-2:] != imgs.shape[-2:]:
527
+ # rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
528
+
529
+ rec_f32 = rec.to(torch.float32)
530
+ imgs_f32 = imgs.to(torch.float32)
531
+
532
+ abs_losses = {
533
+ "mae": F.l1_loss(rec_f32, imgs_f32),
534
+ "mse": F.mse_loss(rec_f32, imgs_f32),
535
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
536
+ "fdl": fdl_loss(rec_f32, imgs_f32),
537
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
538
+ }
539
+
540
+ if full_training and not train_encoder_only:
541
+ mean = enc.latent_dist.mean
542
+ logvar = enc.latent_dist.logvar
543
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
544
+ abs_losses["kl"] = kl
545
+ else:
546
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
547
+
548
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
549
+
550
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
551
+ raise RuntimeError("NaN/Inf loss")
552
+
553
+ accelerator.backward(total_loss)
554
+
555
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
556
+ if accelerator.sync_gradients:
557
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
558
+ optimizer.step()
559
+ scheduler.step()
560
+ optimizer.zero_grad(set_to_none=True)
561
+ global_step += 1
562
+ progress.update(1)
563
+
564
+ if accelerator.is_main_process:
565
+ try:
566
+ current_lr = optimizer.param_groups[0]["lr"]
567
+ except Exception:
568
+ current_lr = scheduler.get_last_lr()[0]
569
+
570
+ batch_losses.append(total_loss.detach().item())
571
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
572
+ for k, v in abs_losses.items():
573
+ track_losses[k].append(float(v.detach().item()))
574
+
575
+ if use_wandb and accelerator.sync_gradients:
576
+ log_dict = {
577
+ "total_loss": float(total_loss.detach().item()),
578
+ "learning_rate": current_lr,
579
+ "epoch": epoch,
580
+ "grad_norm": batch_grads[-1],
581
+ }
582
+ for k, v in abs_losses.items():
583
+ log_dict[f"loss_{k}"] = float(v.detach().item())
584
+ for k in coeffs:
585
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
586
+ log_dict[f"median_{k}"] = float(meds[k])
587
+ wandb.log(log_dict, step=global_step)
588
+
589
+ if global_step > 0 and global_step % sample_interval == 0:
590
+ if accelerator.is_main_process:
591
+ generate_and_save_samples(global_step)
592
+ accelerator.wait_for_everyone()
593
+
594
+ n_micro = sample_interval * gradient_accumulation_steps
595
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
596
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
597
+
598
+ if accelerator.is_main_process:
599
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
600
+ if save_model and avg_loss < min_loss * save_barrier:
601
+ min_loss = avg_loss
602
+ unwrapped = vae.module if hasattr(vae, "module") else vae
603
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
604
+ current_vae.save_pretrained(save_as)
605
+ if use_wandb:
606
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
607
+
608
+ if accelerator.is_main_process:
609
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
610
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
611
+ if use_wandb:
612
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
613
+
614
+ # --------------------------- Финальное сохранение ---------------------------
615
+ if accelerator.is_main_process:
616
+ print("Training finished – saving final model")
617
+ if save_model:
618
+ unwrapped = vae.module if hasattr(vae, "module") else vae
619
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
620
+ current_vae.save_pretrained(save_as)
621
+
622
+ accelerator.free_memory()
623
+ if torch.distributed.is_initialized():
624
+ torch.distributed.destroy_process_group()
625
+ print("Готово!")
transfer_simplevae3.ipynb ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "c15deb04-94a0-4073-a174-adcd22af10b8",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
14
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
15
+ ]
16
+ },
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "✅ Создана новая модель: <class 'diffusers.models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL'>\n",
22
+ "\n",
23
+ "--- Перенос весов ---\n"
24
+ ]
25
+ },
26
+ {
27
+ "name": "stderr",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "100%|██████████| 324/324 [00:00<00:00, 56241.13it/s]"
31
+ ]
32
+ },
33
+ {
34
+ "name": "stdout",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "\n",
38
+ "✅ Перенос завершён.\n",
39
+ "Статистика:\n",
40
+ " перенесено: 251\n",
41
+ " дублировано: 2\n",
42
+ " пропущено: 0\n",
43
+ "AsymmetricAutoencoderKL(\n",
44
+ " (encoder): Encoder(\n",
45
+ " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
46
+ " (down_blocks): ModuleList(\n",
47
+ " (0-1): 2 x DownEncoderBlock2D(\n",
48
+ " (resnets): ModuleList(\n",
49
+ " (0-1): 2 x ResnetBlock2D(\n",
50
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
51
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
52
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
53
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
54
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
55
+ " (nonlinearity): SiLU()\n",
56
+ " )\n",
57
+ " )\n",
58
+ " (downsamplers): ModuleList(\n",
59
+ " (0): Downsample2D(\n",
60
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
61
+ " )\n",
62
+ " )\n",
63
+ " )\n",
64
+ " (2): DownEncoderBlock2D(\n",
65
+ " (resnets): ModuleList(\n",
66
+ " (0): ResnetBlock2D(\n",
67
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
68
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
69
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
70
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
71
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
72
+ " (nonlinearity): SiLU()\n",
73
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
74
+ " )\n",
75
+ " (1): ResnetBlock2D(\n",
76
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
77
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
78
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
79
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
80
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
81
+ " (nonlinearity): SiLU()\n",
82
+ " )\n",
83
+ " )\n",
84
+ " (downsamplers): ModuleList(\n",
85
+ " (0): Downsample2D(\n",
86
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
87
+ " )\n",
88
+ " )\n",
89
+ " )\n",
90
+ " (3): DownEncoderBlock2D(\n",
91
+ " (resnets): ModuleList(\n",
92
+ " (0): ResnetBlock2D(\n",
93
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
94
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
95
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
96
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
97
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
98
+ " (nonlinearity): SiLU()\n",
99
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
100
+ " )\n",
101
+ " (1): ResnetBlock2D(\n",
102
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
103
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
104
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
105
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
106
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
107
+ " (nonlinearity): SiLU()\n",
108
+ " )\n",
109
+ " )\n",
110
+ " (downsamplers): ModuleList(\n",
111
+ " (0): Downsample2D(\n",
112
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n",
113
+ " )\n",
114
+ " )\n",
115
+ " )\n",
116
+ " )\n",
117
+ " (mid_block): UNetMidBlock2D(\n",
118
+ " (attentions): ModuleList(\n",
119
+ " (0): Attention(\n",
120
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
121
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
122
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
123
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
124
+ " (to_out): ModuleList(\n",
125
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
126
+ " (1): Dropout(p=0.0, inplace=False)\n",
127
+ " )\n",
128
+ " )\n",
129
+ " )\n",
130
+ " (resnets): ModuleList(\n",
131
+ " (0-1): 2 x ResnetBlock2D(\n",
132
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
133
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
134
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
135
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
136
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
137
+ " (nonlinearity): SiLU()\n",
138
+ " )\n",
139
+ " )\n",
140
+ " )\n",
141
+ " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
142
+ " (conv_act): SiLU()\n",
143
+ " (conv_out): Conv2d(512, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
144
+ " )\n",
145
+ " (decoder): MaskConditionDecoder(\n",
146
+ " (conv_in): Conv2d(16, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
147
+ " (up_blocks): ModuleList(\n",
148
+ " (0-1): 2 x UpDecoderBlock2D(\n",
149
+ " (resnets): ModuleList(\n",
150
+ " (0-3): 4 x ResnetBlock2D(\n",
151
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
152
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
153
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
154
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
155
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
156
+ " (nonlinearity): SiLU()\n",
157
+ " )\n",
158
+ " )\n",
159
+ " (upsamplers): ModuleList(\n",
160
+ " (0): Upsample2D(\n",
161
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
162
+ " )\n",
163
+ " )\n",
164
+ " )\n",
165
+ " (2): UpDecoderBlock2D(\n",
166
+ " (resnets): ModuleList(\n",
167
+ " (0): ResnetBlock2D(\n",
168
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
169
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
170
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
171
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
172
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
173
+ " (nonlinearity): SiLU()\n",
174
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
175
+ " )\n",
176
+ " (1-3): 3 x ResnetBlock2D(\n",
177
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
178
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
179
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
180
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
181
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
182
+ " (nonlinearity): SiLU()\n",
183
+ " )\n",
184
+ " )\n",
185
+ " (upsamplers): ModuleList(\n",
186
+ " (0): Upsample2D(\n",
187
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
188
+ " )\n",
189
+ " )\n",
190
+ " )\n",
191
+ " (3): UpDecoderBlock2D(\n",
192
+ " (resnets): ModuleList(\n",
193
+ " (0): ResnetBlock2D(\n",
194
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
195
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
196
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
197
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
198
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
199
+ " (nonlinearity): SiLU()\n",
200
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
201
+ " )\n",
202
+ " (1-3): 3 x ResnetBlock2D(\n",
203
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
204
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
205
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
206
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
207
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
208
+ " (nonlinearity): SiLU()\n",
209
+ " )\n",
210
+ " )\n",
211
+ " (upsamplers): ModuleList(\n",
212
+ " (0): Upsample2D(\n",
213
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
214
+ " )\n",
215
+ " )\n",
216
+ " )\n",
217
+ " (4): UpDecoderBlock2D(\n",
218
+ " (resnets): ModuleList(\n",
219
+ " (0-3): 4 x ResnetBlock2D(\n",
220
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
221
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
222
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
223
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
224
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
225
+ " (nonlinearity): SiLU()\n",
226
+ " )\n",
227
+ " )\n",
228
+ " )\n",
229
+ " )\n",
230
+ " (mid_block): UNetMidBlock2D(\n",
231
+ " (attentions): ModuleList(\n",
232
+ " (0): Attention(\n",
233
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
234
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
235
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
236
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
237
+ " (to_out): ModuleList(\n",
238
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
239
+ " (1): Dropout(p=0.0, inplace=False)\n",
240
+ " )\n",
241
+ " )\n",
242
+ " )\n",
243
+ " (resnets): ModuleList(\n",
244
+ " (0-1): 2 x ResnetBlock2D(\n",
245
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
246
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
247
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
248
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
249
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
250
+ " (nonlinearity): SiLU()\n",
251
+ " )\n",
252
+ " )\n",
253
+ " )\n",
254
+ " (condition_encoder): MaskConditionEncoder(\n",
255
+ " (layers): Sequential(\n",
256
+ " (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
257
+ " (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
258
+ " (2): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
259
+ " (3): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
260
+ " (4): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
261
+ " )\n",
262
+ " )\n",
263
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
264
+ " (conv_act): SiLU()\n",
265
+ " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
266
+ " )\n",
267
+ " (quant_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
268
+ " (post_quant_conv): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))\n",
269
+ ")\n"
270
+ ]
271
+ },
272
+ {
273
+ "name": "stderr",
274
+ "output_type": "stream",
275
+ "text": [
276
+ "\n"
277
+ ]
278
+ }
279
+ ],
280
+ "source": [
281
+ "from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL\n",
282
+ "import torch\n",
283
+ "from tqdm import tqdm\n",
284
+ "\n",
285
+ "# ---- Конфиг новой модели ----\n",
286
+ "config = {\n",
287
+ " \"_class_name\": \"AsymmetricAutoencoderKL\",\n",
288
+ " \"act_fn\": \"silu\",\n",
289
+ " \"in_channels\": 3,\n",
290
+ " \"out_channels\": 3,\n",
291
+ " \"scaling_factor\": 1.0,\n",
292
+ " \"norm_num_groups\": 32,\n",
293
+ " \"down_block_out_channels\": [128, 128, 256, 512, 512],\n",
294
+ " \"down_block_types\": [\n",
295
+ " \"DownEncoderBlock2D\",\n",
296
+ " \"DownEncoderBlock2D\",\n",
297
+ " \"DownEncoderBlock2D\",\n",
298
+ " \"DownEncoderBlock2D\",\n",
299
+ " ],\n",
300
+ " \"latent_channels\": 16,\n",
301
+ " \"up_block_out_channels\": [128, 128, 256, 512, 512],\n",
302
+ " \"up_block_types\": [\n",
303
+ " \"UpDecoderBlock2D\",\n",
304
+ " \"UpDecoderBlock2D\",\n",
305
+ " \"UpDecoderBlock2D\",\n",
306
+ " \"UpDecoderBlock2D\",\n",
307
+ " \"UpDecoderBlock2D\",\n",
308
+ " ],\n",
309
+ "}\n",
310
+ "\n",
311
+ "# ---- Создание пустой асимметричной модели ----\n",
312
+ "vae = AsymmetricAutoencoderKL(\n",
313
+ " act_fn=config[\"act_fn\"],\n",
314
+ " down_block_out_channels=config[\"down_block_out_channels\"],\n",
315
+ " down_block_types=config[\"down_block_types\"],\n",
316
+ " latent_channels=config[\"latent_channels\"],\n",
317
+ " up_block_out_channels=config[\"up_block_out_channels\"],\n",
318
+ " up_block_types=config[\"up_block_types\"],\n",
319
+ " in_channels=config[\"in_channels\"],\n",
320
+ " out_channels=config[\"out_channels\"],\n",
321
+ " scaling_factor=config[\"scaling_factor\"],\n",
322
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
323
+ " layers_per_down_block=2,\n",
324
+ " layers_per_up_block=3,\n",
325
+ " sample_size=1024\n",
326
+ ")\n",
327
+ "\n",
328
+ "vae.save_pretrained(\"asymmetric_vae_empty\")\n",
329
+ "print(\"✅ Создана новая модель:\", type(vae))\n",
330
+ "\n",
331
+ "# ---- Функция переноса весов старого VAE ----\n",
332
+ "def transfer_weights(old_path, new_path, save_path=\"asymmetric_vae\", device=\"cuda\", dtype=torch.float16):\n",
333
+ " old_vae = AsymmetricAutoencoderKL.from_pretrained(old_path).to(device, dtype=dtype)\n",
334
+ " new_vae = AsymmetricAutoencoderKL.from_pretrained(new_path).to(device, dtype=dtype)\n",
335
+ "\n",
336
+ " old_sd = old_vae.state_dict()\n",
337
+ " new_sd = new_vae.state_dict()\n",
338
+ "\n",
339
+ " transferred_keys = set()\n",
340
+ " transfer_stats = {\"перенесено\": 0, \"дублировано\": 0, \"пропущено\": 0}\n",
341
+ "\n",
342
+ " print(\"\\n--- Перенос весов ---\")\n",
343
+ " for k, v in tqdm(old_sd.items()):\n",
344
+ " # Копирование энкодера и прочих совпадающих ключей\n",
345
+ " if (\"encoder\" in k) or (\"quant_conv\" in k) or (\"post_quant_conv\" in k):\n",
346
+ " if k in new_sd and new_sd[k].shape == v.shape:\n",
347
+ " new_sd[k] = v.clone()\n",
348
+ " transferred_keys.add(k)\n",
349
+ " transfer_stats[\"перенесено\"] += 1\n",
350
+ " continue\n",
351
+ "\n",
352
+ " # Копирование декодера (без сдвига)\n",
353
+ " if \"decoder.up_blocks\" in k:\n",
354
+ " if k in new_sd and new_sd[k].shape == v.shape:\n",
355
+ " new_sd[k] = v.clone()\n",
356
+ " transferred_keys.add(k)\n",
357
+ " transfer_stats[\"перенесено\"] += 1\n",
358
+ " continue\n",
359
+ "\n",
360
+ " # Дублирование весов старого первого 512→512 блока в новый блок 64→128 для апскейла\n",
361
+ " ref_prefix = \"encoder.down_blocks.1\"\n",
362
+ " new_prefix = \"encoder.down_blocks.0\"\n",
363
+ " for k, v in old_sd.items():\n",
364
+ " if k.startswith(ref_prefix) and new_prefix + k[len(ref_prefix):] in new_sd:\n",
365
+ " new_k = k.replace(ref_prefix, new_prefix)\n",
366
+ " if new_sd[new_k].shape == v.shape:\n",
367
+ " new_sd[new_k] = v.clone()\n",
368
+ " transferred_keys.add(new_k)\n",
369
+ " transfer_stats[\"дублировано\"] += 1\n",
370
+ "\n",
371
+ " # Загрузка и сохранение\n",
372
+ " new_vae.load_state_dict(new_sd, strict=False)\n",
373
+ " new_vae.save_pretrained(save_path)\n",
374
+ "\n",
375
+ " print(\"\\n✅ Перенос завершён.\")\n",
376
+ " print(\"Статистика:\")\n",
377
+ " for k, v in transfer_stats.items():\n",
378
+ " print(f\" {k}: {v}\")\n",
379
+ " print(new_vae)\n",
380
+ "\n",
381
+ "# ---- Запуск переноса ----\n",
382
+ "transfer_weights(\"vae10\", \"asymmetric_vae_empty\", save_path=\"vae11\")\n"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "id": "59fcafb9-6d89-49b4-8362-b4891f591687",
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": []
392
+ }
393
+ ],
394
+ "metadata": {
395
+ "kernelspec": {
396
+ "display_name": "Python 3 (ipykernel)",
397
+ "language": "python",
398
+ "name": "python3"
399
+ },
400
+ "language_info": {
401
+ "codemirror_mode": {
402
+ "name": "ipython",
403
+ "version": 3
404
+ },
405
+ "file_extension": ".py",
406
+ "mimetype": "text/x-python",
407
+ "name": "python",
408
+ "nbconvert_exporter": "python",
409
+ "pygments_lexer": "ipython3",
410
+ "version": "3.11.10"
411
+ }
412
+ },
413
+ "nbformat": 4,
414
+ "nbformat_minor": 5
415
+ }
vae_comp.ipynb ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "b3b23a40-8354-4287-bac2-32f9d084fff3",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py:202: UserWarning: The `local_dir_use_symlinks` argument is deprecated and ignored in `hf_hub_download`. Downloading to a local directory does not use symlinks anymore.\n",
14
+ " warnings.warn(\n"
15
+ ]
16
+ },
17
+ {
18
+ "data": {
19
+ "application/vnd.jupyter.widget-view+json": {
20
+ "model_id": "96d38ff0fa134b02a5a21c96bdfd36b5",
21
+ "version_major": 2,
22
+ "version_minor": 0
23
+ },
24
+ "text/plain": [
25
+ "vae/config.json: 0%| | 0.00/752 [00:00<?, ?B/s]"
26
+ ]
27
+ },
28
+ "metadata": {},
29
+ "output_type": "display_data"
30
+ },
31
+ {
32
+ "data": {
33
+ "application/vnd.jupyter.widget-view+json": {
34
+ "model_id": "0a44c60705d44f58b5a07ead45936327",
35
+ "version_major": 2,
36
+ "version_minor": 0
37
+ },
38
+ "text/plain": [
39
+ "vae/diffusion_pytorch_model.safetensors: 0%| | 0.00/191M [00:00<?, ?B/s]"
40
+ ]
41
+ },
42
+ "metadata": {},
43
+ "output_type": "display_data"
44
+ },
45
+ {
46
+ "name": "stdout",
47
+ "output_type": "stream",
48
+ "text": [
49
+ "sdxs_vae log-variance: 1.840\n"
50
+ ]
51
+ },
52
+ {
53
+ "name": "stderr",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
57
+ ]
58
+ },
59
+ {
60
+ "name": "stdout",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "vae9 log-variance: 1.840\n",
64
+ "Готово\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "import torch\n",
70
+ "from PIL import Image\n",
71
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
72
+ "from torchvision.transforms.functional import to_pil_image\n",
73
+ "import matplotlib.pyplot as plt\n",
74
+ "import os\n",
75
+ "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n",
76
+ "\n",
77
+ "# путь к вашей картинке\n",
78
+ "IMG_PATH = \"1234567890.png\"\n",
79
+ "OUT_DIR = \"vaetest\"\n",
80
+ "device = \"cuda\"\n",
81
+ "dtype = torch.float32 # ← единый float32\n",
82
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
83
+ "\n",
84
+ "# список VAE\n",
85
+ "VAES = {\n",
86
+ " #\"sdxl\": \"madebyollin/sdxl-vae-fp16-fix\",\n",
87
+ " \"sdxs_vae\": \"AiArtLab/sdxs-1b\",\n",
88
+ " #\"vae8\": \"/workspace/simplevae2x/vae8\",\n",
89
+ " \"vae9\": \"/workspace/simplevae2x/vae9\"\n",
90
+ "}\n",
91
+ "\n",
92
+ "def load_image(path):\n",
93
+ " img = Image.open(path).convert('RGB')\n",
94
+ " # обрезаем до кратности 8\n",
95
+ " w, h = img.size\n",
96
+ " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n",
97
+ " tensor = ToTensor()(img).unsqueeze(0) # [0,1]\n",
98
+ " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor) # [-1,1]\n",
99
+ " return img, tensor.to(device, dtype=dtype)\n",
100
+ "\n",
101
+ "# обратно в PIL\n",
102
+ "def tensor_to_img(t):\n",
103
+ " t = (t * 0.5 + 0.5).clamp(0, 1)\n",
104
+ " return to_pil_image(t[0])\n",
105
+ "\n",
106
+ "def logvariance(latents):\n",
107
+ " \"\"\"Возвращает лог-дисперсию по всем элементам.\"\"\"\n",
108
+ " return torch.log(latents.var() + 1e-8).item()\n",
109
+ "\n",
110
+ "def plot_latent_distribution(latents, title, save_path):\n",
111
+ " \"\"\"Гистограмма + QQ-plot.\"\"\"\n",
112
+ " lat = latents.detach().cpu().numpy().flatten()\n",
113
+ " plt.figure(figsize=(10, 4))\n",
114
+ "\n",
115
+ " # гистограмма\n",
116
+ " plt.subplot(1, 2, 1)\n",
117
+ " plt.hist(lat, bins=100, density=True, alpha=0.7, color='steelblue')\n",
118
+ " plt.title(f\"{title} histogram\")\n",
119
+ " plt.xlabel(\"latent value\")\n",
120
+ " plt.ylabel(\"density\")\n",
121
+ "\n",
122
+ " # QQ-plot\n",
123
+ " from scipy.stats import probplot\n",
124
+ " plt.subplot(1, 2, 2)\n",
125
+ " probplot(lat, dist=\"norm\", plot=plt)\n",
126
+ " plt.title(f\"{title} QQ-plot\")\n",
127
+ "\n",
128
+ " plt.tight_layout()\n",
129
+ " plt.savefig(save_path)\n",
130
+ " plt.close()\n",
131
+ "\n",
132
+ "for name, repo in VAES.items():\n",
133
+ " if name==\"sdxs_vae\":\n",
134
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, subfolder=\"vae\", torch_dtype=dtype).to(device)\n",
135
+ " else:\n",
136
+ " vae = AsymmetricAutoencoderKL.from_pretrained(repo, torch_dtype=dtype).to(device)#, subfolder=\"vae\", variant=\"fp16\"\n",
137
+ "\n",
138
+ " cfg = vae.config\n",
139
+ " scale = getattr(cfg, \"scaling_factor\", 1.)\n",
140
+ " shift = getattr(cfg, \"shift_factor\", 0.0)\n",
141
+ " mean = getattr(cfg, \"latents_mean\", None)\n",
142
+ " std = getattr(cfg, \"latents_std\", None)\n",
143
+ "\n",
144
+ " C = 4 # 4 для SDXL\n",
145
+ " if mean is not None:\n",
146
+ " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, C, 1, 1)\n",
147
+ " if std is not None:\n",
148
+ " std = torch.tensor(std, device=device, dtype=dtype).view(1, C, 1, 1)\n",
149
+ " if shift is not None:\n",
150
+ " shift = torch.tensor(shift, device=device, dtype=dtype)\n",
151
+ " else:\n",
152
+ " shift = 0.0 \n",
153
+ "\n",
154
+ " scale = torch.tensor(scale, device=device, dtype=dtype)\n",
155
+ "\n",
156
+ " img, x = load_image(IMG_PATH)\n",
157
+ " img.save(os.path.join(OUT_DIR, f\"original.png\"))\n",
158
+ "\n",
159
+ " with torch.no_grad():\n",
160
+ " # encode\n",
161
+ " latents = vae.encode(x).latent_dist.sample().to(dtype)\n",
162
+ " if mean is not None and std is not None:\n",
163
+ " latents = (latents - mean) / std\n",
164
+ " latents = latents * scale + shift\n",
165
+ "\n",
166
+ " lv = logvariance(latents)\n",
167
+ " print(f\"{name} log-variance: {lv:.3f}\")\n",
168
+ "\n",
169
+ " # график\n",
170
+ " plot_latent_distribution(latents, f\"{name}_latents\",\n",
171
+ " os.path.join(OUT_DIR, f\"dist_{name}.png\"))\n",
172
+ "\n",
173
+ " # decode\n",
174
+ " latents = (latents - shift) / scale\n",
175
+ " if mean is not None and std is not None:\n",
176
+ " latents = latents * std + mean\n",
177
+ " rec = vae.decode(latents).sample\n",
178
+ "\n",
179
+ " tensor_to_img(rec).save(os.path.join(OUT_DIR, f\"decoded_{name}.png\"))\n",
180
+ "\n",
181
+ "print(\"Готово\")"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "id": "200b72ab-1978-4d71-9aba-b1ef97cf0b27",
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": []
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "Python 3 (ipykernel)",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.11.10"
210
+ }
211
+ },
212
+ "nbformat": 4,
213
+ "nbformat_minor": 5
214
+ }