babkasotona commited on
Commit
b55577e
·
verified ·
1 Parent(s): 000074f

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/config-checkpoint.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 32,
20
+ "latents_mean": [
21
+ -0.03542253375053406,
22
+ 0.20086465775966644,
23
+ -0.016413161531090736,
24
+ -0.0956302210688591,
25
+ -0.2672063112258911,
26
+ 0.2609933018684387,
27
+ -0.07806991040706635,
28
+ -0.48407721519470215,
29
+ 0.21844269335269928,
30
+ -0.1122383326292038,
31
+ 0.27197545766830444,
32
+ -0.18958772718906403,
33
+ 0.18776826560497284,
34
+ 0.0987580344080925,
35
+ 0.2837068736553192,
36
+ -0.4486690163612366,
37
+ 0.4816776514053345,
38
+ 0.02947971224784851,
39
+ -0.1337375044822693,
40
+ -0.39750921726226807,
41
+ -0.08513020724058151,
42
+ -0.054023586213588715,
43
+ -0.3943594992160797,
44
+ 0.23918119072914124,
45
+ -0.12466679513454437,
46
+ 0.09935147315263748,
47
+ 0.31858691573143005,
48
+ 0.48585832118988037,
49
+ -0.6416525840759277,
50
+ -0.15164820849895477,
51
+ -0.4693508744239807,
52
+ -0.13071806728839874
53
+ ],
54
+ "latents_std": [
55
+ 1.5792087316513062,
56
+ 1.5769503116607666,
57
+ 1.5864241123199463,
58
+ 1.6454921960830688,
59
+ 1.5336694717407227,
60
+ 1.5587652921676636,
61
+ 1.5838669538497925,
62
+ 1.5659377574920654,
63
+ 1.6860467195510864,
64
+ 1.5192310810089111,
65
+ 1.573639988899231,
66
+ 1.5953549146652222,
67
+ 1.5271092653274536,
68
+ 1.6246271133422852,
69
+ 1.7054023742675781,
70
+ 1.607722282409668,
71
+ 1.558642864227295,
72
+ 1.5824549198150635,
73
+ 1.6202995777130127,
74
+ 1.6206320524215698,
75
+ 1.6379750967025757,
76
+ 1.6527063846588135,
77
+ 1.498811960220337,
78
+ 1.5706247091293335,
79
+ 1.5854856967926025,
80
+ 1.4828169345855713,
81
+ 1.5693111419677734,
82
+ 1.692481517791748,
83
+ 1.6409776210784912,
84
+ 1.6216280460357666,
85
+ 1.6087706089019775,
86
+ 1.5776633024215698
87
+ ],
88
+ "layers_per_block": 2,
89
+ "mid_block_add_attention": true,
90
+ "norm_num_groups": 32,
91
+ "out_channels": 3,
92
+ "sample_size": 32,
93
+ "scaling_factor": 1.0,
94
+ "shift_factor": 0.0,
95
+ "up_block_types": [
96
+ "UpDecoderBlock2D",
97
+ "UpDecoderBlock2D",
98
+ "UpDecoderBlock2D",
99
+ "UpDecoderBlock2D"
100
+ ],
101
+ "use_post_quant_conv": true,
102
+ "use_quant_conv": true
103
+ }
Untitled.ipynb ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "dccce86b-90a0-47c7-aaad-2ebb16d90756",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Картинка загружена: torch.Size([1, 3, 1280, 1280])\n",
14
+ "\n",
15
+ "=======================================================\n",
16
+ "VAE : FLUX.2\n",
17
+ "repo: AiArtLab/sdxs-1b\n",
18
+ "latent_channels : 32\n",
19
+ "scaling_factor : 1.00000\n",
20
+ "shift_factor : 0.00000\n",
21
+ "latents_mean : нет\n",
22
+ "latents_std : нет\n",
23
+ "\n",
24
+ "[encode] raw latents: torch.Size([1, 32, 160, 160])\n",
25
+ "[flux2] patchify : torch.Size([1, 32, 160, 160]) → torch.Size([1, 128, 80, 80])\n",
26
+ "[flux2] BN norm : mean=-0.0096 std=1.7674\n",
27
+ "\n",
28
+ "[STATS] после BN нормализации (128ch):\n",
29
+ " log-variance : -0.0767 (идеал ≈ 0.0)\n",
30
+ " mean : -0.0134\n",
31
+ " std : 0.9624\n",
32
+ " shape : torch.Size([1, 128, 80, 80])\n",
33
+ "\n",
34
+ "[flux2] BN denorm + unpatchify: torch.Size([1, 32, 160, 160])\n",
35
+ "Сохранено: vaetest/decoded_FLUX.2.png\n",
36
+ "\n",
37
+ "=======================================================\n",
38
+ "VAE : vae32ch2\n",
39
+ "repo: vae32ch2\n",
40
+ "latent_channels : 32\n",
41
+ "scaling_factor : 1.00000\n",
42
+ "shift_factor : 0.00000\n",
43
+ "latents_mean : да (32ch)\n",
44
+ "latents_std : да (32ch)\n",
45
+ "\n",
46
+ "[encode] raw latents: torch.Size([1, 32, 160, 160])\n",
47
+ "\n",
48
+ "[STATS] после per-channel нормализации (32ch):\n",
49
+ " log-variance : 0.1192 (идеал ≈ 0.0)\n",
50
+ " mean : -0.0016\n",
51
+ " std : 1.0614\n",
52
+ " shape : torch.Size([1, 32, 160, 160])\n",
53
+ "\n",
54
+ "[vae32ch2] denorm: torch.Size([1, 32, 160, 160])\n",
55
+ "Сохранено: vaetest/decoded_vae32ch2.png\n",
56
+ "\n",
57
+ "=======================================================\n",
58
+ "Готово\n"
59
+ ]
60
+ }
61
+ ],
62
+ "source": [
63
+ "\n",
64
+ "import torch\n",
65
+ "from PIL import Image\n",
66
+ "from diffusers import AutoencoderKL, AutoencoderKLFlux2\n",
67
+ "from torchvision.transforms.functional import to_pil_image\n",
68
+ "import matplotlib.pyplot as plt\n",
69
+ "import os\n",
70
+ "from torchvision.transforms import ToTensor, Normalize, CenterCrop\n",
71
+ "\n",
72
+ "# ── Настройки ─────────────────────────────────────────────────────────────────\n",
73
+ "IMG_PATH = \"1234.png\"\n",
74
+ "OUT_DIR = \"vaetest\"\n",
75
+ "device = \"cuda\"\n",
76
+ "dtype = torch.float32\n",
77
+ "os.makedirs(OUT_DIR, exist_ok=True)\n",
78
+ "\n",
79
+ "VAES = {\n",
80
+ " \"FLUX.2\": (\"flux2\", \"AiArtLab/sdxs-1b\"),\n",
81
+ " \"vae32ch2\": (\"vae32ch\", \"vae32ch2\"),\n",
82
+ "}\n",
83
+ "\n",
84
+ "# ── Patchify / Unpatchify ─────────────────────────────────────────────────────\n",
85
+ "def _patchify_latents(latents):\n",
86
+ " B, C, H, W = latents.shape\n",
87
+ " latents = latents.view(B, C, H // 2, 2, W // 2, 2)\n",
88
+ " latents = latents.permute(0, 1, 3, 5, 2, 4)\n",
89
+ " latents = latents.reshape(B, C * 4, H // 2, W // 2)\n",
90
+ " return latents\n",
91
+ "\n",
92
+ "def _unpatchify_latents(latents):\n",
93
+ " B, C, H, W = latents.shape\n",
94
+ " latents = latents.reshape(B, C // 4, 2, 2, H, W)\n",
95
+ " latents = latents.permute(0, 1, 4, 2, 5, 3)\n",
96
+ " latents = latents.reshape(B, C // 4, H * 2, W * 2)\n",
97
+ " return latents\n",
98
+ "\n",
99
+ "# ── Загрузка картинки ─────────────────────────────────────────────────────────\n",
100
+ "def load_image(path):\n",
101
+ " img = Image.open(path).convert(\"RGB\")\n",
102
+ " w, h = img.size\n",
103
+ " img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n",
104
+ " tensor = ToTensor()(img).unsqueeze(0)\n",
105
+ " tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor)\n",
106
+ " return img, tensor.to(device, dtype=dtype)\n",
107
+ "\n",
108
+ "def tensor_to_img(t):\n",
109
+ " t = (t * 0.5 + 0.5).clamp(0, 1)\n",
110
+ " return to_pil_image(t[0])\n",
111
+ "\n",
112
+ "# ── Статистика ────────────────────────────────────────────────────────────────\n",
113
+ "def logvariance(latents):\n",
114
+ " return torch.log(latents.var() + 1e-8).item()\n",
115
+ "\n",
116
+ "def print_stats(name, latents):\n",
117
+ " lv = logvariance(latents)\n",
118
+ " print(f\" log-variance : {lv:.4f} (идеал ≈ 0.0)\")\n",
119
+ " print(f\" mean : {latents.mean():.4f}\")\n",
120
+ " print(f\" std : {latents.std():.4f}\")\n",
121
+ " print(f\" shape : {latents.shape}\")\n",
122
+ "\n",
123
+ "def plot_latent_distribution(latents, title, save_path):\n",
124
+ " from scipy.stats import probplot\n",
125
+ " lat = latents.detach().cpu().float().numpy().flatten()\n",
126
+ "\n",
127
+ " plt.figure(figsize=(10, 4))\n",
128
+ "\n",
129
+ " plt.subplot(1, 2, 1)\n",
130
+ " plt.hist(lat, bins=100, density=True, alpha=0.7, color=\"steelblue\")\n",
131
+ " plt.title(f\"{title} histogram\")\n",
132
+ " plt.xlabel(\"latent value\")\n",
133
+ " plt.ylabel(\"density\")\n",
134
+ "\n",
135
+ " plt.subplot(1, 2, 2)\n",
136
+ " probplot(lat, dist=\"norm\", plot=plt)\n",
137
+ " plt.title(f\"{title} QQ-plot\")\n",
138
+ "\n",
139
+ " plt.tight_layout()\n",
140
+ " plt.savefig(save_path)\n",
141
+ " plt.close()\n",
142
+ " print(f\" график сохранён: {save_path}\")\n",
143
+ "\n",
144
+ "# ── Нормализация из конфига (per-channel для vae32ch) ────────────────────────\n",
145
+ "def make_norm_tensors(cfg, latent_channels, device, dtype):\n",
146
+ " mean = getattr(cfg, \"latents_mean\", None)\n",
147
+ " std = getattr(cfg, \"latents_std\", None)\n",
148
+ " shift = getattr(cfg, \"shift_factor\", 0.0)\n",
149
+ " scale = getattr(cfg, \"scaling_factor\", 1.0)\n",
150
+ "\n",
151
+ " if mean is not None:\n",
152
+ " mean = torch.tensor(mean, device=device, dtype=dtype).view(1, latent_channels, 1, 1)\n",
153
+ " if std is not None:\n",
154
+ " std = torch.tensor(std, device=device, dtype=dtype).view(1, latent_channels, 1, 1)\n",
155
+ "\n",
156
+ " shift = torch.tensor(shift if shift else 0., device=device, dtype=dtype)\n",
157
+ " scale = torch.tensor(scale, device=device, dtype=dtype)\n",
158
+ " return mean, std, shift, scale\n",
159
+ "\n",
160
+ "# ── Основной цикл ─────────────────────────────────────────────────────────────\n",
161
+ "img, x = load_image(IMG_PATH)\n",
162
+ "img.save(os.path.join(OUT_DIR, \"original.png\"))\n",
163
+ "print(f\"Картинка загружена: {x.shape}\")\n",
164
+ "\n",
165
+ "for name, (kind, repo) in VAES.items():\n",
166
+ " print(f\"\\n{'='*55}\")\n",
167
+ " print(f\"VAE : {name}\")\n",
168
+ " print(f\"repo: {repo}\")\n",
169
+ "\n",
170
+ " # --- загружаем нужный класс ---\n",
171
+ " if kind == \"flux2\":\n",
172
+ " vae = AutoencoderKLFlux2.from_pretrained(\n",
173
+ " repo, subfolder=\"vae\", torch_dtype=dtype\n",
174
+ " ).to(device)\n",
175
+ " else:\n",
176
+ " vae = AutoencoderKL.from_pretrained(\n",
177
+ " repo, torch_dtype=dtype\n",
178
+ " ).to(device)\n",
179
+ " vae.eval()\n",
180
+ "\n",
181
+ " latent_channels = vae.config.latent_channels\n",
182
+ " mean_t, std_t, shift_t, scale_t = make_norm_tensors(\n",
183
+ " vae.config, latent_channels, device, dtype\n",
184
+ " )\n",
185
+ "\n",
186
+ " print(f\"latent_channels : {latent_channels}\")\n",
187
+ " print(f\"scaling_factor : {scale_t.item():.5f}\")\n",
188
+ " print(f\"shift_factor : {shift_t.item():.5f}\")\n",
189
+ " print(f\"latents_mean : {'да (' + str(latent_channels) + 'ch)' if mean_t is not None else 'нет'}\")\n",
190
+ " print(f\"latents_std : {'да (' + str(latent_channels) + 'ch)' if std_t is not None else 'нет'}\")\n",
191
+ "\n",
192
+ " with torch.no_grad():\n",
193
+ "\n",
194
+ " # ── ENCODE ────────────────────────────────────────────────────────────\n",
195
+ " latents = vae.encode(x).latent_dist.sample().to(dtype)\n",
196
+ " print(f\"\\n[encode] raw latents: {latents.shape}\")\n",
197
+ "\n",
198
+ " if kind == \"flux2\":\n",
199
+ " # 32ch → patchify → 128ch\n",
200
+ " latents_patched = _patchify_latents(latents)\n",
201
+ " print(f\"[flux2] patchify : {latents.shape} → {latents_patched.shape}\")\n",
202
+ "\n",
203
+ " # BN нормализация в 128-канальном пространстве\n",
204
+ " bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device, dtype)\n",
205
+ " bn_std = torch.sqrt(\n",
206
+ " vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps\n",
207
+ " ).to(device, dtype)\n",
208
+ " latents_normed = (latents_patched - bn_mean) / bn_std\n",
209
+ " print(f\"[flux2] BN norm : mean={bn_mean.mean():.4f} std={bn_std.mean():.4f}\")\n",
210
+ "\n",
211
+ " # считаем статистику в 128ch нормализованном пространстве\n",
212
+ " print(\"\\n[STATS] после BN нормализации (128ch):\")\n",
213
+ " print_stats(name, latents_normed)\n",
214
+ " #plot_latent_distribution(\n",
215
+ " # latents_normed,\n",
216
+ " # f\"{name}_latents\",\n",
217
+ " # os.path.join(OUT_DIR, f\"dist_{name}.png\")\n",
218
+ " #)\n",
219
+ "\n",
220
+ " # unpatchify → 32ch (для decode)\n",
221
+ " latents = _unpatchify_latents(latents_normed)\n",
222
+ "\n",
223
+ " else: # vae32ch2\n",
224
+ " # per-channel нормализация из конфига\n",
225
+ " if mean_t is not None and std_t is not None:\n",
226
+ " latents = (latents - mean_t) / std_t\n",
227
+ " latents = (latents - shift_t) / scale_t\n",
228
+ "\n",
229
+ " print(f\"\\n[STATS] после per-channel нормализации ({latent_channels}ch):\")\n",
230
+ " print_stats(name, latents)\n",
231
+ " #plot_latent_distribution(\n",
232
+ " # latents,\n",
233
+ " # f\"{name}_latents\",\n",
234
+ " # os.path.join(OUT_DIR, f\"dist_{name}.png\")\n",
235
+ " #)\n",
236
+ "\n",
237
+ " # ── DECODE ────────────────────────────────────────────────────────────\n",
238
+ " if kind == \"flux2\":\n",
239
+ " # patchify → denorm → unpatchify\n",
240
+ " latents_patched = _patchify_latents(latents)\n",
241
+ " latents_denormed = latents_patched * bn_std + bn_mean\n",
242
+ " latents = _unpatchify_latents(latents_denormed)\n",
243
+ " print(f\"\\n[flux2] BN denorm + unpatchify: {latents.shape}\")\n",
244
+ "\n",
245
+ " else: # vae32ch2\n",
246
+ " latents = latents * scale_t + shift_t\n",
247
+ " if mean_t is not None and std_t is not None:\n",
248
+ " latents = latents * std_t + mean_t\n",
249
+ " print(f\"\\n[vae32ch2] denorm: {latents.shape}\")\n",
250
+ "\n",
251
+ " rec = vae.decode(latents).sample\n",
252
+ "\n",
253
+ " out_path = os.path.join(OUT_DIR, f\"decoded_{name}.png\")\n",
254
+ " tensor_to_img(rec).save(out_path)\n",
255
+ " print(f\"Сохранено: {out_path}\")\n",
256
+ "\n",
257
+ "print(f\"\\n{'='*55}\")\n",
258
+ "print(\"Готово\")"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 3,
264
+ "id": "c219c07b-8da2-4182-ace6-8c3cc63ae3b1",
265
+ "metadata": {},
266
+ "outputs": [
267
+ {
268
+ "name": "stdout",
269
+ "output_type": "stream",
270
+ "text": [
271
+ "Requirement already satisfied: scipy in /usr/local/lib/python3.12/dist-packages (1.17.1)\n",
272
+ "Requirement already satisfied: numpy<2.7,>=1.26.4 in /usr/local/lib/python3.12/dist-packages (from scipy) (2.4.0)\n",
273
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
274
+ "\u001b[0m"
275
+ ]
276
+ },
277
+ {
278
+ "ename": "ModuleNotFoundError",
279
+ "evalue": "No module named 'scipy'",
280
+ "output_type": "error",
281
+ "traceback": [
282
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
283
+ "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
284
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m get_ipython().system(\u001b[33m'\u001b[39m\u001b[33mpip install --user scipy\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mscipy\u001b[39;00m\n\u001b[32m 4\u001b[39m \u001b[38;5;28mprint\u001b[39m(scipy.__version__)\n",
285
+ "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'scipy'"
286
+ ]
287
+ }
288
+ ],
289
+ "source": [
290
+ "!pip install --user scipy\n",
291
+ "\n",
292
+ "import scipy\n",
293
+ "print(scipy.__version__)\n"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "id": "43a4e1bc-2b02-4604-b69e-1a5aa276b6ac",
300
+ "metadata": {},
301
+ "outputs": [],
302
+ "source": []
303
+ }
304
+ ],
305
+ "metadata": {
306
+ "kernelspec": {
307
+ "display_name": "Python3 (ipykernel)",
308
+ "language": "python",
309
+ "name": "python3"
310
+ },
311
+ "language_info": {
312
+ "codemirror_mode": {
313
+ "name": "ipython",
314
+ "version": 3
315
+ },
316
+ "file_extension": ".py",
317
+ "mimetype": "text/x-python",
318
+ "name": "python",
319
+ "nbconvert_exporter": "python",
320
+ "pygments_lexer": "ipython3",
321
+ "version": "3.12.12"
322
+ }
323
+ },
324
+ "nbformat": 4,
325
+ "nbformat_minor": 5
326
+ }
config.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.36.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 32,
20
+ "latents_mean": [
21
+ -0.03542253375053406,
22
+ 0.20086465775966644,
23
+ -0.016413161531090736,
24
+ -0.0956302210688591,
25
+ -0.2672063112258911,
26
+ 0.2609933018684387,
27
+ -0.07806991040706635,
28
+ -0.48407721519470215,
29
+ 0.21844269335269928,
30
+ -0.1122383326292038,
31
+ 0.27197545766830444,
32
+ -0.18958772718906403,
33
+ 0.18776826560497284,
34
+ 0.0987580344080925,
35
+ 0.2837068736553192,
36
+ -0.4486690163612366,
37
+ 0.4816776514053345,
38
+ 0.02947971224784851,
39
+ -0.1337375044822693,
40
+ -0.39750921726226807,
41
+ -0.08513020724058151,
42
+ -0.054023586213588715,
43
+ -0.3943594992160797,
44
+ 0.23918119072914124,
45
+ -0.12466679513454437,
46
+ 0.09935147315263748,
47
+ 0.31858691573143005,
48
+ 0.48585832118988037,
49
+ -0.6416525840759277,
50
+ -0.15164820849895477,
51
+ -0.4693508744239807,
52
+ -0.13071806728839874
53
+ ],
54
+ "latents_std": [
55
+ 1.5792087316513062,
56
+ 1.5769503116607666,
57
+ 1.5864241123199463,
58
+ 1.6454921960830688,
59
+ 1.5336694717407227,
60
+ 1.5587652921676636,
61
+ 1.5838669538497925,
62
+ 1.5659377574920654,
63
+ 1.6860467195510864,
64
+ 1.5192310810089111,
65
+ 1.573639988899231,
66
+ 1.5953549146652222,
67
+ 1.5271092653274536,
68
+ 1.6246271133422852,
69
+ 1.7054023742675781,
70
+ 1.607722282409668,
71
+ 1.558642864227295,
72
+ 1.5824549198150635,
73
+ 1.6202995777130127,
74
+ 1.6206320524215698,
75
+ 1.6379750967025757,
76
+ 1.6527063846588135,
77
+ 1.498811960220337,
78
+ 1.5706247091293335,
79
+ 1.5854856967926025,
80
+ 1.4828169345855713,
81
+ 1.5693111419677734,
82
+ 1.692481517791748,
83
+ 1.6409776210784912,
84
+ 1.6216280460357666,
85
+ 1.6087706089019775,
86
+ 1.5776633024215698
87
+ ],
88
+ "layers_per_block": 2,
89
+ "mid_block_add_attention": true,
90
+ "norm_num_groups": 32,
91
+ "out_channels": 3,
92
+ "sample_size": 32,
93
+ "scaling_factor": 1.0,
94
+ "shift_factor": 0.0,
95
+ "up_block_types": [
96
+ "UpDecoderBlock2D",
97
+ "UpDecoderBlock2D",
98
+ "UpDecoderBlock2D",
99
+ "UpDecoderBlock2D"
100
+ ],
101
+ "use_post_quant_conv": true,
102
+ "use_quant_conv": true
103
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6369e370ff02168a240a9ebfd47810dd7babb36f76b7d9999e5d78cb4a1976c2
3
+ size 336212308
scale.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from diffusers import AutoencoderKL
5
+ from tqdm import tqdm
6
+ import pathlib
7
+
8
+ # ── 1. Загружаем VAE ──────────────────────────────────────────────────────────
9
+ vae = AutoencoderKL.from_pretrained("vae32ch", torch_dtype=torch.float32)
10
+ vae.eval().cuda()
11
+
12
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) # = 8
13
+
14
+ # ── 2. Собираем все PNG рекурсивно ───────────────────────────────────────────
15
+ dataset_path = pathlib.Path("/workspace/ds")
16
+ image_paths = sorted(dataset_path.rglob("*.png"))
17
+ print(f"Найдено картинок: {len(image_paths)}")
18
+
19
+ # Берём первые 3000
20
+ image_paths = image_paths[:30000]
21
+
22
+ # ── 3. Препроцессинг — кроп до кратного 8 без ресайза ────────────────────────
23
+ def preprocess(path):
24
+ img = Image.open(path).convert("RGB")
25
+ w, h = img.size
26
+
27
+ new_w = (w // vae_scale_factor) * vae_scale_factor
28
+ new_h = (h // vae_scale_factor) * vae_scale_factor
29
+
30
+ if new_w != w or new_h != h:
31
+ left = (w - new_w) // 2
32
+ top = (h - new_h) // 2
33
+ img = img.crop((left, top, left + new_w, top + new_h))
34
+
35
+ x = torch.from_numpy(np.array(img).astype(np.float32) / 255.0)
36
+ x = x.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
37
+ x = x * 2.0 - 1.0 # [-1, 1]
38
+ return x
39
+
40
+ # ── 4. Считаем статистику по каналам ─────────────────────────────────────────
41
+ latent_channels = vae.config.latent_channels # 32
42
+
43
+ all_means = [] # [N, C]
44
+ all_stds = [] # [N, C]
45
+ errors = []
46
+
47
+ with torch.no_grad():
48
+ for path in tqdm(image_paths, desc="Encoding"):
49
+ try:
50
+ x = preprocess(path).cuda()
51
+ lat = vae.encode(x).latent_dist.sample() # [1, C, H, W]
52
+ flat = lat.squeeze(0).float().reshape(latent_channels, -1) # [C, H*W]
53
+
54
+ all_means.append(flat.mean(dim=1).cpu()) # [C]
55
+ all_stds.append(flat.std(dim=1).cpu()) # [C]
56
+
57
+ except Exception as e:
58
+ errors.append((path, str(e)))
59
+
60
+ if errors:
61
+ print(f"\nОшибки ({len(errors)}):")
62
+ for p, e in errors:
63
+ print(f" {p}: {e}")
64
+
65
+ mean = torch.stack(all_means).mean(dim=0) # [C]
66
+ std = torch.stack(all_stds).mean(dim=0) # [C]
67
+
68
+ print(f"\nОбработано картинок: {len(all_means)}")
69
+ print(f"\nlatents_mean ({latent_channels} каналов):")
70
+ print(mean.tolist())
71
+ print(f"\nlatents_std ({latent_channels} каналов):")
72
+ print(std.tolist())
73
+
74
+ # ── 5. Создаём новый VAE с той же архитектурой + scaling векторы ──────────────
75
+ cfg = vae.config
76
+
77
+ new_vae = AutoencoderKL(
78
+ in_channels = cfg.in_channels,
79
+ out_channels = cfg.out_channels,
80
+ latent_channels = cfg.latent_channels,
81
+ block_out_channels = cfg.block_out_channels,
82
+ layers_per_block = cfg.layers_per_block,
83
+ norm_num_groups = cfg.norm_num_groups,
84
+ act_fn = cfg.act_fn,
85
+ down_block_types = cfg.down_block_types,
86
+ up_block_types = cfg.up_block_types,
87
+ )
88
+ new_vae.eval()
89
+
90
+ # Переносим веса
91
+ result = new_vae.load_state_dict(vae.state_dict(), strict=False)
92
+ print(f"\nВеса перенесены: {result}")
93
+
94
+ # Прописываем scaling векторы в конфиг
95
+ new_vae.register_to_config(
96
+ latents_mean = mean.tolist(),
97
+ latents_std = std.tolist(),
98
+ scaling_factor = 1.0,
99
+ shift_factor = 0.0,
100
+ )
101
+
102
+ print(f"\nlatents_mean в конфиге: {new_vae.config.latents_mean[:4]}...")
103
+ print(f"latents_std в конфиге: {new_vae.config.latents_std[:4]}...")
104
+
105
+ # ── 6. Сохраняем ──────────────────────────────────────────────────────────────
106
+ new_vae.save_pretrained("vae32ch2")
107
+ print("\nСохранено в vae32ch2/")