recoilme commited on
Commit
f8803e6
·
1 Parent(s): 8592059
down.sh CHANGED
File without changes
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers>=0.32.2
2
+ accelerate>=1.5.2
3
+ datasets>=3.5.0
4
+ matplotlib>=3.10.1
5
+ wandb>=0.19.8
6
+ huggingface_hub>=0.29.3
7
+ bitsandbytes>=0.45.4
8
+ transformers
9
+ hf_transfer
10
+ lpips
samples/sample_0.jpg CHANGED

Git LFS Details

  • SHA256: c4e815d64f34c24b8e56359be835dfaed9442439a95a387076b86e76791dd216
  • Pointer size: 130 Bytes
  • Size of remote file: 97 kB

Git LFS Details

  • SHA256: 2afa6db6e3af8fb931230b3f97abfe11b4925ee6c332a8fbd5d1c5f7d8d22593
  • Pointer size: 130 Bytes
  • Size of remote file: 63.8 kB
samples/sample_1.jpg CHANGED

Git LFS Details

  • SHA256: 7581064f43dc8e29b122e9e2c8e85d17db746999f8c96970d92ce3503e819d09
  • Pointer size: 131 Bytes
  • Size of remote file: 216 kB

Git LFS Details

  • SHA256: db182a72ee0953837f023de75e6ef1d8f110032877c77ee8ab20f84b04d07902
  • Pointer size: 130 Bytes
  • Size of remote file: 58.6 kB
samples/sample_2.jpg CHANGED

Git LFS Details

  • SHA256: 2a1dc60b95faf2711b2cc2d24ea50a676db5bd95884787f7b4d26a39f2431fbd
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB

Git LFS Details

  • SHA256: bfbbdcc88844acf204dc8cc621c01ccf2edf44e4e290235c6aaba3b005e1b652
  • Pointer size: 130 Bytes
  • Size of remote file: 62.4 kB
samples/sample_decoded.jpg CHANGED

Git LFS Details

  • SHA256: c4e815d64f34c24b8e56359be835dfaed9442439a95a387076b86e76791dd216
  • Pointer size: 130 Bytes
  • Size of remote file: 97 kB

Git LFS Details

  • SHA256: 2afa6db6e3af8fb931230b3f97abfe11b4925ee6c332a8fbd5d1c5f7d8d22593
  • Pointer size: 130 Bytes
  • Size of remote file: 63.8 kB
samples/sample_decoded_0.jpg ADDED
samples/sample_decoded_1.jpg ADDED
samples/sample_decoded_2.jpg ADDED
samples/sample_real.jpg CHANGED

Git LFS Details

  • SHA256: 3cc35d9bbece554c4c56a1ab764588e545c4292654d3288e43f8abffa04d249a
  • Pointer size: 130 Bytes
  • Size of remote file: 87.6 kB

Git LFS Details

  • SHA256: aced529a0e633c3fbd35abcc4147dd0108e824bbe4fb8acc097a2a2c521691c2
  • Pointer size: 130 Bytes
  • Size of remote file: 59.7 kB
samples/sample_real_0.jpg ADDED
samples/sample_real_1.jpg ADDED
samples/sample_real_2.jpg ADDED
train_vae.py CHANGED
@@ -28,20 +28,20 @@ from collections import deque
28
 
29
  # --------------------------- Параметры ---------------------------
30
  ds_path = "/workspace/d23"
31
- project = "vae2"
32
- batch_size = 1
33
- base_learning_rate = 6e-6
34
- min_learning_rate = 8e-8
35
- num_epochs = 20
36
- sample_interval_share = 10
37
  use_wandb = True
38
  save_model = True
39
  use_decay = True
40
  optimizer_type = "adam8bit"
41
  dtype = torch.float32
42
 
43
- model_resolution = 512
44
- high_resolution = 1024
45
  limit = 0
46
  save_barrier = 1.3
47
  warmup_percent = 0.001
