babkasotona commited on
Commit
c7fbaf5
·
verified ·
1 Parent(s): 63f4f49

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/config-checkpoint.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.37.0",
4
+ "_name_or_path": "vae16x32ch",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 128,
9
+ 256,
10
+ 512,
11
+ 512
12
+ ],
13
+ "down_block_types": [
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D",
17
+ "DownEncoderBlock2D",
18
+ "DownEncoderBlock2D"
19
+ ],
20
+ "force_upcast": false,
21
+ "in_channels": 3,
22
+ "latent_channels": 32,
23
+ "latents_mean": [
24
+ -0.03542253375053406,
25
+ 0.20086465775966644,
26
+ -0.016413161531090736,
27
+ -0.0956302210688591,
28
+ -0.2672063112258911,
29
+ 0.2609933018684387,
30
+ -0.07806991040706635,
31
+ -0.48407721519470215,
32
+ 0.21844269335269928,
33
+ -0.1122383326292038,
34
+ 0.27197545766830444,
35
+ -0.18958772718906403,
36
+ 0.18776826560497284,
37
+ 0.0987580344080925,
38
+ 0.2837068736553192,
39
+ -0.4486690163612366,
40
+ 0.4816776514053345,
41
+ 0.02947971224784851,
42
+ -0.1337375044822693,
43
+ -0.39750921726226807,
44
+ -0.08513020724058151,
45
+ -0.054023586213588715,
46
+ -0.3943594992160797,
47
+ 0.23918119072914124,
48
+ -0.12466679513454437,
49
+ 0.09935147315263748,
50
+ 0.31858691573143005,
51
+ 0.48585832118988037,
52
+ -0.6416525840759277,
53
+ -0.15164820849895477,
54
+ -0.4693508744239807,
55
+ -0.13071806728839874
56
+ ],
57
+ "latents_std": [
58
+ 1.5792087316513062,
59
+ 1.5769503116607666,
60
+ 1.5864241123199463,
61
+ 1.6454921960830688,
62
+ 1.5336694717407227,
63
+ 1.5587652921676636,
64
+ 1.5838669538497925,
65
+ 1.5659377574920654,
66
+ 1.6860467195510864,
67
+ 1.5192310810089111,
68
+ 1.573639988899231,
69
+ 1.5953549146652222,
70
+ 1.5271092653274536,
71
+ 1.6246271133422852,
72
+ 1.7054023742675781,
73
+ 1.607722282409668,
74
+ 1.558642864227295,
75
+ 1.5824549198150635,
76
+ 1.6202995777130127,
77
+ 1.6206320524215698,
78
+ 1.6379750967025757,
79
+ 1.6527063846588135,
80
+ 1.498811960220337,
81
+ 1.5706247091293335,
82
+ 1.5854856967926025,
83
+ 1.4828169345855713,
84
+ 1.5693111419677734,
85
+ 1.692481517791748,
86
+ 1.6409776210784912,
87
+ 1.6216280460357666,
88
+ 1.6087706089019775,
89
+ 1.5776633024215698
90
+ ],
91
+ "layers_per_block": 2,
92
+ "mid_block_add_attention": true,
93
+ "norm_num_groups": 32,
94
+ "out_channels": 3,
95
+ "sample_size": 32,
96
+ "scaling_factor": 1.0,
97
+ "shift_factor": null,
98
+ "up_block_types": [
99
+ "UpDecoderBlock2D",
100
+ "UpDecoderBlock2D",
101
+ "UpDecoderBlock2D",
102
+ "UpDecoderBlock2D",
103
+ "UpDecoderBlock2D"
104
+ ],
105
+ "use_post_quant_conv": true,
106
+ "use_quant_conv": true
107
+ }
.ipynb_checkpoints/create_symmetric-checkpoint.ipynb ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 10,
6
+ "id": "407171be-ab46-442b-a0bd-83ca75173eba",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "AutoencoderKL(\n",
14
+ " (encoder): Encoder(\n",
15
+ " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
16
+ " (down_blocks): ModuleList(\n",
17
+ " (0-1): 2 x DownEncoderBlock2D(\n",
18
+ " (resnets): ModuleList(\n",
19
+ " (0-1): 2 x ResnetBlock2D(\n",
20
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
21
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
22
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
23
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
24
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
25
+ " (nonlinearity): SiLU()\n",
26
+ " )\n",
27
+ " )\n",
28
+ " (downsamplers): ModuleList(\n",
29
+ " (0): Downsample2D(\n",
30
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
31
+ " )\n",
32
+ " )\n",
33
+ " )\n",
34
+ " (2): DownEncoderBlock2D(\n",
35
+ " (resnets): ModuleList(\n",
36
+ " (0): ResnetBlock2D(\n",
37
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
38
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
39
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
40
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
41
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
42
+ " (nonlinearity): SiLU()\n",
43
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
44
+ " )\n",
45
+ " (1): ResnetBlock2D(\n",
46
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
47
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
48
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
49
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
50
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
51
+ " (nonlinearity): SiLU()\n",
52
+ " )\n",
53
+ " )\n",
54
+ " (downsamplers): ModuleList(\n",
55
+ " (0): Downsample2D(\n",
56
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
57
+ " )\n",
58
+ " )\n",
59
+ " )\n",
60
+ " (3): DownEncoderBlock2D(\n",
61
+ " (resnets): ModuleList(\n",
62
+ " (0): ResnetBlock2D(\n",
63
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
64
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
65
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
66
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
67
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
68
+ " (nonlinearity): SiLU()\n",
69
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
70
+ " )\n",
71
+ " (1): ResnetBlock2D(\n",
72
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
73
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
74
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
75
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
76
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
77
+ " (nonlinearity): SiLU()\n",
78
+ " )\n",
79
+ " )\n",
80
+ " (downsamplers): ModuleList(\n",
81
+ " (0): Downsample2D(\n",
82
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n",
83
+ " )\n",
84
+ " )\n",
85
+ " )\n",
86
+ " (4): DownEncoderBlock2D(\n",
87
+ " (resnets): ModuleList(\n",
88
+ " (0-1): 2 x ResnetBlock2D(\n",
89
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
90
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
91
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
92
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
93
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
94
+ " (nonlinearity): SiLU()\n",
95
+ " )\n",
96
+ " )\n",
97
+ " )\n",
98
+ " )\n",
99
+ " (mid_block): UNetMidBlock2D(\n",
100
+ " (attentions): ModuleList(\n",
101
+ " (0): Attention(\n",
102
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
103
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
104
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
105
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
106
+ " (to_out): ModuleList(\n",
107
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
108
+ " (1): Dropout(p=0.0, inplace=False)\n",
109
+ " )\n",
110
+ " )\n",
111
+ " )\n",
112
+ " (resnets): ModuleList(\n",
113
+ " (0-1): 2 x ResnetBlock2D(\n",
114
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
115
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
116
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
117
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
118
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
119
+ " (nonlinearity): SiLU()\n",
120
+ " )\n",
121
+ " )\n",
122
+ " )\n",
123
+ " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
124
+ " (conv_act): SiLU()\n",
125
+ " (conv_out): Conv2d(512, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
126
+ " )\n",
127
+ " (decoder): Decoder(\n",
128
+ " (conv_in): Conv2d(32, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
129
+ " (up_blocks): ModuleList(\n",
130
+ " (0-1): 2 x UpDecoderBlock2D(\n",
131
+ " (resnets): ModuleList(\n",
132
+ " (0-2): 3 x ResnetBlock2D(\n",
133
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
134
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
135
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
136
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
137
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
138
+ " (nonlinearity): SiLU()\n",
139
+ " )\n",
140
+ " )\n",
141
+ " (upsamplers): ModuleList(\n",
142
+ " (0): Upsample2D(\n",
143
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
144
+ " )\n",
145
+ " )\n",
146
+ " )\n",
147
+ " (2): UpDecoderBlock2D(\n",
148
+ " (resnets): ModuleList(\n",
149
+ " (0): ResnetBlock2D(\n",
150
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
151
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
152
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
153
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
154
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
155
+ " (nonlinearity): SiLU()\n",
156
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
157
+ " )\n",
158
+ " (1-2): 2 x ResnetBlock2D(\n",
159
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
160
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
161
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
162
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
163
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
164
+ " (nonlinearity): SiLU()\n",
165
+ " )\n",
166
+ " )\n",
167
+ " (upsamplers): ModuleList(\n",
168
+ " (0): Upsample2D(\n",
169
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
170
+ " )\n",
171
+ " )\n",
172
+ " )\n",
173
+ " (3): UpDecoderBlock2D(\n",
174
+ " (resnets): ModuleList(\n",
175
+ " (0): ResnetBlock2D(\n",
176
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
177
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
178
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
179
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
180
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
181
+ " (nonlinearity): SiLU()\n",
182
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
183
+ " )\n",
184
+ " (1-2): 2 x ResnetBlock2D(\n",
185
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
186
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
187
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
188
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
189
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
190
+ " (nonlinearity): SiLU()\n",
191
+ " )\n",
192
+ " )\n",
193
+ " (upsamplers): ModuleList(\n",
194
+ " (0): Upsample2D(\n",
195
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
196
+ " )\n",
197
+ " )\n",
198
+ " )\n",
199
+ " (4): UpDecoderBlock2D(\n",
200
+ " (resnets): ModuleList(\n",
201
+ " (0-2): 3 x ResnetBlock2D(\n",
202
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
203
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
204
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
205
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
206
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
207
+ " (nonlinearity): SiLU()\n",
208
+ " )\n",
209
+ " )\n",
210
+ " )\n",
211
+ " )\n",
212
+ " (mid_block): UNetMidBlock2D(\n",
213
+ " (attentions): ModuleList(\n",
214
+ " (0): Attention(\n",
215
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
216
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
217
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
218
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
219
+ " (to_out): ModuleList(\n",
220
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
221
+ " (1): Dropout(p=0.0, inplace=False)\n",
222
+ " )\n",
223
+ " )\n",
224
+ " )\n",
225
+ " (resnets): ModuleList(\n",
226
+ " (0-1): 2 x ResnetBlock2D(\n",
227
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
228
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
229
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
230
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
231
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
232
+ " (nonlinearity): SiLU()\n",
233
+ " )\n",
234
+ " )\n",
235
+ " )\n",
236
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
237
+ " (conv_act): SiLU()\n",
238
+ " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
239
+ " )\n",
240
+ " (quant_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
241
+ " (post_quant_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
242
+ ")\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "from diffusers.models import AutoencoderKL\n",
248
+ "import torch\n",
249
+ "\n",
250
+ "config = {\n",
251
+ " \"_class_name\": \"AutoencoderKL\",\n",
252
+ " \"_diffusers_version\": \"0.36.0\",\n",
253
+ " \"act_fn\": \"silu\",\n",
254
+ " \"block_out_channels\": [\n",
255
+ " 128,\n",
256
+ " 128,\n",
257
+ " 256,\n",
258
+ " 512,\n",
259
+ " 512\n",
260
+ " ],\n",
261
+ " \"down_block_types\": [\n",
262
+ " \"DownEncoderBlock2D\",\n",
263
+ " \"DownEncoderBlock2D\",\n",
264
+ " \"DownEncoderBlock2D\",\n",
265
+ " \"DownEncoderBlock2D\",\n",
266
+ " \"DownEncoderBlock2D\"\n",
267
+ " ],\n",
268
+ " \"force_upcast\": False,\n",
269
+ " \"in_channels\": 3,\n",
270
+ " \"latent_channels\": 32,\n",
271
+ " \"latents_mean\": [\n",
272
+ " -0.03542253375053406,\n",
273
+ " 0.20086465775966644,\n",
274
+ " -0.016413161531090736,\n",
275
+ " -0.0956302210688591,\n",
276
+ " -0.2672063112258911,\n",
277
+ " 0.2609933018684387,\n",
278
+ " -0.07806991040706635,\n",
279
+ " -0.48407721519470215,\n",
280
+ " 0.21844269335269928,\n",
281
+ " -0.1122383326292038,\n",
282
+ " 0.27197545766830444,\n",
283
+ " -0.18958772718906403,\n",
284
+ " 0.18776826560497284,\n",
285
+ " 0.0987580344080925,\n",
286
+ " 0.2837068736553192,\n",
287
+ " -0.4486690163612366,\n",
288
+ " 0.4816776514053345,\n",
289
+ " 0.02947971224784851,\n",
290
+ " -0.1337375044822693,\n",
291
+ " -0.39750921726226807,\n",
292
+ " -0.08513020724058151,\n",
293
+ " -0.054023586213588715,\n",
294
+ " -0.3943594992160797,\n",
295
+ " 0.23918119072914124,\n",
296
+ " -0.12466679513454437,\n",
297
+ " 0.09935147315263748,\n",
298
+ " 0.31858691573143005,\n",
299
+ " 0.48585832118988037,\n",
300
+ " -0.6416525840759277,\n",
301
+ " -0.15164820849895477,\n",
302
+ " -0.4693508744239807,\n",
303
+ " -0.13071806728839874\n",
304
+ " ],\n",
305
+ " \"latents_std\": [\n",
306
+ " 1.5792087316513062,\n",
307
+ " 1.5769503116607666,\n",
308
+ " 1.5864241123199463,\n",
309
+ " 1.6454921960830688,\n",
310
+ " 1.5336694717407227,\n",
311
+ " 1.5587652921676636,\n",
312
+ " 1.5838669538497925,\n",
313
+ " 1.5659377574920654,\n",
314
+ " 1.6860467195510864,\n",
315
+ " 1.5192310810089111,\n",
316
+ " 1.573639988899231,\n",
317
+ " 1.5953549146652222,\n",
318
+ " 1.5271092653274536,\n",
319
+ " 1.6246271133422852,\n",
320
+ " 1.7054023742675781,\n",
321
+ " 1.607722282409668,\n",
322
+ " 1.558642864227295,\n",
323
+ " 1.5824549198150635,\n",
324
+ " 1.6202995777130127,\n",
325
+ " 1.6206320524215698,\n",
326
+ " 1.6379750967025757,\n",
327
+ " 1.6527063846588135,\n",
328
+ " 1.498811960220337,\n",
329
+ " 1.5706247091293335,\n",
330
+ " 1.5854856967926025,\n",
331
+ " 1.4828169345855713,\n",
332
+ " 1.5693111419677734,\n",
333
+ " 1.692481517791748,\n",
334
+ " 1.6409776210784912,\n",
335
+ " 1.6216280460357666,\n",
336
+ " 1.6087706089019775,\n",
337
+ " 1.5776633024215698\n",
338
+ " ],\n",
339
+ " \"layers_per_block\": 2,\n",
340
+ " \"mid_block_add_attention\": True,\n",
341
+ " \"norm_num_groups\": 32,\n",
342
+ " \"out_channels\": 3,\n",
343
+ " \"sample_size\": 32,\n",
344
+ " \"scaling_factor\": 1.0,\n",
345
+ " \"shift_factor\": 0.0,\n",
346
+ " \"up_block_types\": [\n",
347
+ " \"UpDecoderBlock2D\",\n",
348
+ " \"UpDecoderBlock2D\",\n",
349
+ " \"UpDecoderBlock2D\",\n",
350
+ " \"UpDecoderBlock2D\",\n",
351
+ " \"UpDecoderBlock2D\"\n",
352
+ " ],\n",
353
+ " \"use_post_quant_conv\": True,\n",
354
+ " \"use_quant_conv\": True\n",
355
+ "}\n",
356
+ "\n",
357
+ "\n",
358
+ "vae = AutoencoderKL(\n",
359
+ " act_fn=config[\"act_fn\"],\n",
360
+ " block_out_channels=config[\"block_out_channels\"],\n",
361
+ " down_block_types=config[\"down_block_types\"],\n",
362
+ " up_block_types=config[\"up_block_types\"],\n",
363
+ " in_channels=config[\"in_channels\"],\n",
364
+ " out_channels=config[\"out_channels\"],\n",
365
+ " latent_channels=config[\"latent_channels\"],\n",
366
+ " layers_per_block=config[\"layers_per_block\"],\n",
367
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
368
+ " sample_size=config[\"sample_size\"],\n",
369
+ " scaling_factor=config[\"scaling_factor\"],\n",
370
+ " force_upcast=config[\"force_upcast\"],\n",
371
+ " mid_block_add_attention=config[\"mid_block_add_attention\"],\n",
372
+ " use_quant_conv=config[\"use_quant_conv\"],\n",
373
+ " use_post_quant_conv=config[\"use_post_quant_conv\"],\n",
374
+ " latents_mean=(config[\"latents_mean\"]),\n",
375
+ " latents_std=(config[\"latents_std\"]),\n",
376
+ ")\n",
377
+ "\n",
378
+ "vae.save_pretrained(\"vae16x32ch_empty\")\n",
379
+ "\n",
380
+ "print(vae)"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 6,
386
+ "id": "a2950158-5203-42b9-8791-e231ddbf1063",
387
+ "metadata": {},
388
+ "outputs": [
389
+ {
390
+ "name": "stderr",
391
+ "output_type": "stream",
392
+ "text": [
393
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False, 'latents_mean': [-0.03542253375053406, 0.20086465775966644, -0.016413161531090736, -0.0956302210688591, -0.2672063112258911, 0.2609933018684387, -0.07806991040706635, -0.48407721519470215, 0.21844269335269928, -0.1122383326292038, 0.27197545766830444, -0.18958772718906403, 0.18776826560497284, 0.0987580344080925, 0.2837068736553192, -0.4486690163612366, 0.4816776514053345, 0.02947971224784851, -0.1337375044822693, -0.39750921726226807, -0.08513020724058151, -0.054023586213588715, -0.3943594992160797, 0.23918119072914124, -0.12466679513454437, 0.09935147315263748, 0.31858691573143005, 0.48585832118988037, -0.6416525840759277, -0.15164820849895477, -0.4693508744239807, -0.13071806728839874], 'latents_std': [1.5792087316513062, 1.5769503116607666, 1.5864241123199463, 1.6454921960830688, 1.5336694717407227, 1.5587652921676636, 1.5838669538497925, 1.5659377574920654, 1.6860467195510864, 1.5192310810089111, 1.573639988899231, 1.5953549146652222, 1.5271092653274536, 1.6246271133422852, 1.7054023742675781, 1.607722282409668, 1.558642864227295, 1.5824549198150635, 1.6202995777130127, 1.6206320524215698, 1.6379750967025757, 1.6527063846588135, 1.498811960220337, 1.5706247091293335, 1.5854856967926025, 1.4828169345855713, 1.5693111419677734, 1.692481517791748, 1.6409776210784912, 1.6216280460357666, 1.6087706089019775, 1.5776633024215698]} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
394
+ "Перенос весов: 100%|██████████| 284/284 [00:00<00:00, 38362.12it/s]\n"
395
+ ]
396
+ },
397
+ {
398
+ "name": "stdout",
399
+ "output_type": "stream",
400
+ "text": [
401
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.conv1.weight (torch.Size([256, 128, 3, 3])) -> encoder.down_blocks.1.resnets.0.conv1.weight (torch.Size([128, 128, 3, 3]))\n",
402
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.conv1.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.0.conv1.bias (torch.Size([128]))\n",
403
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.norm2.weight (torch.Size([256])) -> encoder.down_blocks.1.resnets.0.norm2.weight (torch.Size([128]))\n",
404
+ "✗ Нес��впадение размеров: encoder.down_blocks.1.resnets.0.norm2.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.0.norm2.bias (torch.Size([128]))\n",
405
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.conv2.weight (torch.Size([256, 256, 3, 3])) -> encoder.down_blocks.1.resnets.0.conv2.weight (torch.Size([128, 128, 3, 3]))\n",
406
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.conv2.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.0.conv2.bias (torch.Size([128]))\n",
407
+ "? Ключ не найден в новой модели: encoder.down_blocks.1.resnets.0.conv_shortcut.weight -> torch.Size([256, 128, 1, 1])\n",
408
+ "? Ключ не найден в новой модели: encoder.down_blocks.1.resnets.0.conv_shortcut.bias -> torch.Size([256])\n",
409
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.norm1.weight (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.norm1.weight (torch.Size([128]))\n",
410
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.norm1.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.norm1.bias (torch.Size([128]))\n",
411
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.conv1.weight (torch.Size([256, 256, 3, 3])) -> encoder.down_blocks.1.resnets.1.conv1.weight (torch.Size([128, 128, 3, 3]))\n",
412
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.conv1.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.conv1.bias (torch.Size([128]))\n",
413
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.norm2.weight (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.norm2.weight (torch.Size([128]))\n",
414
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.norm2.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.norm2.bias (torch.Size([128]))\n",
415
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.conv2.weight (torch.Size([256, 256, 3, 3])) -> encoder.down_blocks.1.resnets.1.conv2.weight (torch.Size([128, 128, 3, 3]))\n",
416
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.conv2.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.conv2.bias (torch.Size([128]))\n",
417
+ "✗ Несовпадение размеров: encoder.down_blocks.1.downsamplers.0.conv.weight (torch.Size([256, 256, 3, 3])) -> encoder.down_blocks.1.downsamplers.0.conv.weight (torch.Size([128, 128, 3, 3]))\n",
418
+ "✗ Несовпадение размеров: encoder.down_blocks.1.downsamplers.0.conv.bias (torch.Size([256])) -> encoder.down_blocks.1.downsamplers.0.conv.bias (torch.Size([128]))\n",
419
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.norm1.weight (torch.Size([256])) -> encoder.down_blocks.2.resnets.0.norm1.weight (torch.Size([128]))\n",
420
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.norm1.bias (torch.Size([256])) -> encoder.down_blocks.2.resnets.0.norm1.bias (torch.Size([128]))\n",
421
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv1.weight (torch.Size([512, 256, 3, 3])) -> encoder.down_blocks.2.resnets.0.conv1.weight (torch.Size([256, 128, 3, 3]))\n",
422
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv1.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.conv1.bias (torch.Size([256]))\n",
423
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.norm2.weight (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.norm2.weight (torch.Size([256]))\n",
424
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.norm2.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.norm2.bias (torch.Size([256]))\n",
425
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.2.resnets.0.conv2.weight (torch.Size([256, 256, 3, 3]))\n",
426
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv2.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.conv2.bias (torch.Size([256]))\n",
427
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv_shortcut.weight (torch.Size([512, 256, 1, 1])) -> encoder.down_blocks.2.resnets.0.conv_shortcut.weight (torch.Size([256, 128, 1, 1]))\n",
428
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv_shortcut.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.conv_shortcut.bias (torch.Size([256]))\n",
429
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.norm1.weight (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.norm1.weight (torch.Size([256]))\n",
430
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.norm1.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.norm1.bias (torch.Size([256]))\n",
431
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.2.resnets.1.conv1.weight (torch.Size([256, 256, 3, 3]))\n",
432
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.conv1.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.conv1.bias (torch.Size([256]))\n",
433
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.norm2.weight (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.norm2.weight (torch.Size([256]))\n",
434
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.norm2.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.norm2.bias (torch.Size([256]))\n",
435
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.2.resnets.1.conv2.weight (torch.Size([256, 256, 3, 3]))\n",
436
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.conv2.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.conv2.bias (torch.Size([256]))\n",
437
+ "✗ Несовпадение размеров: encoder.down_blocks.2.downsamplers.0.conv.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.2.downsamplers.0.conv.weight (torch.Size([256, 256, 3, 3]))\n",
438
+ "✗ Несовпадение размеров: encoder.down_blocks.2.downsamplers.0.conv.bias (torch.Size([512])) -> encoder.down_blocks.2.downsamplers.0.conv.bias (torch.Size([256]))\n",
439
+ "✗ Несовпадение размеров: encoder.down_blocks.3.resnets.0.norm1.weight (torch.Size([512])) -> encoder.down_blocks.3.resnets.0.norm1.weight (torch.Size([256]))\n",
440
+ "✗ Несовпадение размеров: encoder.down_blocks.3.resnets.0.norm1.bias (torch.Size([512])) -> encoder.down_blocks.3.resnets.0.norm1.bias (torch.Size([256]))\n",
441
+ "✗ Несовпадение размеров: encoder.down_blocks.3.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.3.resnets.0.conv1.weight (torch.Size([512, 256, 3, 3]))\n",
442
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.norm1.weight -> torch.Size([128])\n",
443
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.norm1.bias -> torch.Size([128])\n",
444
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.conv1.weight -> torch.Size([128, 128, 3, 3])\n",
445
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.conv1.bias -> torch.Size([128])\n",
446
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.norm2.weight -> torch.Size([128])\n",
447
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.norm2.bias -> torch.Size([128])\n",
448
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.conv2.weight -> torch.Size([128, 128, 3, 3])\n",
449
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.conv2.bias -> torch.Size([128])\n",
450
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.norm1.weight -> torch.Size([128])\n",
451
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.norm1.bias -> torch.Size([128])\n",
452
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.conv1.weight -> torch.Size([128, 128, 3, 3])\n",
453
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.conv1.bias -> torch.Size([128])\n",
454
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.norm2.weight -> torch.Size([128])\n",
455
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.norm2.bias -> torch.Size([128])\n",
456
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.conv2.weight -> torch.Size([128, 128, 3, 3])\n",
457
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.conv2.bias -> torch.Size([128])\n",
458
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.norm1.weight -> torch.Size([128])\n",
459
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.norm1.bias -> torch.Size([128])\n",
460
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.conv1.weight -> torch.Size([128, 128, 3, 3])\n",
461
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.conv1.bias -> torch.Size([128])\n",
462
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.norm2.weight -> torch.Size([128])\n",
463
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.norm2.bias -> torch.Size([128])\n",
464
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.conv2.weight -> torch.Size([128, 128, 3, 3])\n",
465
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.conv2.bias -> torch.Size([128])\n",
466
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.0.weight -> torch.Size([128, 3, 3, 3])\n",
467
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.0.bias -> torch.Size([128])\n",
468
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.1.weight -> torch.Size([256, 128, 3, 3])\n",
469
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.1.bias -> torch.Size([256])\n",
470
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.2.weight -> torch.Size([512, 256, 4, 4])\n",
471
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.2.bias -> torch.Size([512])\n",
472
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.3.weight -> torch.Size([512, 512, 4, 4])\n",
473
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.3.bias -> torch.Size([512])\n",
474
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.4.weight -> torch.Size([512, 512, 4, 4])\n",
475
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.4.bias -> torch.Size([512])\n",
476
+ "Статистика переноса: {'перенесено': 209, 'несовпадение_размеров': 39, 'пропущено': 36}\n",
477
+ "Неперенесенные ключи в новой модели:\n",
478
+ "encoder.down_blocks.1.downsamplers.0.conv.bias\n",
479
+ "encoder.down_blocks.1.downsamplers.0.conv.weight\n",
480
+ "encoder.down_blocks.1.resnets.0.conv1.bias\n",
481
+ "encoder.down_blocks.1.resnets.0.conv1.weight\n",
482
+ "encoder.down_blocks.1.resnets.0.conv2.bias\n",
483
+ "encoder.down_blocks.1.resnets.0.conv2.weight\n",
484
+ "encoder.down_blocks.1.resnets.0.norm2.bias\n",
485
+ "encoder.down_blocks.1.resnets.0.norm2.weight\n",
486
+ "encoder.down_blocks.1.resnets.1.conv1.bias\n",
487
+ "encoder.down_blocks.1.resnets.1.conv1.weight\n",
488
+ "encoder.down_blocks.1.resnets.1.conv2.bias\n",
489
+ "encoder.down_blocks.1.resnets.1.conv2.weight\n",
490
+ "encoder.down_blocks.1.resnets.1.norm1.bias\n",
491
+ "encoder.down_blocks.1.resnets.1.norm1.weight\n",
492
+ "encoder.down_blocks.1.resnets.1.norm2.bias\n",
493
+ "encoder.down_blocks.1.resnets.1.norm2.weight\n",
494
+ "encoder.down_blocks.2.downsamplers.0.conv.bias\n",
495
+ "encoder.down_blocks.2.downsamplers.0.conv.weight\n",
496
+ "encoder.down_blocks.2.resnets.0.conv1.bias\n",
497
+ "encoder.down_blocks.2.resnets.0.conv1.weight\n",
498
+ "encoder.down_blocks.2.resnets.0.conv2.bias\n",
499
+ "encoder.down_blocks.2.resnets.0.conv2.weight\n",
500
+ "encoder.down_blocks.2.resnets.0.conv_shortcut.bias\n",
501
+ "encoder.down_blocks.2.resnets.0.conv_shortcut.weight\n",
502
+ "encoder.down_blocks.2.resnets.0.norm1.bias\n",
503
+ "encoder.down_blocks.2.resnets.0.norm1.weight\n",
504
+ "encoder.down_blocks.2.resnets.0.norm2.bias\n",
505
+ "encoder.down_blocks.2.resnets.0.norm2.weight\n",
506
+ "encoder.down_blocks.2.resnets.1.conv1.bias\n",
507
+ "encoder.down_blocks.2.resnets.1.conv1.weight\n",
508
+ "encoder.down_blocks.2.resnets.1.conv2.bias\n",
509
+ "encoder.down_blocks.2.resnets.1.conv2.weight\n",
510
+ "encoder.down_blocks.2.resnets.1.norm1.bias\n",
511
+ "encoder.down_blocks.2.resnets.1.norm1.weight\n",
512
+ "encoder.down_blocks.2.resnets.1.norm2.bias\n",
513
+ "encoder.down_blocks.2.resnets.1.norm2.weight\n",
514
+ "encoder.down_blocks.3.downsamplers.0.conv.bias\n",
515
+ "encoder.down_blocks.3.downsamplers.0.conv.weight\n",
516
+ "encoder.down_blocks.3.resnets.0.conv1.weight\n",
517
+ "encoder.down_blocks.3.resnets.0.conv_shortcut.bias\n",
518
+ "encoder.down_blocks.3.resnets.0.conv_shortcut.weight\n",
519
+ "encoder.down_blocks.3.resnets.0.norm1.bias\n",
520
+ "encoder.down_blocks.3.resnets.0.norm1.weight\n"
521
+ ]
522
+ }
523
+ ],
524
+ "source": [
525
+ "import torch\n",
526
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
527
+ "from tqdm import tqdm\n",
528
+ "import torch.nn.init as init\n",
529
+ "\n",
530
+ "def log(message):\n",
531
+ " print(message)\n",
532
+ "\n",
533
+ "def main():\n",
534
+ " checkpoint_path_old = \"asymmetric_vae_new\"\n",
535
+ " checkpoint_path_new = \"vae16x32ch_empty\"\n",
536
+ " device = \"cuda\"\n",
537
+ " dtype = torch.float32\n",
538
+ "\n",
539
+ " # Загрузка моделей\n",
540
+ " old_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
541
+ " new_unet = AutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
542
+ "\n",
543
+ " old_state_dict = old_unet.state_dict()\n",
544
+ " new_state_dict = new_unet.state_dict()\n",
545
+ "\n",
546
+ " transferred_state_dict = {}\n",
547
+ " transfer_stats = {\n",
548
+ " \"перенесено\": 0,\n",
549
+ " \"несовпадение_размеров\": 0,\n",
550
+ " \"пропущено\": 0\n",
551
+ " }\n",
552
+ "\n",
553
+ " transferred_keys = set()\n",
554
+ "\n",
555
+ " # Обрабатываем каждый ключ старой модели\n",
556
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
557
+ " new_key = old_key\n",
558
+ "\n",
559
+ " if new_key in new_state_dict:\n",
560
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
561
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
562
+ " transferred_keys.add(new_key)\n",
563
+ " transfer_stats[\"перенесено\"] += 1\n",
564
+ " else:\n",
565
+ " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n",
566
+ " transfer_stats[\"несовпадение_размеров\"] += 1\n",
567
+ " else:\n",
568
+ " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n",
569
+ " transfer_stats[\"пропущено\"] += 1\n",
570
+ "\n",
571
+ " # Обновляем состояние новой модели перенесенными весами\n",
572
+ " new_state_dict.update(transferred_state_dict)\n",
573
+ " \n",
574
+ " # Инициализируем веса для нового mid блока\n",
575
+ " #new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n",
576
+ " \n",
577
+ " new_unet.load_state_dict(new_state_dict)\n",
578
+ " new_unet.save_pretrained(\"vae16x32ch\")\n",
579
+ "\n",
580
+ " # Получаем список неперенесенных ключей\n",
581
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
582
+ "\n",
583
+ " print(\"Статистика переноса:\", transfer_stats)\n",
584
+ " print(\"Неперенесенные ключи в новой модели:\")\n",
585
+ " for key in non_transferred_keys:\n",
586
+ " print(key)\n",
587
+ "\n",
588
+ "if __name__ == \"__main__\":\n",
589
+ " main()"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": 1,
595
+ "id": "b316ee6c-d295-4396-9177-78e39a53055b",
596
+ "metadata": {},
597
+ "outputs": [
598
+ {
599
+ "name": "stderr",
600
+ "output_type": "stream",
601
+ "text": [
602
+ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
603
+ ]
604
+ },
605
+ {
606
+ "name": "stdout",
607
+ "output_type": "stream",
608
+ "text": [
609
+ "ok\n"
610
+ ]
611
+ }
612
+ ],
613
+ "source": [
614
+ "import torch\n",
615
+ "\n",
616
+ "from torchvision import transforms, utils\n",
617
+ "\n",
618
+ "import diffusers\n",
619
+ "from diffusers import AsymmetricAutoencoderKL\n",
620
+ "\n",
621
+ "from diffusers.utils import load_image\n",
622
+ "\n",
623
+ "def crop_image_to_nearest_divisible_by_8(img):\n",
624
+ " # Check if the image height and width are divisible by 8\n",
625
+ " if img.shape[1] % 8 == 0 and img.shape[2] % 8 == 0:\n",
626
+ " return img\n",
627
+ " else:\n",
628
+ " # Calculate the closest lower resolution divisible by 8\n",
629
+ " new_height = img.shape[1] - (img.shape[1] % 8)\n",
630
+ " new_width = img.shape[2] - (img.shape[2] % 8)\n",
631
+ " \n",
632
+ " # Use CenterCrop to crop the image\n",
633
+ " transform = transforms.CenterCrop((new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR)\n",
634
+ " img = transform(img).to(torch.float32).clamp(-1, 1)\n",
635
+ " \n",
636
+ " return img\n",
637
+ " \n",
638
+ "to_tensor = transforms.ToTensor()\n",
639
+ "\n",
640
+ "device = \"cuda\"\n",
641
+ "dtype=torch.float16\n",
642
+ "vae = AsymmetricAutoencoderKL.from_pretrained(\"asymmetric_vae\",torch_dtype=dtype).to(device).eval()\n",
643
+ "\n",
644
+ "image = load_image(\"123456789.jpg\")\n",
645
+ "\n",
646
+ "image = crop_image_to_nearest_divisible_by_8(to_tensor(image)).unsqueeze(0).to(device,dtype=dtype)\n",
647
+ "\n",
648
+ "upscaled_image = vae(image).sample\n",
649
+ "#vae.config.scaled_factor\n",
650
+ "# Save the reconstructed image\n",
651
+ "utils.save_image(upscaled_image, \"test.png\")\n",
652
+ "print('ok')"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "code",
657
+ "execution_count": 11,
658
+ "id": "5a01b8e9-73c9-4da7-a097-e334019bd8e9",
659
+ "metadata": {},
660
+ "outputs": [
661
+ {
662
+ "name": "stderr",
663
+ "output_type": "stream",
664
+ "text": [
665
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False, 'latents_mean': [-0.03542253375053406, 0.20086465775966644, -0.016413161531090736, -0.0956302210688591, -0.2672063112258911, 0.2609933018684387, -0.07806991040706635, -0.48407721519470215, 0.21844269335269928, -0.1122383326292038, 0.27197545766830444, -0.18958772718906403, 0.18776826560497284, 0.0987580344080925, 0.2837068736553192, -0.4486690163612366, 0.4816776514053345, 0.02947971224784851, -0.1337375044822693, -0.39750921726226807, -0.08513020724058151, -0.054023586213588715, -0.3943594992160797, 0.23918119072914124, -0.12466679513454437, 0.09935147315263748, 0.31858691573143005, 0.48585832118988037, -0.6416525840759277, -0.15164820849895477, -0.4693508744239807, -0.13071806728839874], 'latents_std': [1.5792087316513062, 1.5769503116607666, 1.5864241123199463, 1.6454921960830688, 1.5336694717407227, 1.5587652921676636, 1.5838669538497925, 1.5659377574920654, 1.6860467195510864, 1.5192310810089111, 1.573639988899231, 1.5953549146652222, 1.5271092653274536, 1.6246271133422852, 1.7054023742675781, 1.607722282409668, 1.558642864227295, 1.5824549198150635, 1.6202995777130127, 1.6206320524215698, 1.6379750967025757, 1.6527063846588135, 1.498811960220337, 1.5706247091293335, 1.5854856967926025, 1.4828169345855713, 1.5693111419677734, 1.692481517791748, 1.6409776210784912, 1.6216280460357666, 1.6087706089019775, 1.5776633024215698]} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
666
+ "Перенос весов: 100%|██████████| 284/284 [00:00<00:00, 30094.80it/s]\n"
667
+ ]
668
+ },
669
+ {
670
+ "name": "stdout",
671
+ "output_type": "stream",
672
+ "text": [
673
+ "Статистика: {'перенесено': 292, 'несовпадение_размеров': 0, 'пропущено': 10}\n",
674
+ "\n",
675
+ "Неперенесенные ключи:\n"
676
+ ]
677
+ }
678
+ ],
679
+ "source": [
680
+ "import torch\n",
681
+ "from diffusers import AutoencoderKL, AsymmetricAutoencoderKL\n",
682
+ "from tqdm import tqdm\n",
683
+ "\n",
684
+ "\n",
685
+ "def log(message):\n",
686
+ " print(message)\n",
687
+ "\n",
688
+ "\n",
689
+ "def remap_key(old_key: str):\n",
690
+ " \"\"\"\n",
691
+ " Смещение только encoder.down_blocks\n",
692
+ " \"\"\"\n",
693
+ "\n",
694
+ " if \"encoder.down_blocks\" not in old_key:\n",
695
+ " return [old_key]\n",
696
+ "\n",
697
+ " parts = old_key.split(\".\")\n",
698
+ " block_id = int(parts[2])\n",
699
+ "\n",
700
+ " if block_id == 0:\n",
701
+ " # первый блок копируем дважды\n",
702
+ " return [\n",
703
+ " old_key.replace(\"down_blocks.0\", \"down_blocks.0\"),\n",
704
+ " old_key.replace(\"down_blocks.0\", \"down_blocks.1\"),\n",
705
+ " ]\n",
706
+ "\n",
707
+ " # остальные блоки сдвигаем\n",
708
+ " new_block = block_id + 1\n",
709
+ " return [old_key.replace(f\"down_blocks.{block_id}\", f\"down_blocks.{new_block}\")]\n",
710
+ "\n",
711
+ "\n",
712
+ "def main():\n",
713
+ " checkpoint_path_old = \"asymmetric_vae_new\"\n",
714
+ " checkpoint_path_new = \"vae16x32ch_empty\"\n",
715
+ "\n",
716
+ " device = \"cuda\"\n",
717
+ " dtype = torch.float32\n",
718
+ "\n",
719
+ " old_vae = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
720
+ " new_vae = AutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
721
+ "\n",
722
+ " old_state_dict = old_vae.state_dict()\n",
723
+ " new_state_dict = new_vae.state_dict()\n",
724
+ "\n",
725
+ " transferred_state_dict = {}\n",
726
+ " transferred_keys = set()\n",
727
+ "\n",
728
+ " transfer_stats = {\n",
729
+ " \"перенесено\": 0,\n",
730
+ " \"несовпадение_размеров\": 0,\n",
731
+ " \"пропущено\": 0\n",
732
+ " }\n",
733
+ "\n",
734
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
735
+ "\n",
736
+ " new_keys = remap_key(old_key)\n",
737
+ "\n",
738
+ " for new_key in new_keys:\n",
739
+ "\n",
740
+ " if new_key in new_state_dict:\n",
741
+ "\n",
742
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
743
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
744
+ " transferred_keys.add(new_key)\n",
745
+ " transfer_stats[\"перенесено\"] += 1\n",
746
+ " else:\n",
747
+ " log(\n",
748
+ " f\"✗ Несовпадение размеров: \"\n",
749
+ " f\"{old_key} {old_state_dict[old_key].shape} \"\n",
750
+ " f\"-> {new_key} {new_state_dict[new_key].shape}\"\n",
751
+ " )\n",
752
+ " transfer_stats[\"несовпадение_р��змеров\"] += 1\n",
753
+ " else:\n",
754
+ " transfer_stats[\"пропущено\"] += 1\n",
755
+ "\n",
756
+ " new_state_dict.update(transferred_state_dict)\n",
757
+ "\n",
758
+ " new_vae.load_state_dict(new_state_dict)\n",
759
+ " new_vae.save_pretrained(\"vae16x32ch\")\n",
760
+ "\n",
761
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
762
+ "\n",
763
+ " print(\"Статистика:\", transfer_stats)\n",
764
+ "\n",
765
+ " print(\"\\nНеперенесенные ключи:\")\n",
766
+ " for key in non_transferred_keys:\n",
767
+ " print(key)\n",
768
+ "\n",
769
+ "\n",
770
+ "if __name__ == \"__main__\":\n",
771
+ " main()"
772
+ ]
773
+ },
774
+ {
775
+ "cell_type": "code",
776
+ "execution_count": null,
777
+ "id": "fe8f1ceb-8d3e-4df5-a1dc-1b56a0d398a2",
778
+ "metadata": {},
779
+ "outputs": [],
780
+ "source": []
781
+ }
782
+ ],
783
+ "metadata": {
784
+ "kernelspec": {
785
+ "display_name": "Python3 (ipykernel)",
786
+ "language": "python",
787
+ "name": "python3"
788
+ },
789
+ "language_info": {
790
+ "codemirror_mode": {
791
+ "name": "ipython",
792
+ "version": 3
793
+ },
794
+ "file_extension": ".py",
795
+ "mimetype": "text/x-python",
796
+ "name": "python",
797
+ "nbconvert_exporter": "python",
798
+ "pygments_lexer": "ipython3",
799
+ "version": "3.12.12"
800
+ }
801
+ },
802
+ "nbformat": 4,
803
+ "nbformat_minor": 5
804
+ }
config.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.37.0",
4
+ "_name_or_path": "vae16x32ch",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 128,
9
+ 256,
10
+ 512,
11
+ 512
12
+ ],
13
+ "down_block_types": [
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D",
17
+ "DownEncoderBlock2D",
18
+ "DownEncoderBlock2D"
19
+ ],
20
+ "force_upcast": false,
21
+ "in_channels": 3,
22
+ "latent_channels": 32,
23
+ "latents_mean": [
24
+ -0.03542253375053406,
25
+ 0.20086465775966644,
26
+ -0.016413161531090736,
27
+ -0.0956302210688591,
28
+ -0.2672063112258911,
29
+ 0.2609933018684387,
30
+ -0.07806991040706635,
31
+ -0.48407721519470215,
32
+ 0.21844269335269928,
33
+ -0.1122383326292038,
34
+ 0.27197545766830444,
35
+ -0.18958772718906403,
36
+ 0.18776826560497284,
37
+ 0.0987580344080925,
38
+ 0.2837068736553192,
39
+ -0.4486690163612366,
40
+ 0.4816776514053345,
41
+ 0.02947971224784851,
42
+ -0.1337375044822693,
43
+ -0.39750921726226807,
44
+ -0.08513020724058151,
45
+ -0.054023586213588715,
46
+ -0.3943594992160797,
47
+ 0.23918119072914124,
48
+ -0.12466679513454437,
49
+ 0.09935147315263748,
50
+ 0.31858691573143005,
51
+ 0.48585832118988037,
52
+ -0.6416525840759277,
53
+ -0.15164820849895477,
54
+ -0.4693508744239807,
55
+ -0.13071806728839874
56
+ ],
57
+ "latents_std": [
58
+ 1.5792087316513062,
59
+ 1.5769503116607666,
60
+ 1.5864241123199463,
61
+ 1.6454921960830688,
62
+ 1.5336694717407227,
63
+ 1.5587652921676636,
64
+ 1.5838669538497925,
65
+ 1.5659377574920654,
66
+ 1.6860467195510864,
67
+ 1.5192310810089111,
68
+ 1.573639988899231,
69
+ 1.5953549146652222,
70
+ 1.5271092653274536,
71
+ 1.6246271133422852,
72
+ 1.7054023742675781,
73
+ 1.607722282409668,
74
+ 1.558642864227295,
75
+ 1.5824549198150635,
76
+ 1.6202995777130127,
77
+ 1.6206320524215698,
78
+ 1.6379750967025757,
79
+ 1.6527063846588135,
80
+ 1.498811960220337,
81
+ 1.5706247091293335,
82
+ 1.5854856967926025,
83
+ 1.4828169345855713,
84
+ 1.5693111419677734,
85
+ 1.692481517791748,
86
+ 1.6409776210784912,
87
+ 1.6216280460357666,
88
+ 1.6087706089019775,
89
+ 1.5776633024215698
90
+ ],
91
+ "layers_per_block": 2,
92
+ "mid_block_add_attention": true,
93
+ "norm_num_groups": 32,
94
+ "out_channels": 3,
95
+ "sample_size": 32,
96
+ "scaling_factor": 1.0,
97
+ "shift_factor": null,
98
+ "up_block_types": [
99
+ "UpDecoderBlock2D",
100
+ "UpDecoderBlock2D",
101
+ "UpDecoderBlock2D",
102
+ "UpDecoderBlock2D",
103
+ "UpDecoderBlock2D"
104
+ ],
105
+ "use_post_quant_conv": true,
106
+ "use_quant_conv": true
107
+ }
create_symmetric.ipynb ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 10,
6
+ "id": "407171be-ab46-442b-a0bd-83ca75173eba",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "AutoencoderKL(\n",
14
+ " (encoder): Encoder(\n",
15
+ " (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
16
+ " (down_blocks): ModuleList(\n",
17
+ " (0-1): 2 x DownEncoderBlock2D(\n",
18
+ " (resnets): ModuleList(\n",
19
+ " (0-1): 2 x ResnetBlock2D(\n",
20
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
21
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
22
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
23
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
24
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
25
+ " (nonlinearity): SiLU()\n",
26
+ " )\n",
27
+ " )\n",
28
+ " (downsamplers): ModuleList(\n",
29
+ " (0): Downsample2D(\n",
30
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))\n",
31
+ " )\n",
32
+ " )\n",
33
+ " )\n",
34
+ " (2): DownEncoderBlock2D(\n",
35
+ " (resnets): ModuleList(\n",
36
+ " (0): ResnetBlock2D(\n",
37
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
38
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
39
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
40
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
41
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
42
+ " (nonlinearity): SiLU()\n",
43
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
44
+ " )\n",
45
+ " (1): ResnetBlock2D(\n",
46
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
47
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
48
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
49
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
50
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
51
+ " (nonlinearity): SiLU()\n",
52
+ " )\n",
53
+ " )\n",
54
+ " (downsamplers): ModuleList(\n",
55
+ " (0): Downsample2D(\n",
56
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n",
57
+ " )\n",
58
+ " )\n",
59
+ " )\n",
60
+ " (3): DownEncoderBlock2D(\n",
61
+ " (resnets): ModuleList(\n",
62
+ " (0): ResnetBlock2D(\n",
63
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
64
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
65
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
66
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
67
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
68
+ " (nonlinearity): SiLU()\n",
69
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
70
+ " )\n",
71
+ " (1): ResnetBlock2D(\n",
72
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
73
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
74
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
75
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
76
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
77
+ " (nonlinearity): SiLU()\n",
78
+ " )\n",
79
+ " )\n",
80
+ " (downsamplers): ModuleList(\n",
81
+ " (0): Downsample2D(\n",
82
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))\n",
83
+ " )\n",
84
+ " )\n",
85
+ " )\n",
86
+ " (4): DownEncoderBlock2D(\n",
87
+ " (resnets): ModuleList(\n",
88
+ " (0-1): 2 x ResnetBlock2D(\n",
89
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
90
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
91
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
92
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
93
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
94
+ " (nonlinearity): SiLU()\n",
95
+ " )\n",
96
+ " )\n",
97
+ " )\n",
98
+ " )\n",
99
+ " (mid_block): UNetMidBlock2D(\n",
100
+ " (attentions): ModuleList(\n",
101
+ " (0): Attention(\n",
102
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
103
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
104
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
105
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
106
+ " (to_out): ModuleList(\n",
107
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
108
+ " (1): Dropout(p=0.0, inplace=False)\n",
109
+ " )\n",
110
+ " )\n",
111
+ " )\n",
112
+ " (resnets): ModuleList(\n",
113
+ " (0-1): 2 x ResnetBlock2D(\n",
114
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
115
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
116
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
117
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
118
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
119
+ " (nonlinearity): SiLU()\n",
120
+ " )\n",
121
+ " )\n",
122
+ " )\n",
123
+ " (conv_norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
124
+ " (conv_act): SiLU()\n",
125
+ " (conv_out): Conv2d(512, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
126
+ " )\n",
127
+ " (decoder): Decoder(\n",
128
+ " (conv_in): Conv2d(32, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
129
+ " (up_blocks): ModuleList(\n",
130
+ " (0-1): 2 x UpDecoderBlock2D(\n",
131
+ " (resnets): ModuleList(\n",
132
+ " (0-2): 3 x ResnetBlock2D(\n",
133
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
134
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
135
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
136
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
137
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
138
+ " (nonlinearity): SiLU()\n",
139
+ " )\n",
140
+ " )\n",
141
+ " (upsamplers): ModuleList(\n",
142
+ " (0): Upsample2D(\n",
143
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
144
+ " )\n",
145
+ " )\n",
146
+ " )\n",
147
+ " (2): UpDecoderBlock2D(\n",
148
+ " (resnets): ModuleList(\n",
149
+ " (0): ResnetBlock2D(\n",
150
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
151
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
152
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
153
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
154
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
155
+ " (nonlinearity): SiLU()\n",
156
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
157
+ " )\n",
158
+ " (1-2): 2 x ResnetBlock2D(\n",
159
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
160
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
161
+ " (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
162
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
163
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
164
+ " (nonlinearity): SiLU()\n",
165
+ " )\n",
166
+ " )\n",
167
+ " (upsamplers): ModuleList(\n",
168
+ " (0): Upsample2D(\n",
169
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
170
+ " )\n",
171
+ " )\n",
172
+ " )\n",
173
+ " (3): UpDecoderBlock2D(\n",
174
+ " (resnets): ModuleList(\n",
175
+ " (0): ResnetBlock2D(\n",
176
+ " (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)\n",
177
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
178
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
179
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
180
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
181
+ " (nonlinearity): SiLU()\n",
182
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
183
+ " )\n",
184
+ " (1-2): 2 x ResnetBlock2D(\n",
185
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
186
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
187
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
188
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
189
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
190
+ " (nonlinearity): SiLU()\n",
191
+ " )\n",
192
+ " )\n",
193
+ " (upsamplers): ModuleList(\n",
194
+ " (0): Upsample2D(\n",
195
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
196
+ " )\n",
197
+ " )\n",
198
+ " )\n",
199
+ " (4): UpDecoderBlock2D(\n",
200
+ " (resnets): ModuleList(\n",
201
+ " (0-2): 3 x ResnetBlock2D(\n",
202
+ " (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
203
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
204
+ " (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
205
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
206
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
207
+ " (nonlinearity): SiLU()\n",
208
+ " )\n",
209
+ " )\n",
210
+ " )\n",
211
+ " )\n",
212
+ " (mid_block): UNetMidBlock2D(\n",
213
+ " (attentions): ModuleList(\n",
214
+ " (0): Attention(\n",
215
+ " (group_norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
216
+ " (to_q): Linear(in_features=512, out_features=512, bias=True)\n",
217
+ " (to_k): Linear(in_features=512, out_features=512, bias=True)\n",
218
+ " (to_v): Linear(in_features=512, out_features=512, bias=True)\n",
219
+ " (to_out): ModuleList(\n",
220
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
221
+ " (1): Dropout(p=0.0, inplace=False)\n",
222
+ " )\n",
223
+ " )\n",
224
+ " )\n",
225
+ " (resnets): ModuleList(\n",
226
+ " (0-1): 2 x ResnetBlock2D(\n",
227
+ " (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
228
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
229
+ " (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
230
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
231
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
232
+ " (nonlinearity): SiLU()\n",
233
+ " )\n",
234
+ " )\n",
235
+ " )\n",
236
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)\n",
237
+ " (conv_act): SiLU()\n",
238
+ " (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
239
+ " )\n",
240
+ " (quant_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
241
+ " (post_quant_conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
242
+ ")\n"
243
+ ]
244
+ }
245
+ ],
246
+ "source": [
247
+ "from diffusers.models import AutoencoderKL\n",
248
+ "import torch\n",
249
+ "\n",
250
+ "config = {\n",
251
+ " \"_class_name\": \"AutoencoderKL\",\n",
252
+ " \"_diffusers_version\": \"0.36.0\",\n",
253
+ " \"act_fn\": \"silu\",\n",
254
+ " \"block_out_channels\": [\n",
255
+ " 128,\n",
256
+ " 128,\n",
257
+ " 256,\n",
258
+ " 512,\n",
259
+ " 512\n",
260
+ " ],\n",
261
+ " \"down_block_types\": [\n",
262
+ " \"DownEncoderBlock2D\",\n",
263
+ " \"DownEncoderBlock2D\",\n",
264
+ " \"DownEncoderBlock2D\",\n",
265
+ " \"DownEncoderBlock2D\",\n",
266
+ " \"DownEncoderBlock2D\"\n",
267
+ " ],\n",
268
+ " \"force_upcast\": False,\n",
269
+ " \"in_channels\": 3,\n",
270
+ " \"latent_channels\": 32,\n",
271
+ " \"latents_mean\": [\n",
272
+ " -0.03542253375053406,\n",
273
+ " 0.20086465775966644,\n",
274
+ " -0.016413161531090736,\n",
275
+ " -0.0956302210688591,\n",
276
+ " -0.2672063112258911,\n",
277
+ " 0.2609933018684387,\n",
278
+ " -0.07806991040706635,\n",
279
+ " -0.48407721519470215,\n",
280
+ " 0.21844269335269928,\n",
281
+ " -0.1122383326292038,\n",
282
+ " 0.27197545766830444,\n",
283
+ " -0.18958772718906403,\n",
284
+ " 0.18776826560497284,\n",
285
+ " 0.0987580344080925,\n",
286
+ " 0.2837068736553192,\n",
287
+ " -0.4486690163612366,\n",
288
+ " 0.4816776514053345,\n",
289
+ " 0.02947971224784851,\n",
290
+ " -0.1337375044822693,\n",
291
+ " -0.39750921726226807,\n",
292
+ " -0.08513020724058151,\n",
293
+ " -0.054023586213588715,\n",
294
+ " -0.3943594992160797,\n",
295
+ " 0.23918119072914124,\n",
296
+ " -0.12466679513454437,\n",
297
+ " 0.09935147315263748,\n",
298
+ " 0.31858691573143005,\n",
299
+ " 0.48585832118988037,\n",
300
+ " -0.6416525840759277,\n",
301
+ " -0.15164820849895477,\n",
302
+ " -0.4693508744239807,\n",
303
+ " -0.13071806728839874\n",
304
+ " ],\n",
305
+ " \"latents_std\": [\n",
306
+ " 1.5792087316513062,\n",
307
+ " 1.5769503116607666,\n",
308
+ " 1.5864241123199463,\n",
309
+ " 1.6454921960830688,\n",
310
+ " 1.5336694717407227,\n",
311
+ " 1.5587652921676636,\n",
312
+ " 1.5838669538497925,\n",
313
+ " 1.5659377574920654,\n",
314
+ " 1.6860467195510864,\n",
315
+ " 1.5192310810089111,\n",
316
+ " 1.573639988899231,\n",
317
+ " 1.5953549146652222,\n",
318
+ " 1.5271092653274536,\n",
319
+ " 1.6246271133422852,\n",
320
+ " 1.7054023742675781,\n",
321
+ " 1.607722282409668,\n",
322
+ " 1.558642864227295,\n",
323
+ " 1.5824549198150635,\n",
324
+ " 1.6202995777130127,\n",
325
+ " 1.6206320524215698,\n",
326
+ " 1.6379750967025757,\n",
327
+ " 1.6527063846588135,\n",
328
+ " 1.498811960220337,\n",
329
+ " 1.5706247091293335,\n",
330
+ " 1.5854856967926025,\n",
331
+ " 1.4828169345855713,\n",
332
+ " 1.5693111419677734,\n",
333
+ " 1.692481517791748,\n",
334
+ " 1.6409776210784912,\n",
335
+ " 1.6216280460357666,\n",
336
+ " 1.6087706089019775,\n",
337
+ " 1.5776633024215698\n",
338
+ " ],\n",
339
+ " \"layers_per_block\": 2,\n",
340
+ " \"mid_block_add_attention\": True,\n",
341
+ " \"norm_num_groups\": 32,\n",
342
+ " \"out_channels\": 3,\n",
343
+ " \"sample_size\": 32,\n",
344
+ " \"scaling_factor\": 1.0,\n",
345
+ " \"shift_factor\": 0.0,\n",
346
+ " \"up_block_types\": [\n",
347
+ " \"UpDecoderBlock2D\",\n",
348
+ " \"UpDecoderBlock2D\",\n",
349
+ " \"UpDecoderBlock2D\",\n",
350
+ " \"UpDecoderBlock2D\",\n",
351
+ " \"UpDecoderBlock2D\"\n",
352
+ " ],\n",
353
+ " \"use_post_quant_conv\": True,\n",
354
+ " \"use_quant_conv\": True\n",
355
+ "}\n",
356
+ "\n",
357
+ "\n",
358
+ "vae = AutoencoderKL(\n",
359
+ " act_fn=config[\"act_fn\"],\n",
360
+ " block_out_channels=config[\"block_out_channels\"],\n",
361
+ " down_block_types=config[\"down_block_types\"],\n",
362
+ " up_block_types=config[\"up_block_types\"],\n",
363
+ " in_channels=config[\"in_channels\"],\n",
364
+ " out_channels=config[\"out_channels\"],\n",
365
+ " latent_channels=config[\"latent_channels\"],\n",
366
+ " layers_per_block=config[\"layers_per_block\"],\n",
367
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
368
+ " sample_size=config[\"sample_size\"],\n",
369
+ " scaling_factor=config[\"scaling_factor\"],\n",
370
+ " force_upcast=config[\"force_upcast\"],\n",
371
+ " mid_block_add_attention=config[\"mid_block_add_attention\"],\n",
372
+ " use_quant_conv=config[\"use_quant_conv\"],\n",
373
+ " use_post_quant_conv=config[\"use_post_quant_conv\"],\n",
374
+ " latents_mean=(config[\"latents_mean\"]),\n",
375
+ " latents_std=(config[\"latents_std\"]),\n",
376
+ ")\n",
377
+ "\n",
378
+ "vae.save_pretrained(\"vae16x32ch_empty\")\n",
379
+ "\n",
380
+ "print(vae)"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 6,
386
+ "id": "a2950158-5203-42b9-8791-e231ddbf1063",
387
+ "metadata": {},
388
+ "outputs": [
389
+ {
390
+ "name": "stderr",
391
+ "output_type": "stream",
392
+ "text": [
393
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False, 'latents_mean': [-0.03542253375053406, 0.20086465775966644, -0.016413161531090736, -0.0956302210688591, -0.2672063112258911, 0.2609933018684387, -0.07806991040706635, -0.48407721519470215, 0.21844269335269928, -0.1122383326292038, 0.27197545766830444, -0.18958772718906403, 0.18776826560497284, 0.0987580344080925, 0.2837068736553192, -0.4486690163612366, 0.4816776514053345, 0.02947971224784851, -0.1337375044822693, -0.39750921726226807, -0.08513020724058151, -0.054023586213588715, -0.3943594992160797, 0.23918119072914124, -0.12466679513454437, 0.09935147315263748, 0.31858691573143005, 0.48585832118988037, -0.6416525840759277, -0.15164820849895477, -0.4693508744239807, -0.13071806728839874], 'latents_std': [1.5792087316513062, 1.5769503116607666, 1.5864241123199463, 1.6454921960830688, 1.5336694717407227, 1.5587652921676636, 1.5838669538497925, 1.5659377574920654, 1.6860467195510864, 1.5192310810089111, 1.573639988899231, 1.5953549146652222, 1.5271092653274536, 1.6246271133422852, 1.7054023742675781, 1.607722282409668, 1.558642864227295, 1.5824549198150635, 1.6202995777130127, 1.6206320524215698, 1.6379750967025757, 1.6527063846588135, 1.498811960220337, 1.5706247091293335, 1.5854856967926025, 1.4828169345855713, 1.5693111419677734, 1.692481517791748, 1.6409776210784912, 1.6216280460357666, 1.6087706089019775, 1.5776633024215698]} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
394
+ "Перенос весов: 100%|██████████| 284/284 [00:00<00:00, 38362.12it/s]\n"
395
+ ]
396
+ },
397
+ {
398
+ "name": "stdout",
399
+ "output_type": "stream",
400
+ "text": [
401
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.conv1.weight (torch.Size([256, 128, 3, 3])) -> encoder.down_blocks.1.resnets.0.conv1.weight (torch.Size([128, 128, 3, 3]))\n",
402
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.conv1.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.0.conv1.bias (torch.Size([128]))\n",
403
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.norm2.weight (torch.Size([256])) -> encoder.down_blocks.1.resnets.0.norm2.weight (torch.Size([128]))\n",
404
+ "✗ Нес��впадение размеров: encoder.down_blocks.1.resnets.0.norm2.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.0.norm2.bias (torch.Size([128]))\n",
405
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.conv2.weight (torch.Size([256, 256, 3, 3])) -> encoder.down_blocks.1.resnets.0.conv2.weight (torch.Size([128, 128, 3, 3]))\n",
406
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.0.conv2.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.0.conv2.bias (torch.Size([128]))\n",
407
+ "? Ключ не найден в новой модели: encoder.down_blocks.1.resnets.0.conv_shortcut.weight -> torch.Size([256, 128, 1, 1])\n",
408
+ "? Ключ не найден в новой модели: encoder.down_blocks.1.resnets.0.conv_shortcut.bias -> torch.Size([256])\n",
409
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.norm1.weight (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.norm1.weight (torch.Size([128]))\n",
410
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.norm1.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.norm1.bias (torch.Size([128]))\n",
411
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.conv1.weight (torch.Size([256, 256, 3, 3])) -> encoder.down_blocks.1.resnets.1.conv1.weight (torch.Size([128, 128, 3, 3]))\n",
412
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.conv1.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.conv1.bias (torch.Size([128]))\n",
413
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.norm2.weight (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.norm2.weight (torch.Size([128]))\n",
414
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.norm2.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.norm2.bias (torch.Size([128]))\n",
415
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.conv2.weight (torch.Size([256, 256, 3, 3])) -> encoder.down_blocks.1.resnets.1.conv2.weight (torch.Size([128, 128, 3, 3]))\n",
416
+ "✗ Несовпадение размеров: encoder.down_blocks.1.resnets.1.conv2.bias (torch.Size([256])) -> encoder.down_blocks.1.resnets.1.conv2.bias (torch.Size([128]))\n",
417
+ "✗ Несовпадение размеров: encoder.down_blocks.1.downsamplers.0.conv.weight (torch.Size([256, 256, 3, 3])) -> encoder.down_blocks.1.downsamplers.0.conv.weight (torch.Size([128, 128, 3, 3]))\n",
418
+ "✗ Несовпадение размеров: encoder.down_blocks.1.downsamplers.0.conv.bias (torch.Size([256])) -> encoder.down_blocks.1.downsamplers.0.conv.bias (torch.Size([128]))\n",
419
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.norm1.weight (torch.Size([256])) -> encoder.down_blocks.2.resnets.0.norm1.weight (torch.Size([128]))\n",
420
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.norm1.bias (torch.Size([256])) -> encoder.down_blocks.2.resnets.0.norm1.bias (torch.Size([128]))\n",
421
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv1.weight (torch.Size([512, 256, 3, 3])) -> encoder.down_blocks.2.resnets.0.conv1.weight (torch.Size([256, 128, 3, 3]))\n",
422
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv1.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.conv1.bias (torch.Size([256]))\n",
423
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.norm2.weight (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.norm2.weight (torch.Size([256]))\n",
424
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.norm2.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.norm2.bias (torch.Size([256]))\n",
425
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv2.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.2.resnets.0.conv2.weight (torch.Size([256, 256, 3, 3]))\n",
426
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv2.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.conv2.bias (torch.Size([256]))\n",
427
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv_shortcut.weight (torch.Size([512, 256, 1, 1])) -> encoder.down_blocks.2.resnets.0.conv_shortcut.weight (torch.Size([256, 128, 1, 1]))\n",
428
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.0.conv_shortcut.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.0.conv_shortcut.bias (torch.Size([256]))\n",
429
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.norm1.weight (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.norm1.weight (torch.Size([256]))\n",
430
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.norm1.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.norm1.bias (torch.Size([256]))\n",
431
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.conv1.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.2.resnets.1.conv1.weight (torch.Size([256, 256, 3, 3]))\n",
432
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.conv1.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.conv1.bias (torch.Size([256]))\n",
433
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.norm2.weight (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.norm2.weight (torch.Size([256]))\n",
434
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.norm2.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.norm2.bias (torch.Size([256]))\n",
435
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.conv2.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.2.resnets.1.conv2.weight (torch.Size([256, 256, 3, 3]))\n",
436
+ "✗ Несовпадение размеров: encoder.down_blocks.2.resnets.1.conv2.bias (torch.Size([512])) -> encoder.down_blocks.2.resnets.1.conv2.bias (torch.Size([256]))\n",
437
+ "✗ Несовпадение размеров: encoder.down_blocks.2.downsamplers.0.conv.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.2.downsamplers.0.conv.weight (torch.Size([256, 256, 3, 3]))\n",
438
+ "✗ Несовпадение размеров: encoder.down_blocks.2.downsamplers.0.conv.bias (torch.Size([512])) -> encoder.down_blocks.2.downsamplers.0.conv.bias (torch.Size([256]))\n",
439
+ "✗ Несовпадение размеров: encoder.down_blocks.3.resnets.0.norm1.weight (torch.Size([512])) -> encoder.down_blocks.3.resnets.0.norm1.weight (torch.Size([256]))\n",
440
+ "✗ Несовпадение размеров: encoder.down_blocks.3.resnets.0.norm1.bias (torch.Size([512])) -> encoder.down_blocks.3.resnets.0.norm1.bias (torch.Size([256]))\n",
441
+ "✗ Несовпадение размеров: encoder.down_blocks.3.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> encoder.down_blocks.3.resnets.0.conv1.weight (torch.Size([512, 256, 3, 3]))\n",
442
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.norm1.weight -> torch.Size([128])\n",
443
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.norm1.bias -> torch.Size([128])\n",
444
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.conv1.weight -> torch.Size([128, 128, 3, 3])\n",
445
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.conv1.bias -> torch.Size([128])\n",
446
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.norm2.weight -> torch.Size([128])\n",
447
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.norm2.bias -> torch.Size([128])\n",
448
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.conv2.weight -> torch.Size([128, 128, 3, 3])\n",
449
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.0.conv2.bias -> torch.Size([128])\n",
450
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.norm1.weight -> torch.Size([128])\n",
451
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.norm1.bias -> torch.Size([128])\n",
452
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.conv1.weight -> torch.Size([128, 128, 3, 3])\n",
453
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.conv1.bias -> torch.Size([128])\n",
454
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.norm2.weight -> torch.Size([128])\n",
455
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.norm2.bias -> torch.Size([128])\n",
456
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.conv2.weight -> torch.Size([128, 128, 3, 3])\n",
457
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.1.conv2.bias -> torch.Size([128])\n",
458
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.norm1.weight -> torch.Size([128])\n",
459
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.norm1.bias -> torch.Size([128])\n",
460
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.conv1.weight -> torch.Size([128, 128, 3, 3])\n",
461
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.conv1.bias -> torch.Size([128])\n",
462
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.norm2.weight -> torch.Size([128])\n",
463
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.norm2.bias -> torch.Size([128])\n",
464
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.conv2.weight -> torch.Size([128, 128, 3, 3])\n",
465
+ "? Ключ не найден в новой модели: decoder.up_blocks.4.resnets.2.conv2.bias -> torch.Size([128])\n",
466
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.0.weight -> torch.Size([128, 3, 3, 3])\n",
467
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.0.bias -> torch.Size([128])\n",
468
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.1.weight -> torch.Size([256, 128, 3, 3])\n",
469
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.1.bias -> torch.Size([256])\n",
470
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.2.weight -> torch.Size([512, 256, 4, 4])\n",
471
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.2.bias -> torch.Size([512])\n",
472
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.3.weight -> torch.Size([512, 512, 4, 4])\n",
473
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.3.bias -> torch.Size([512])\n",
474
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.4.weight -> torch.Size([512, 512, 4, 4])\n",
475
+ "? Ключ не найден в новой модели: decoder.condition_encoder.layers.4.bias -> torch.Size([512])\n",
476
+ "Статистика переноса: {'перенесено': 209, 'несовпадение_размеров': 39, 'пропущено': 36}\n",
477
+ "Неперенесенные ключи в новой модели:\n",
478
+ "encoder.down_blocks.1.downsamplers.0.conv.bias\n",
479
+ "encoder.down_blocks.1.downsamplers.0.conv.weight\n",
480
+ "encoder.down_blocks.1.resnets.0.conv1.bias\n",
481
+ "encoder.down_blocks.1.resnets.0.conv1.weight\n",
482
+ "encoder.down_blocks.1.resnets.0.conv2.bias\n",
483
+ "encoder.down_blocks.1.resnets.0.conv2.weight\n",
484
+ "encoder.down_blocks.1.resnets.0.norm2.bias\n",
485
+ "encoder.down_blocks.1.resnets.0.norm2.weight\n",
486
+ "encoder.down_blocks.1.resnets.1.conv1.bias\n",
487
+ "encoder.down_blocks.1.resnets.1.conv1.weight\n",
488
+ "encoder.down_blocks.1.resnets.1.conv2.bias\n",
489
+ "encoder.down_blocks.1.resnets.1.conv2.weight\n",
490
+ "encoder.down_blocks.1.resnets.1.norm1.bias\n",
491
+ "encoder.down_blocks.1.resnets.1.norm1.weight\n",
492
+ "encoder.down_blocks.1.resnets.1.norm2.bias\n",
493
+ "encoder.down_blocks.1.resnets.1.norm2.weight\n",
494
+ "encoder.down_blocks.2.downsamplers.0.conv.bias\n",
495
+ "encoder.down_blocks.2.downsamplers.0.conv.weight\n",
496
+ "encoder.down_blocks.2.resnets.0.conv1.bias\n",
497
+ "encoder.down_blocks.2.resnets.0.conv1.weight\n",
498
+ "encoder.down_blocks.2.resnets.0.conv2.bias\n",
499
+ "encoder.down_blocks.2.resnets.0.conv2.weight\n",
500
+ "encoder.down_blocks.2.resnets.0.conv_shortcut.bias\n",
501
+ "encoder.down_blocks.2.resnets.0.conv_shortcut.weight\n",
502
+ "encoder.down_blocks.2.resnets.0.norm1.bias\n",
503
+ "encoder.down_blocks.2.resnets.0.norm1.weight\n",
504
+ "encoder.down_blocks.2.resnets.0.norm2.bias\n",
505
+ "encoder.down_blocks.2.resnets.0.norm2.weight\n",
506
+ "encoder.down_blocks.2.resnets.1.conv1.bias\n",
507
+ "encoder.down_blocks.2.resnets.1.conv1.weight\n",
508
+ "encoder.down_blocks.2.resnets.1.conv2.bias\n",
509
+ "encoder.down_blocks.2.resnets.1.conv2.weight\n",
510
+ "encoder.down_blocks.2.resnets.1.norm1.bias\n",
511
+ "encoder.down_blocks.2.resnets.1.norm1.weight\n",
512
+ "encoder.down_blocks.2.resnets.1.norm2.bias\n",
513
+ "encoder.down_blocks.2.resnets.1.norm2.weight\n",
514
+ "encoder.down_blocks.3.downsamplers.0.conv.bias\n",
515
+ "encoder.down_blocks.3.downsamplers.0.conv.weight\n",
516
+ "encoder.down_blocks.3.resnets.0.conv1.weight\n",
517
+ "encoder.down_blocks.3.resnets.0.conv_shortcut.bias\n",
518
+ "encoder.down_blocks.3.resnets.0.conv_shortcut.weight\n",
519
+ "encoder.down_blocks.3.resnets.0.norm1.bias\n",
520
+ "encoder.down_blocks.3.resnets.0.norm1.weight\n"
521
+ ]
522
+ }
523
+ ],
524
+ "source": [
525
+ "import torch\n",
526
+ "from diffusers import AutoencoderKL,AsymmetricAutoencoderKL\n",
527
+ "from tqdm import tqdm\n",
528
+ "import torch.nn.init as init\n",
529
+ "\n",
530
+ "def log(message):\n",
531
+ " print(message)\n",
532
+ "\n",
533
+ "def main():\n",
534
+ " checkpoint_path_old = \"asymmetric_vae_new\"\n",
535
+ " checkpoint_path_new = \"vae16x32ch_empty\"\n",
536
+ " device = \"cuda\"\n",
537
+ " dtype = torch.float32\n",
538
+ "\n",
539
+ " # Загрузка моделей\n",
540
+ " old_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
541
+ " new_unet = AutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
542
+ "\n",
543
+ " old_state_dict = old_unet.state_dict()\n",
544
+ " new_state_dict = new_unet.state_dict()\n",
545
+ "\n",
546
+ " transferred_state_dict = {}\n",
547
+ " transfer_stats = {\n",
548
+ " \"перенесено\": 0,\n",
549
+ " \"несовпадение_размеров\": 0,\n",
550
+ " \"пропущено\": 0\n",
551
+ " }\n",
552
+ "\n",
553
+ " transferred_keys = set()\n",
554
+ "\n",
555
+ " # Обрабатываем каждый ключ старой модели\n",
556
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
557
+ " new_key = old_key\n",
558
+ "\n",
559
+ " if new_key in new_state_dict:\n",
560
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
561
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
562
+ " transferred_keys.add(new_key)\n",
563
+ " transfer_stats[\"перенесено\"] += 1\n",
564
+ " else:\n",
565
+ " log(f\"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})\")\n",
566
+ " transfer_stats[\"несовпадение_размеров\"] += 1\n",
567
+ " else:\n",
568
+ " log(f\"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}\")\n",
569
+ " transfer_stats[\"пропущено\"] += 1\n",
570
+ "\n",
571
+ " # Обновляем состояние новой модели перенесенными весами\n",
572
+ " new_state_dict.update(transferred_state_dict)\n",
573
+ " \n",
574
+ " # Инициализируем веса для нового mid блока\n",
575
+ " #new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)\n",
576
+ " \n",
577
+ " new_unet.load_state_dict(new_state_dict)\n",
578
+ " new_unet.save_pretrained(\"vae16x32ch\")\n",
579
+ "\n",
580
+ " # Получаем список неперенесенных ключей\n",
581
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
582
+ "\n",
583
+ " print(\"Статистика переноса:\", transfer_stats)\n",
584
+ " print(\"Неперенесенные ключи в новой модели:\")\n",
585
+ " for key in non_transferred_keys:\n",
586
+ " print(key)\n",
587
+ "\n",
588
+ "if __name__ == \"__main__\":\n",
589
+ " main()"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": 1,
595
+ "id": "b316ee6c-d295-4396-9177-78e39a53055b",
596
+ "metadata": {},
597
+ "outputs": [
598
+ {
599
+ "name": "stderr",
600
+ "output_type": "stream",
601
+ "text": [
602
+ "The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
603
+ ]
604
+ },
605
+ {
606
+ "name": "stdout",
607
+ "output_type": "stream",
608
+ "text": [
609
+ "ok\n"
610
+ ]
611
+ }
612
+ ],
613
+ "source": [
614
+ "import torch\n",
615
+ "\n",
616
+ "from torchvision import transforms, utils\n",
617
+ "\n",
618
+ "import diffusers\n",
619
+ "from diffusers import AsymmetricAutoencoderKL\n",
620
+ "\n",
621
+ "from diffusers.utils import load_image\n",
622
+ "\n",
623
+ "def crop_image_to_nearest_divisible_by_8(img):\n",
624
+ " # Check if the image height and width are divisible by 8\n",
625
+ " if img.shape[1] % 8 == 0 and img.shape[2] % 8 == 0:\n",
626
+ " return img\n",
627
+ " else:\n",
628
+ " # Calculate the closest lower resolution divisible by 8\n",
629
+ " new_height = img.shape[1] - (img.shape[1] % 8)\n",
630
+ " new_width = img.shape[2] - (img.shape[2] % 8)\n",
631
+ " \n",
632
+ " # Use CenterCrop to crop the image\n",
633
+ " transform = transforms.CenterCrop((new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR)\n",
634
+ " img = transform(img).to(torch.float32).clamp(-1, 1)\n",
635
+ " \n",
636
+ " return img\n",
637
+ " \n",
638
+ "to_tensor = transforms.ToTensor()\n",
639
+ "\n",
640
+ "device = \"cuda\"\n",
641
+ "dtype=torch.float16\n",
642
+ "vae = AsymmetricAutoencoderKL.from_pretrained(\"asymmetric_vae\",torch_dtype=dtype).to(device).eval()\n",
643
+ "\n",
644
+ "image = load_image(\"123456789.jpg\")\n",
645
+ "\n",
646
+ "image = crop_image_to_nearest_divisible_by_8(to_tensor(image)).unsqueeze(0).to(device,dtype=dtype)\n",
647
+ "\n",
648
+ "upscaled_image = vae(image).sample\n",
649
+ "#vae.config.scaled_factor\n",
650
+ "# Save the reconstructed image\n",
651
+ "utils.save_image(upscaled_image, \"test.png\")\n",
652
+ "print('ok')"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "code",
657
+ "execution_count": 11,
658
+ "id": "5a01b8e9-73c9-4da7-a097-e334019bd8e9",
659
+ "metadata": {},
660
+ "outputs": [
661
+ {
662
+ "name": "stderr",
663
+ "output_type": "stream",
664
+ "text": [
665
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False, 'latents_mean': [-0.03542253375053406, 0.20086465775966644, -0.016413161531090736, -0.0956302210688591, -0.2672063112258911, 0.2609933018684387, -0.07806991040706635, -0.48407721519470215, 0.21844269335269928, -0.1122383326292038, 0.27197545766830444, -0.18958772718906403, 0.18776826560497284, 0.0987580344080925, 0.2837068736553192, -0.4486690163612366, 0.4816776514053345, 0.02947971224784851, -0.1337375044822693, -0.39750921726226807, -0.08513020724058151, -0.054023586213588715, -0.3943594992160797, 0.23918119072914124, -0.12466679513454437, 0.09935147315263748, 0.31858691573143005, 0.48585832118988037, -0.6416525840759277, -0.15164820849895477, -0.4693508744239807, -0.13071806728839874], 'latents_std': [1.5792087316513062, 1.5769503116607666, 1.5864241123199463, 1.6454921960830688, 1.5336694717407227, 1.5587652921676636, 1.5838669538497925, 1.5659377574920654, 1.6860467195510864, 1.5192310810089111, 1.573639988899231, 1.5953549146652222, 1.5271092653274536, 1.6246271133422852, 1.7054023742675781, 1.607722282409668, 1.558642864227295, 1.5824549198150635, 1.6202995777130127, 1.6206320524215698, 1.6379750967025757, 1.6527063846588135, 1.498811960220337, 1.5706247091293335, 1.5854856967926025, 1.4828169345855713, 1.5693111419677734, 1.692481517791748, 1.6409776210784912, 1.6216280460357666, 1.6087706089019775, 1.5776633024215698]} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
666
+ "Перенос весов: 100%|██████████| 284/284 [00:00<00:00, 30094.80it/s]\n"
667
+ ]
668
+ },
669
+ {
670
+ "name": "stdout",
671
+ "output_type": "stream",
672
+ "text": [
673
+ "Статистика: {'перенесено': 292, 'несовпадение_размеров': 0, 'пропущено': 10}\n",
674
+ "\n",
675
+ "Неперенесенные ключи:\n"
676
+ ]
677
+ }
678
+ ],
679
+ "source": [
680
+ "import torch\n",
681
+ "from diffusers import AutoencoderKL, AsymmetricAutoencoderKL\n",
682
+ "from tqdm import tqdm\n",
683
+ "\n",
684
+ "\n",
685
+ "def log(message):\n",
686
+ " print(message)\n",
687
+ "\n",
688
+ "\n",
689
+ "def remap_key(old_key: str):\n",
690
+ " \"\"\"\n",
691
+ " Смещение только encoder.down_blocks\n",
692
+ " \"\"\"\n",
693
+ "\n",
694
+ " if \"encoder.down_blocks\" not in old_key:\n",
695
+ " return [old_key]\n",
696
+ "\n",
697
+ " parts = old_key.split(\".\")\n",
698
+ " block_id = int(parts[2])\n",
699
+ "\n",
700
+ " if block_id == 0:\n",
701
+ " # первый блок копируем дважды\n",
702
+ " return [\n",
703
+ " old_key.replace(\"down_blocks.0\", \"down_blocks.0\"),\n",
704
+ " old_key.replace(\"down_blocks.0\", \"down_blocks.1\"),\n",
705
+ " ]\n",
706
+ "\n",
707
+ " # остальные блоки сдвигаем\n",
708
+ " new_block = block_id + 1\n",
709
+ " return [old_key.replace(f\"down_blocks.{block_id}\", f\"down_blocks.{new_block}\")]\n",
710
+ "\n",
711
+ "\n",
712
+ "def main():\n",
713
+ " checkpoint_path_old = \"asymmetric_vae_new\"\n",
714
+ " checkpoint_path_new = \"vae16x32ch_empty\"\n",
715
+ "\n",
716
+ " device = \"cuda\"\n",
717
+ " dtype = torch.float32\n",
718
+ "\n",
719
+ " old_vae = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_old).to(device, dtype=dtype)\n",
720
+ " new_vae = AutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)\n",
721
+ "\n",
722
+ " old_state_dict = old_vae.state_dict()\n",
723
+ " new_state_dict = new_vae.state_dict()\n",
724
+ "\n",
725
+ " transferred_state_dict = {}\n",
726
+ " transferred_keys = set()\n",
727
+ "\n",
728
+ " transfer_stats = {\n",
729
+ " \"перенесено\": 0,\n",
730
+ " \"несовпадение_размеров\": 0,\n",
731
+ " \"пропущено\": 0\n",
732
+ " }\n",
733
+ "\n",
734
+ " for old_key in tqdm(old_state_dict.keys(), desc=\"Перенос весов\"):\n",
735
+ "\n",
736
+ " new_keys = remap_key(old_key)\n",
737
+ "\n",
738
+ " for new_key in new_keys:\n",
739
+ "\n",
740
+ " if new_key in new_state_dict:\n",
741
+ "\n",
742
+ " if old_state_dict[old_key].shape == new_state_dict[new_key].shape:\n",
743
+ " transferred_state_dict[new_key] = old_state_dict[old_key].clone()\n",
744
+ " transferred_keys.add(new_key)\n",
745
+ " transfer_stats[\"перенесено\"] += 1\n",
746
+ " else:\n",
747
+ " log(\n",
748
+ " f\"✗ Несовпадение размеров: \"\n",
749
+ " f\"{old_key} {old_state_dict[old_key].shape} \"\n",
750
+ " f\"-> {new_key} {new_state_dict[new_key].shape}\"\n",
751
+ " )\n",
752
+ " transfer_stats[\"несовпадение_р��змеров\"] += 1\n",
753
+ " else:\n",
754
+ " transfer_stats[\"пропущено\"] += 1\n",
755
+ "\n",
756
+ " new_state_dict.update(transferred_state_dict)\n",
757
+ "\n",
758
+ " new_vae.load_state_dict(new_state_dict)\n",
759
+ " new_vae.save_pretrained(\"vae16x32ch\")\n",
760
+ "\n",
761
+ " non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)\n",
762
+ "\n",
763
+ " print(\"Статистика:\", transfer_stats)\n",
764
+ "\n",
765
+ " print(\"\\nНеперенесенные ключи:\")\n",
766
+ " for key in non_transferred_keys:\n",
767
+ " print(key)\n",
768
+ "\n",
769
+ "\n",
770
+ "if __name__ == \"__main__\":\n",
771
+ " main()"
772
+ ]
773
+ },
774
+ {
775
+ "cell_type": "code",
776
+ "execution_count": null,
777
+ "id": "fe8f1ceb-8d3e-4df5-a1dc-1b56a0d398a2",
778
+ "metadata": {},
779
+ "outputs": [],
780
+ "source": []
781
+ }
782
+ ],
783
+ "metadata": {
784
+ "kernelspec": {
785
+ "display_name": "Python3 (ipykernel)",
786
+ "language": "python",
787
+ "name": "python3"
788
+ },
789
+ "language_info": {
790
+ "codemirror_mode": {
791
+ "name": "ipython",
792
+ "version": 3
793
+ },
794
+ "file_extension": ".py",
795
+ "mimetype": "text/x-python",
796
+ "name": "python",
797
+ "nbconvert_exporter": "python",
798
+ "pygments_lexer": "ipython3",
799
+ "version": "3.12.12"
800
+ }
801
+ },
802
+ "nbformat": 4,
803
+ "nbformat_minor": 5
804
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bdc3ae2397d59d1b7541aa09b3cd0727f5ffc2dd587f4485ef480c9113a275d
3
+ size 343311604
train_vae_16x.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
+ from collections import deque
29
+
30
+ # --------------------------- Параметры ---------------------------
31
+ ds_path = "/workspace/d23"
32
+ project = "vae16x32ch"
33
+ batch_size = 1
34
+ base_learning_rate = 6e-6
35
+ min_learning_rate = 7e-7
36
+ num_epochs = 1
37
+ sample_interval_share = 30
38
+ use_wandb = True
39
+ save_model = True
40
+ use_decay = True
41
+ optimizer_type = "adam8bit"
42
+ dtype = torch.float32
43
+
44
+ model_resolution = 768 #448 #288
45
+ high_resolution = 768 #896 #576
46
+ limit = 0
47
+ save_barrier = 1.3
48
+ warmup_percent = 0.005
49
+ percentile_clipping = 99
50
+ beta2 = 0.997
51
+ eps = 1e-8
52
+ clip_grad_norm = 1.0
53
+ mixed_precision = "no"
54
+ gradient_accumulation_steps = 1
55
+ generated_folder = "samples"
56
+ save_as = "vae16x32ch_new"
57
+ num_workers = 0
58
+ device = None
59
+ torch.backends.cuda.matmul.allow_tf32 = True
60
+ torch.backends.cudnn.allow_tf32 = True
61
+ # Включение Flash Attention 2/SDPA #MAX_JOBS=4 pip install flash-attn --no-build-isolation
62
+ torch.backends.cuda.enable_flash_sdp(True)
63
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
64
+ torch.backends.cuda.enable_math_sdp(False)
65
+
66
+ # --- Режимы обучения ---
67
+ # QWEN: учим только декодер
68
+ train_decoder_only = False
69
+ train_up_only = False
70
+ full_training = True # если True — учим весь VAE и добавляем KL (ниже)
71
+ kl_ratio = 0.00
72
+
73
+ # Доли лоссов
74
+ loss_ratios = {
75
+ "lpips": 0.70,#0.50,
76
+ "fdl" : 0.10,#0.25,
77
+ "edge": 0.05,
78
+ "mse": 0.10,
79
+ "mae": 0.05,
80
+ "kl": 0.00, # активируем при full_training=True
81
+ }
82
+ median_coeff_steps = 250
83
+
84
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
85
+
86
+ # QWEN: конфиг загрузки модели
87
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
88
+
89
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
90
+
91
+ accelerator = Accelerator(
92
+ mixed_precision=mixed_precision,
93
+ gradient_accumulation_steps=gradient_accumulation_steps
94
+ )
95
+ device = accelerator.device
96
+
97
+ # reproducibility
98
+ seed = int(datetime.now().strftime("%Y%m%d")) + 13
99
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
100
+ torch.backends.cudnn.benchmark = False
101
+
102
+ # --------------------------- WandB ---------------------------
103
+ if use_wandb and accelerator.is_main_process:
104
+ wandb.init(project=project, config={
105
+ "batch_size": batch_size,
106
+ "base_learning_rate": base_learning_rate,
107
+ "num_epochs": num_epochs,
108
+ "optimizer_type": optimizer_type,
109
+ "model_resolution": model_resolution,
110
+ "high_resolution": high_resolution,
111
+ "gradient_accumulation_steps": gradient_accumulation_steps,
112
+ "train_decoder_only": train_decoder_only,
113
+ "full_training": full_training,
114
+ "kl_ratio": kl_ratio,
115
+ "vae_kind": vae_kind,
116
+ })
117
+
118
+ # --------------------------- VAE ---------------------------
119
+ def get_core_model(model):
120
+ m = model
121
+ # если модель уже обёрнута torch.compile
122
+ if hasattr(m, "_orig_mod"):
123
+ m = m._orig_mod
124
+ return m
125
+
126
+ def is_video_vae(model) -> bool:
127
+ # WAN/Qwen — это видео-VAEs
128
+ if vae_kind in ("wan", "qwen"):
129
+ return True
130
+ # fallback по структуре (если понадобится)
131
+ try:
132
+ core = get_core_model(model)
133
+ enc = getattr(core, "encoder", None)
134
+ conv_in = getattr(enc, "conv_in", None)
135
+ w = getattr(conv_in, "weight", None)
136
+ if isinstance(w, torch.nn.Parameter):
137
+ return w.ndim == 5
138
+ except Exception:
139
+ pass
140
+ return False
141
+
142
+ # загрузка
143
+ if vae_kind == "qwen":
144
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
145
+ else:
146
+ if vae_kind == "wan":
147
+ vae = AutoencoderKLWan.from_pretrained(project)
148
+ else:
149
+ # старое поведение (пример)
150
+ if model_resolution==high_resolution:
151
+ vae = AutoencoderKL.from_pretrained(project)
152
+ else:
153
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
154
+
155
+ vae = vae.to(dtype)
156
+
157
+ # torch.compile (опционально)
158
+ if hasattr(torch, "compile"):
159
+ try:
160
+ vae = torch.compile(vae)
161
+ except Exception as e:
162
+ print(f"[WARN] torch.compile failed: {e}")
163
+
164
+ # --------------------------- Freeze/Unfreeze ---------------------------
165
+ core = get_core_model(vae)
166
+
167
+ for p in core.parameters():
168
+ p.requires_grad = False
169
+
170
+ unfrozen_param_names = []
171
+
172
+ if full_training and not train_decoder_only:
173
+ for name, p in core.named_parameters():
174
+ p.requires_grad = True
175
+ unfrozen_param_names.append(name)
176
+ loss_ratios["kl"] = float(kl_ratio)
177
+ trainable_module = core
178
+ else:
179
+ # учим только 0-й блок декодера + post_quant_conv
180
+ if hasattr(core, "decoder"):
181
+ if train_up_only:#hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
182
+ # --- только 0-й up_block ---
183
+ for name, p in core.decoder.up_blocks[0].named_parameters():
184
+ p.requires_grad = True
185
+ unfrozen_param_names.append(f"{name}")
186
+ else:
187
+ print("Decoder — fallback to full decoder")
188
+ for name, p in core.decoder.named_parameters():
189
+ p.requires_grad = True
190
+ unfrozen_param_names.append(f"decoder.{name}")
191
+ if hasattr(core, "post_quant_conv"):
192
+ for name, p in core.post_quant_conv.named_parameters():
193
+ p.requires_grad = True
194
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
195
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
196
+
197
+
198
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
199
+ for nm in unfrozen_param_names[:200]:
200
+ print(" ", nm)
201
+
202
+ # --------------------------- Датасет ---------------------------
203
+ class PngFolderDataset(Dataset):
204
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
205
+ self.root_dir = root_dir
206
+ self.resolution = resolution
207
+ self.paths = []
208
+ for root, _, files in os.walk(root_dir):
209
+ for fname in files:
210
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
211
+ self.paths.append(os.path.join(root, fname))
212
+ if limit:
213
+ self.paths = self.paths[:limit]
214
+ valid = []
215
+ for p in self.paths:
216
+ try:
217
+ with Image.open(p) as im:
218
+ im.verify()
219
+ valid.append(p)
220
+ except (OSError, UnidentifiedImageError):
221
+ continue
222
+ self.paths = valid
223
+ if len(self.paths) == 0:
224
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
225
+ random.shuffle(self.paths)
226
+
227
+ def __len__(self):
228
+ return len(self.paths)
229
+
230
+ def __getitem__(self, idx):
231
+ p = self.paths[idx % len(self.paths)]
232
+ with Image.open(p) as img:
233
+ img = img.convert("RGB")
234
+ if not resize_long_side or resize_long_side <= 0:
235
+ return img
236
+ w, h = img.size
237
+ long = max(w, h)
238
+ if long <= resize_long_side:
239
+ return img
240
+ scale = resize_long_side / float(long)
241
+ new_w = int(round(w * scale))
242
+ new_h = int(round(h * scale))
243
+ return img.resize((new_w, new_h), Image.BICUBIC)
244
+
245
+ def random_crop(img, sz):
246
+ w, h = img.size
247
+ if w < sz or h < sz:
248
+ img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
249
+ x = random.randint(0, max(1, img.width - sz))
250
+ y = random.randint(0, max(1, img.height - sz))
251
+ return img.crop((x, y, x + sz, y + sz))
252
+
253
+ tfm = transforms.Compose([
254
+ transforms.ToTensor(),
255
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
256
+ ])
257
+
258
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
259
+ print("len(dataset)",len(dataset))
260
+ if len(dataset) < batch_size:
261
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
262
+
263
+ def collate_fn(batch):
264
+ imgs = []
265
+ for img in batch:
266
+ img = random_crop(img, high_resolution)
267
+ imgs.append(tfm(img))
268
+ return torch.stack(imgs)
269
+
270
+ dataloader = DataLoader(
271
+ dataset,
272
+ batch_size=batch_size,
273
+ shuffle=True,
274
+ collate_fn=collate_fn,
275
+ num_workers=num_workers,
276
+ pin_memory=True,
277
+ drop_last=True
278
+ )
279
+
280
+ # --------------------------- Оптимизатор ---------------------------
281
+ def get_param_groups(module, weight_decay=0.001):
282
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
283
+ decay_params, no_decay_params = [], []
284
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
285
+ if not p.requires_grad:
286
+ continue
287
+ if any(nd in n for nd in no_decay):
288
+ no_decay_params.append(p)
289
+ else:
290
+ decay_params.append(p)
291
+ return [
292
+ {"params": decay_params, "weight_decay": weight_decay},
293
+ {"params": no_decay_params, "weight_decay": 0.0},
294
+ ]
295
+
296
+ def get_param_groups(module, weight_decay=0.001):
297
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
298
+ decay_params, no_decay_params = [], []
299
+ for n, p in module.named_parameters():
300
+ if not p.requires_grad:
301
+ continue
302
+ n_l = n.lower()
303
+ if any(t in n_l for t in no_decay_tokens):
304
+ no_decay_params.append(p)
305
+ else:
306
+ decay_params.append(p)
307
+ return [
308
+ {"params": decay_params, "weight_decay": weight_decay},
309
+ {"params": no_decay_params, "weight_decay": 0.0},
310
+ ]
311
+
312
+ def create_optimizer(name, param_groups):
313
+ if name == "adam8bit":
314
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
315
+ raise ValueError(name)
316
+
317
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
318
+ optimizer = create_optimizer(optimizer_type, param_groups)
319
+
320
+ # --------------------------- LR schedule ---------------------------
321
+ batches_per_epoch = len(dataloader)
322
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
323
+ total_steps = steps_per_epoch * num_epochs
324
+
325
+ def lr_lambda(step):
326
+ if not use_decay:
327
+ return 1.0
328
+ x = float(step) / float(max(1, total_steps))
329
+ warmup = float(warmup_percent)
330
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
331
+ if x < warmup:
332
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
333
+ decay_ratio = (x - warmup) / (1.0 - warmup)
334
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
335
+
336
+ scheduler = LambdaLR(optimizer, lr_lambda)
337
+
338
+ # Подготовка
339
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
340
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
341
+
342
+ # fdl
343
+ fdl_loss = FDL_loss()
344
+ fdl_loss = fdl_loss.to(accelerator.device)
345
+
346
+ # --------------------------- LPIPS и вспомогательные ---------------------------
347
+ _lpips_net = None
348
+ def _get_lpips():
349
+ global _lpips_net
350
+ if _lpips_net is None:
351
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
352
+ return _lpips_net
353
+
354
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
355
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
356
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
357
+ C = x.shape[1]
358
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
359
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
360
+ gx = F.conv2d(x, kx, padding=1, groups=C)
361
+ gy = F.conv2d(x, ky, padding=1, groups=C)
362
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
363
+
364
+ class MedianLossNormalizer:
365
+ def __init__(self, desired_ratios: dict, window_steps: int):
366
+ s = sum(desired_ratios.values())
367
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
368
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
369
+ self.window = window_steps
370
+
371
+ def update_and_total(self, abs_losses: dict):
372
+ for k, v in abs_losses.items():
373
+ if k in self.buffers:
374
+ self.buffers[k].append(float(v.detach().abs().cpu()))
375
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
376
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
377
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
378
+ return total, coeffs, meds
379
+
380
+ if full_training and not train_decoder_only:
381
+ loss_ratios["kl"] = float(kl_ratio)
382
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
383
+
384
+ # --------------------------- Сэмплы ---------------------------
385
+ @torch.no_grad()
386
+ def get_fixed_samples(n=3):
387
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
388
+ pil_imgs = [dataset[i] for i in idx]
389
+ tensors = []
390
+ for img in pil_imgs:
391
+ img = random_crop(img, high_resolution)
392
+ tensors.append(tfm(img))
393
+ return torch.stack(tensors).to(accelerator.device, dtype)
394
+
395
+ fixed_samples = get_fixed_samples()
396
+
397
+ @torch.no_grad()
398
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
399
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
400
+ return Image.fromarray(arr)
401
+
402
+
403
+ @torch.no_grad()
404
+ def generate_and_save_samples(step=None):
405
+ try:
406
+ #temp_vae = accelerator.unwrap_model(vae).eval()
407
+ if hasattr(vae, "module"):
408
+ # Если это DDP или DistributedDataParallel
409
+ unwrapped_vae = vae.module
410
+ else:
411
+ unwrapped_vae = vae
412
+
413
+ # Если использовался torch.compile, достаем оригинал
414
+ if hasattr(unwrapped_vae, "_orig_mod"):
415
+ temp_vae = unwrapped_vae._orig_mod
416
+ else:
417
+ temp_vae = unwrapped_vae
418
+
419
+ temp_vae = temp_vae.eval()
420
+ lpips_net = _get_lpips()
421
+ with torch.no_grad():
422
+ orig_high = fixed_samples
423
+ orig_low = F.interpolate(
424
+ orig_high,
425
+ size=(model_resolution, model_resolution),
426
+ mode="bilinear",
427
+ align_corners=False
428
+ )
429
+ model_dtype = next(temp_vae.parameters()).dtype
430
+ orig_low = orig_low.to(dtype=model_dtype)
431
+
432
+ # Encode/decode с учётом видео-режима
433
+ if is_video_vae(temp_vae):
434
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
435
+ enc = temp_vae.encode(x_in)
436
+ latents_mean = enc.latent_dist.mean
437
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
438
+ rec = dec.squeeze(2) # [B,3,H,W]
439
+ else:
440
+ enc = temp_vae.encode(orig_low)
441
+ latents_mean = enc.latent_dist.mean
442
+ rec = temp_vae.decode(latents_mean).sample
443
+
444
+ # Подгон размеров, если надо
445
+ #if rec.shape[-2:] != orig_high.shape[-2:]:
446
+ # rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
447
+
448
+ # Сохраняем все real/decoded
449
+ for i in range(rec.shape[0]):
450
+ real_img = _to_pil_uint8(orig_high[i])
451
+ dec_img = _to_pil_uint8(rec[i])
452
+ real_img.save(f"{generated_folder}/sample_real_{i}.png")
453
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.png")
454
+
455
+ # LPIPS
456
+ lpips_scores = []
457
+ for i in range(rec.shape[0]):
458
+ orig_full = orig_high[i:i+1].to(torch.float32)
459
+ rec_full = rec[i:i+1].to(torch.float32)
460
+ #if rec_full.shape[-2:] != orig_full.shape[-2:]:
461
+ # rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
462
+ lpips_val = lpips_net(orig_full, rec_full).item()
463
+ lpips_scores.append(lpips_val)
464
+ avg_lpips = float(np.mean(lpips_scores))
465
+
466
+ # W&B логирование
467
+ if use_wandb and accelerator.is_main_process:
468
+ log_data = {"lpips_mean": avg_lpips}
469
+ for i in range(rec.shape[0]):
470
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.png", caption=f"real_{i}")
471
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.png", caption=f"decoded_{i}")
472
+ wandb.log(log_data, step=step)
473
+
474
+ finally:
475
+ gc.collect()
476
+ torch.cuda.empty_cache()
477
+
478
+
479
+ if accelerator.is_main_process and save_model:
480
+ print("Генерация сэмплов до старта обучения...")
481
+ generate_and_save_samples(0)
482
+
483
+ accelerator.wait_for_everyone()
484
+
485
+ # --------------------------- Тренировка ---------------------------
486
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
487
+ global_step = 0
488
+ min_loss = float("inf")
489
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
490
+
491
+ for epoch in range(num_epochs):
492
+ vae.train()
493
+ batch_losses, batch_grads = [], []
494
+ track_losses = {k: [] for k in loss_ratios.keys()}
495
+
496
+ for imgs in dataloader:
497
+ with accelerator.accumulate(vae):
498
+ imgs = imgs.to(accelerator.device)
499
+
500
+ if high_resolution != model_resolution:
501
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution),mode="area") # mode="bilinear", align_corners=False)
502
+ else:
503
+ imgs_low = imgs
504
+
505
+ model_dtype = next(vae.parameters()).dtype
506
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
507
+
508
+ # Вместо: current_vae = accelerator.unwrap_model(vae)
509
+ unwrapped = vae.module if hasattr(vae, "module") else vae
510
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
511
+
512
+
513
+ # QWEN: encode/decode с T=1
514
+ if is_video_vae(current_vae):
515
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
516
+ enc = current_vae.encode(x_in)
517
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
518
+ dec = current_vae.decode(latents).sample # [B,3,1,H,W]
519
+ rec = dec.squeeze(2) # [B,3,H,W]
520
+ else:
521
+ enc = current_vae.encode(imgs_low_model)
522
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
523
+ rec = current_vae.decode(latents).sample
524
+
525
+ #if rec.shape[-2:] != imgs.shape[-2:]:
526
+ # rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
527
+
528
+ rec_f32 = rec.to(torch.float32)
529
+ imgs_f32 = imgs.to(torch.float32)
530
+
531
+ abs_losses = {
532
+ "mae": F.l1_loss(rec_f32, imgs_f32),
533
+ "mse": F.mse_loss(rec_f32, imgs_f32),
534
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
535
+ "fdl": fdl_loss(rec_f32, imgs_f32),
536
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
537
+ }
538
+
539
+ if full_training and not train_decoder_only:
540
+ mean = enc.latent_dist.mean
541
+ logvar = enc.latent_dist.logvar
542
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
543
+ abs_losses["kl"] = kl
544
+ else:
545
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
546
+
547
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
548
+
549
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
550
+ raise RuntimeError("NaN/Inf loss")
551
+
552
+ accelerator.backward(total_loss)
553
+
554
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
555
+ if accelerator.sync_gradients:
556
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
557
+ optimizer.step()
558
+ scheduler.step()
559
+ optimizer.zero_grad(set_to_none=True)
560
+ global_step += 1
561
+ progress.update(1)
562
+
563
+ if accelerator.is_main_process:
564
+ try:
565
+ current_lr = optimizer.param_groups[0]["lr"]
566
+ except Exception:
567
+ current_lr = scheduler.get_last_lr()[0]
568
+
569
+ batch_losses.append(total_loss.detach().item())
570
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
571
+ for k, v in abs_losses.items():
572
+ track_losses[k].append(float(v.detach().item()))
573
+
574
+ if use_wandb and accelerator.sync_gradients:
575
+ log_dict = {
576
+ "total_loss": float(total_loss.detach().item()),
577
+ "learning_rate": current_lr,
578
+ "epoch": epoch,
579
+ "grad_norm": batch_grads[-1],
580
+ }
581
+ for k, v in abs_losses.items():
582
+ log_dict[f"loss_{k}"] = float(v.detach().item())
583
+ for k in coeffs:
584
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
585
+ log_dict[f"median_{k}"] = float(meds[k])
586
+ wandb.log(log_dict, step=global_step)
587
+
588
+ if global_step > 0 and global_step % sample_interval == 0:
589
+ if accelerator.is_main_process:
590
+ generate_and_save_samples(global_step)
591
+ accelerator.wait_for_everyone()
592
+
593
+ n_micro = sample_interval * gradient_accumulation_steps
594
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
595
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
596
+
597
+ if accelerator.is_main_process:
598
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
599
+ if save_model and avg_loss < min_loss * save_barrier:
600
+ min_loss = avg_loss
601
+ unwrapped = vae.module if hasattr(vae, "module") else vae
602
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
603
+ current_vae.save_pretrained(save_as)
604
+ if use_wandb:
605
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
606
+
607
+ if accelerator.is_main_process:
608
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
609
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
610
+ if use_wandb:
611
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
612
+
613
+ # --------------------------- Финальное сохранение ---------------------------
614
+ if accelerator.is_main_process:
615
+ print("Training finished – saving final model")
616
+ if save_model:
617
+ unwrapped = vae.module if hasattr(vae, "module") else vae
618
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
619
+ current_vae.save_pretrained(save_as)
620
+
621
+ accelerator.free_memory()
622
+ if torch.distributed.is_initialized():
623
+ torch.distributed.destroy_process_group()
624
+ print("Готово!")