@@ -50,9 +50,9 @@ beta2 = 0.997
50
  eps = 1e-8
51
  clip_grad_norm = 1.0
52
  mixed_precision = "no"
53
- gradient_accumulation_steps = 16
54
  generated_folder = "samples"
55
- save_as = "vae2"
56
  num_workers = 0
57
  device = None
58
 
@@ -70,7 +70,7 @@ loss_ratios = {
70
  "mae": 0.10,
71
  "kl": 0.00, # активируем при full_training=True
72
  }
73
- median_coeff_steps = 256
74
 
75
  resize_long_side = 1280 # ресайз длинной стороны исходных картинок
76
 
@@ -385,8 +385,72 @@ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
385
  arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
386
  return Image.fromarray(arr)
387
 
 
388
  @torch.no_grad()
389
  def generate_and_save_samples(step=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  try:
391
  temp_vae = accelerator.unwrap_model(vae).eval()
392
  lpips_net = _get_lpips()
 
28
 
29
  # --------------------------- Параметры ---------------------------
30
  ds_path = "/workspace/d23"
31
+ project = "vae3"
32
+ batch_size = 5
33
+ base_learning_rate = 5e-5
34
+ min_learning_rate = 1e-5
35
+ num_epochs = 50
36
+ sample_interval_share = 2
37
  use_wandb = True
38
  save_model = True
39
  use_decay = True
40
  optimizer_type = "adam8bit"
41
  dtype = torch.float32
42
 
43
+ model_resolution = 256
44
+ high_resolution = 512
45
  limit = 0
46
  save_barrier = 1.3
47
  warmup_percent = 0.001
 
50
  eps = 1e-8
51
  clip_grad_norm = 1.0
52
  mixed_precision = "no"
53
+ gradient_accumulation_steps = 2
54
  generated_folder = "samples"
55
+ save_as = "vae3"
56
  num_workers = 0
57
  device = None
58
 
 
70
  "mae": 0.10,
71
  "kl": 0.00, # активируем при full_training=True
72
  }
73
+ median_coeff_steps = 1000
74
 
75
  resize_long_side = 1280 # ресайз длинной стороны исходных картинок
76
 
 
385
  arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
386
  return Image.fromarray(arr)
387
 
388
+
389
  @torch.no_grad()
390
  def generate_and_save_samples(step=None):
391
+ try:
392
+ temp_vae = accelerator.unwrap_model(vae).eval()
393
+ lpips_net = _get_lpips()
394
+ with torch.no_grad():
395
+ orig_high = fixed_samples
396
+ orig_low = F.interpolate(
397
+ orig_high,
398
+ size=(model_resolution, model_resolution),
399
+ mode="bilinear",
400
+ align_corners=False
401
+ )
402
+ model_dtype = next(temp_vae.parameters()).dtype
403
+ orig_low = orig_low.to(dtype=model_dtype)
404
+
405
+ # Encode/decode с учётом видео-режима
406
+ if is_video_vae(temp_vae):
407
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
408
+ enc = temp_vae.encode(x_in)
409
+ latents_mean = enc.latent_dist.mean
410
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
411
+ rec = dec.squeeze(2) # [B,3,H,W]
412
+ else:
413
+ enc = temp_vae.encode(orig_low)
414
+ latents_mean = enc.latent_dist.mean
415
+ rec = temp_vae.decode(latents_mean).sample
416
+
417
+ # Подгон размеров, если надо
418
+ if rec.shape[-2:] != orig_high.shape[-2:]:
419
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
420
+
421
+ # Сохраняем все real/decoded
422
+ for i in range(rec.shape[0]):
423
+ real_img = _to_pil_uint8(orig_high[i])
424
+ dec_img = _to_pil_uint8(rec[i])
425
+ real_img.save(f"{generated_folder}/sample_real_{i}.jpg", quality=95)
426
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.jpg", quality=95)
427
+
428
+ # LPIPS
429
+ lpips_scores = []
430
+ for i in range(rec.shape[0]):
431
+ orig_full = orig_high[i:i+1].to(torch.float32)
432
+ rec_full = rec[i:i+1].to(torch.float32)
433
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
434
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
435
+ lpips_val = lpips_net(orig_full, rec_full).item()
436
+ lpips_scores.append(lpips_val)
437
+ avg_lpips = float(np.mean(lpips_scores))
438
+
439
+ # W&B логирование
440
+ if use_wandb and accelerator.is_main_process:
441
+ log_data = {"lpips_mean": avg_lpips}
442
+ for i in range(rec.shape[0]):
443
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.jpg", caption=f"real_{i}")
444
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.jpg", caption=f"decoded_{i}")
445
+ wandb.log(log_data, step=step)
446
+
447
+ finally:
448
+ gc.collect()
449
+ torch.cuda.empty_cache()
450
+
451
+
452
+
453
+ def generate_and_save_samples2(step=None):
454
  try:
455
  temp_vae = accelerator.unwrap_model(vae).eval()
456
  lpips_net = _get_lpips()
transfer_simplevae3.ipynb ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "c15deb04-94a0-4073-a174-adcd22af10b8",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "✅ Создана новая модель: <class 'diffusers.models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL'>\n"
14
+ ]
15
+ },
16
+ {
17
+ "data": {
18
+ "application/vnd.jupyter.widget-view+json": {
19
+ "model_id": "e2063f203ab844489f3c02cb9c2ae70b",
20
+ "version_major": 2,
21
+ "version_minor": 0
22
+ },
23
+ "text/plain": [
24
+ "config.json: 0%| | 0.00/801 [00:00<?, ?B/s]"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "data": {
32
+ "application/vnd.jupyter.widget-view+json": {
33
+ "model_id": "d33d67a744ee43b3b9eaeba9228ba976",
34
+ "version_major": 2,
35
+ "version_minor": 0
36
+ },
37
+ "text/plain": [
38
+ "vae/diffusion_pytorch_model.safetensors: 0%| | 0.00/168M [00:00<?, ?B/s]"
39
+ ]
40
+ },
41
+ "metadata": {},
42
+ "output_type": "display_data"
43
+ },
44
+ {
45
+ "name": "stderr",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "The config attributes {'block_out_channels': [128, 128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
49
+ ]
50
+ },
51
+ {
52
+ "name": "stdout",
53
+ "output_type": "stream",
54
+ "text": [
55
+ "\n",
56
+ "--- Перенос весов ---\n"
57
+ ]
58
+ },
59
+ {
60
+ "name": "stderr",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "100%|██████████| 248/248 [00:00<00:00, 87271.36it/s]"
64
+ ]
65
+ },
66
+ {
67
+ "name": "stdout",
68
+ "output_type": "stream",
69
+ "text": [
70
+ "\n",
71
+ "✅ Перенос завершён.\n",
72
+ "Статистика:\n",
73
+ " перенесено: 216\n",
74
+ " дублировано: 26\n",
75
+ " пропущено: 0\n"
76
+ ]
77
+ },
78
+ {
79
+ "name": "stderr",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "\n"
83
+ ]
84
+ }
85
+ ],
86
+ "source": [
87
+ "from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL\n",
88
+ "import torch\n",
89
+ "from tqdm import tqdm\n",
90
+ "\n",
91
+ "# ---- Конфиг новой модели ----\n",
92
+ "config = {\n",
93
+ " \"_class_name\": \"AsymmetricAutoencoderKL\",\n",
94
+ " \"act_fn\": \"silu\",\n",
95
+ " \"in_channels\": 3,\n",
96
+ " \"out_channels\": 3,\n",
97
+ " \"scaling_factor\": 1.0,\n",
98
+ " \"norm_num_groups\": 32,\n",
99
+ " \"down_block_out_channels\": [128, 256, 512, 512],\n",
100
+ " \"down_block_types\": [\n",
101
+ " \"DownEncoderBlock2D\",\n",
102
+ " \"DownEncoderBlock2D\",\n",
103
+ " \"DownEncoderBlock2D\",\n",
104
+ " \"DownEncoderBlock2D\",\n",
105
+ " ],\n",
106
+ " \"latent_channels\": 16,\n",
107
+ " # Новый UpDecoderBlock добавлен в начало\n",
108
+ " \"up_block_out_channels\": [128, 128, 256, 512, 512],\n",
109
+ " \"up_block_types\": [\n",
110
+ " \"UpDecoderBlock2D\",\n",
111
+ " \"UpDecoderBlock2D\",\n",
112
+ " \"UpDecoderBlock2D\",\n",
113
+ " \"UpDecoderBlock2D\",\n",
114
+ " \"UpDecoderBlock2D\",\n",
115
+ " ],\n",
116
+ "}\n",
117
+ "\n",
118
+ "# ---- Создание пустой асимметричной модели ----\n",
119
+ "vae = AsymmetricAutoencoderKL(\n",
120
+ " act_fn=config[\"act_fn\"],\n",
121
+ " down_block_out_channels=config[\"down_block_out_channels\"],\n",
122
+ " down_block_types=config[\"down_block_types\"],\n",
123
+ " latent_channels=config[\"latent_channels\"],\n",
124
+ " up_block_out_channels=config[\"up_block_out_channels\"],\n",
125
+ " up_block_types=config[\"up_block_types\"],\n",
126
+ " in_channels=config[\"in_channels\"],\n",
127
+ " out_channels=config[\"out_channels\"],\n",
128
+ " scaling_factor=config[\"scaling_factor\"],\n",
129
+ " norm_num_groups=config[\"norm_num_groups\"],\n",
130
+ " layers_per_down_block=2,\n",
131
+ " layers_per_up_block=2,\n",
132
+ " sample_size=1024\n",
133
+ ")\n",
134
+ "\n",
135
+ "vae.save_pretrained(\"asymmetric_vae_empty\")\n",
136
+ "print(\"✅ Создана новая модель:\", type(vae))\n",
137
+ "\n",
138
+ "# ---- Функция переноса весов старого VAE ----\n",
139
+ "def transfer_weights(old_path, new_path, save_path=\"asymmetric_vae\", device=\"cuda\", dtype=torch.float16):\n",
140
+ " old_vae = AutoencoderKL.from_pretrained(old_path, subfolder=\"vae\").to(device, dtype=dtype)\n",
141
+ " new_vae = AsymmetricAutoencoderKL.from_pretrained(new_path).to(device, dtype=dtype)\n",
142
+ "\n",
143
+ " old_sd = old_vae.state_dict()\n",
144
+ " new_sd = new_vae.state_dict()\n",
145
+ "\n",
146
+ " transferred_keys = set()\n",
147
+ " transfer_stats = {\"перенесено\": 0, \"дублировано\": 0, \"пропущено\": 0}\n",
148
+ "\n",
149
+ " print(\"\\n--- Перенос весов ---\")\n",
150
+ " for k, v in tqdm(old_sd.items()):\n",
151
+ " # Копирование энкодера и прочих совпадающих ключей\n",
152
+ " if (\"encoder\" in k) or (\"quant_conv\" in k) or (\"post_quant_conv\" in k):\n",
153
+ " if k in new_sd and new_sd[k].shape == v.shape:\n",
154
+ " new_sd[k] = v.clone()\n",
155
+ " transferred_keys.add(k)\n",
156
+ " transfer_stats[\"перенесено\"] += 1\n",
157
+ " continue\n",
158
+ "\n",
159
+ " # Копирование декодера (без сдвига)\n",
160
+ " if \"decoder.up_blocks\" in k:\n",
161
+ " if k in new_sd and new_sd[k].shape == v.shape:\n",
162
+ " new_sd[k] = v.clone()\n",
163
+ " transferred_keys.add(k)\n",
164
+ " transfer_stats[\"перенесено\"] += 1\n",
165
+ " continue\n",
166
+ "\n",
167
+ " # Дублирование весов старого первого 512→512 блока в новый блок 64→128 для апскейла\n",
168
+ " ref_prefix = \"decoder.up_blocks.1\"\n",
169
+ " new_prefix = \"decoder.up_blocks.0\"\n",
170
+ " for k, v in old_sd.items():\n",
171
+ " if k.startswith(ref_prefix) and new_prefix + k[len(ref_prefix):] in new_sd:\n",
172
+ " new_k = k.replace(ref_prefix, new_prefix)\n",
173
+ " if new_sd[new_k].shape == v.shape:\n",
174
+ " new_sd[new_k] = v.clone()\n",
175
+ " transferred_keys.add(new_k)\n",
176
+ " transfer_stats[\"дублировано\"] += 1\n",
177
+ "\n",
178
+ " # Загрузка и сохранение\n",
179
+ " new_vae.load_state_dict(new_sd, strict=False)\n",
180
+ " new_vae.save_pretrained(save_path)\n",
181
+ "\n",
182
+ " print(\"\\n✅ Перенос завершён.\")\n",
183
+ " print(\"Статистика:\")\n",
184
+ " for k, v in transfer_stats.items():\n",
185
+ " print(f\" {k}: {v}\")\n",
186
+ "\n",
187
+ "# ---- Запуск переноса ----\n",
188
+ "transfer_weights(\"AiArtLab/simplevae\", \"asymmetric_vae_empty\", save_path=\"vae3\")\n"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "59fcafb9-6d89-49b4-8362-b4891f591687",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": []
198
+ }
199
+ ],
200
+ "metadata": {
201
+ "kernelspec": {
202
+ "display_name": "Python 3 (ipykernel)",
203
+ "language": "python",
204
+ "name": "python3"
205
+ },
206
+ "language_info": {
207
+ "codemirror_mode": {
208
+ "name": "ipython",
209
+ "version": 3
210
+ },
211
+ "file_extension": ".py",
212
+ "mimetype": "text/x-python",
213
+ "name": "python",
214
+ "nbconvert_exporter": "python",
215
+ "pygments_lexer": "ipython3",
216
+ "version": "3.12.3"
217
+ }
218
+ },
219
+ "nbformat": 4,
220
+ "nbformat_minor": 5
221
+ }
untitled.txt DELETED
File without changes
vae3/config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AsymmetricAutoencoderKL",
3
+ "_diffusers_version": "0.35.2",
4
+ "_name_or_path": "vae3",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 128,
9
+ 256,
10
+ 512,
11
+ 512
12
+ ],
13
+ "down_block_out_channels": [
14
+ 128,
15
+ 256,
16
+ 512,
17
+ 512
18
+ ],
19
+ "down_block_types": [
20
+ "DownEncoderBlock2D",
21
+ "DownEncoderBlock2D",
22
+ "DownEncoderBlock2D",
23
+ "DownEncoderBlock2D"
24
+ ],
25
+ "force_upcast": false,
26
+ "in_channels": 3,
27
+ "latent_channels": 16,
28
+ "layers_per_down_block": 2,
29
+ "layers_per_up_block": 2,
30
+ "norm_num_groups": 32,
31
+ "out_channels": 3,
32
+ "sample_size": 1024,
33
+ "scaling_factor": 1.0,
34
+ "up_block_out_channels": [
35
+ 128,
36
+ 128,
37
+ 256,
38
+ 512,
39
+ 512
40
+ ],
41
+ "up_block_types": [
42
+ "UpDecoderBlock2D",
43
+ "UpDecoderBlock2D",
44
+ "UpDecoderBlock2D",
45
+ "UpDecoderBlock2D",
46
+ "UpDecoderBlock2D"
47
+ ]
48
+ }
vae3/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da227de34295ccebed70b2fc0879e721a75c0e1ccb28a7a65a7e54651b291260
3
+ size 382598708