recoilme commited on
Commit
8951b09
·
1 Parent(s): feec669
.gitattributes CHANGED
@@ -39,3 +39,5 @@ samples/sample_decoded.jpg filter=lfs diff=lfs merge=lfs -text
39
  samples/sample_real.jpg filter=lfs diff=lfs merge=lfs -text
40
  samples/sample_1.jpg filter=lfs diff=lfs merge=lfs -text
41
  samples/sample_2.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
39
  samples/sample_real.jpg filter=lfs diff=lfs merge=lfs -text
40
  samples/sample_1.jpg filter=lfs diff=lfs merge=lfs -text
41
  samples/sample_2.jpg filter=lfs diff=lfs merge=lfs -text
42
+ *.jpg filter=lfs diff=lfs merge=lfs -text
43
+ *.png filter=lfs diff=lfs merge=lfs -text
down.sh CHANGED
@@ -1,6 +1,6 @@
1
  #!/bin/bash
2
 
3
- TARGET_DIR="/workspace/d23"
4
  mkdir -p "$TARGET_DIR"
5
 
6
  BASE_URL="https://huggingface.co/datasets/AI-Art-Collab/dtasettar23/resolve/main/d23.tar."
 
1
  #!/bin/bash
2
 
3
+ TARGET_DIR="/workspace"
4
  mkdir -p "$TARGET_DIR"
5
 
6
  BASE_URL="https://huggingface.co/datasets/AI-Art-Collab/dtasettar23/resolve/main/d23.tar."
requirements.txt CHANGED
@@ -7,4 +7,5 @@ huggingface_hub>=0.29.3
7
  bitsandbytes>=0.45.4
8
  transformers
9
  hf_transfer
10
- lpips
 
 
7
  bitsandbytes>=0.45.4
8
  transformers
9
  hf_transfer
10
+ lpips
11
+ fdl-pytorch
samples/sample_decoded_0.jpg CHANGED

Git LFS Details

  • SHA256: 666d69e53dabb5a1fb5e5efee1cd9ccf3bd6507f04d611dec96375758ce6c489
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
samples/sample_decoded_1.jpg CHANGED

Git LFS Details

  • SHA256: e94d470adf4219e351d513a3d49ae59872693e7d7e746262b3d4a750981f8d5b
  • Pointer size: 130 Bytes
  • Size of remote file: 75.3 kB
samples/sample_decoded_2.jpg CHANGED

Git LFS Details

  • SHA256: bd9deb9bd68ad50f46d73f8432b5aaca5bb52f9af5aff2473db6cfa8fd8facff
  • Pointer size: 130 Bytes
  • Size of remote file: 89.4 kB
samples/sample_real_0.jpg CHANGED

Git LFS Details

  • SHA256: c43d71f0fd394f8b8708151c7486348e154ec2f2832cd319f608d6e9d2c324f9
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
samples/sample_real_1.jpg CHANGED

Git LFS Details

  • SHA256: 91bc92c30f5ca06e9c34c8e915bf9cae2a710f1fb8a60842b8471428f4e2e3f9
  • Pointer size: 130 Bytes
  • Size of remote file: 69.1 kB
samples/sample_real_2.jpg CHANGED

Git LFS Details

  • SHA256: d230f08be5595fba27c422560a77f161b07d4f642790f72ad122bb2b58b9a99f
  • Pointer size: 130 Bytes
  • Size of remote file: 83.6 kB
train_vae_fdl.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
+ from collections import deque
29
+
30
+ # --------------------------- Параметры ---------------------------
31
+ ds_path = "/workspace/d23"
32
+ project = "vae"
33
+ batch_size = 2
34
+ base_learning_rate = 5e-5
35
+ min_learning_rate = 1e-5
36
+ num_epochs = 50
37
+ sample_interval_share = 5
38
+ use_wandb = True
39
+ save_model = True
40
+ use_decay = True
41
+ optimizer_type = "adam8bit"
42
+ dtype = torch.float32
43
+
44
+ model_resolution = 256
45
+ high_resolution = 512
46
+ limit = 0
47
+ save_barrier = 1.3
48
+ warmup_percent = 0.01
49
+ percentile_clipping = 99
50
+ beta2 = 0.997
51
+ eps = 1e-8
52
+ clip_grad_norm = 1.0
53
+ mixed_precision = "no"
54
+ gradient_accumulation_steps = 2
55
+ generated_folder = "samples"
56
+ save_as = "vae"
57
+ num_workers = 0
58
+ device = None
59
+
60
+ # --- Режимы обучения ---
61
+ # QWEN: учим только декодер
62
+ train_decoder_only = True
63
+ full_training = False # если True — учим весь VAE и добавляем KL (ниже)
64
+ kl_ratio = 0.00
65
+
66
+ # Доли лоссов
67
+ loss_ratios = {
68
+ "lpips": 0.25,
69
+ "fdl" : 0.50,
70
+ "edge": 0.05,
71
+ "mse": 0.10,
72
+ "mae": 0.10,
73
+ "kl": 0.00, # активируем при full_training=True
74
+ }
75
+ median_coeff_steps = 1000
76
+
77
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
78
+
79
+ # QWEN: конфиг загрузки модели
80
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
81
+
82
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
83
+
84
+ accelerator = Accelerator(
85
+ mixed_precision=mixed_precision,
86
+ gradient_accumulation_steps=gradient_accumulation_steps
87
+ )
88
+ device = accelerator.device
89
+
90
+ # reproducibility
91
+ seed = int(datetime.now().strftime("%Y%m%d"))
92
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
93
+ torch.backends.cudnn.benchmark = False
94
+
95
+ # --------------------------- WandB ---------------------------
96
+ if use_wandb and accelerator.is_main_process:
97
+ wandb.init(project=project, config={
98
+ "batch_size": batch_size,
99
+ "base_learning_rate": base_learning_rate,
100
+ "num_epochs": num_epochs,
101
+ "optimizer_type": optimizer_type,
102
+ "model_resolution": model_resolution,
103
+ "high_resolution": high_resolution,
104
+ "gradient_accumulation_steps": gradient_accumulation_steps,
105
+ "train_decoder_only": train_decoder_only,
106
+ "full_training": full_training,
107
+ "kl_ratio": kl_ratio,
108
+ "vae_kind": vae_kind,
109
+ })
110
+
111
+ # --------------------------- VAE ---------------------------
112
+ def get_core_model(model):
113
+ m = model
114
+ # если модель уже обёрнута torch.compile
115
+ if hasattr(m, "_orig_mod"):
116
+ m = m._orig_mod
117
+ return m
118
+
119
+ def is_video_vae(model) -> bool:
120
+ # WAN/Qwen — это видео-VAEs
121
+ if vae_kind in ("wan", "qwen"):
122
+ return True
123
+ # fallback по структуре (если понадобится)
124
+ try:
125
+ core = get_core_model(model)
126
+ enc = getattr(core, "encoder", None)
127
+ conv_in = getattr(enc, "conv_in", None)
128
+ w = getattr(conv_in, "weight", None)
129
+ if isinstance(w, torch.nn.Parameter):
130
+ return w.ndim == 5
131
+ except Exception:
132
+ pass
133
+ return False
134
+
135
+ # загрузка
136
+ if vae_kind == "qwen":
137
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
138
+ else:
139
+ if vae_kind == "wan":
140
+ vae = AutoencoderKLWan.from_pretrained(project)
141
+ else:
142
+ # старое поведение (пример)
143
+ if model_resolution==high_resolution:
144
+ vae = AutoencoderKL.from_pretrained(project)
145
+ else:
146
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
147
+
148
+ vae = vae.to(dtype)
149
+
150
+ # torch.compile (опционально)
151
+ if hasattr(torch, "compile"):
152
+ try:
153
+ vae = torch.compile(vae)
154
+ except Exception as e:
155
+ print(f"[WARN] torch.compile failed: {e}")
156
+
157
+ # --------------------------- Freeze/Unfreeze ---------------------------
158
+ core = get_core_model(vae)
159
+
160
+ for p in core.parameters():
161
+ p.requires_grad = False
162
+
163
+ unfrozen_param_names = []
164
+
165
+ if full_training and not train_decoder_only:
166
+ for name, p in core.named_parameters():
167
+ p.requires_grad = True
168
+ unfrozen_param_names.append(name)
169
+ loss_ratios["kl"] = float(kl_ratio)
170
+ trainable_module = core
171
+ else:
172
+ # учим только 0-й блок декодера + post_quant_conv
173
+ if hasattr(core, "decoder"):
174
+ if hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
175
+ # --- только 0-й up_block ---
176
+ for name, p in core.decoder.up_blocks[0].named_parameters():
177
+ p.requires_grad = True
178
+ unfrozen_param_names.append(f"{name}")
179
+ else:
180
+ print("[WARN] Decoder has no up_blocks — fallback to full decoder")
181
+ for name, p in core.decoder.named_parameters():
182
+ p.requires_grad = True
183
+ unfrozen_param_names.append(f"decoder.{name}")
184
+ if hasattr(core, "post_quant_conv"):
185
+ for name, p in core.post_quant_conv.named_parameters():
186
+ p.requires_grad = True
187
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
188
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
189
+
190
+
191
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
192
+ for nm in unfrozen_param_names[:200]:
193
+ print(" ", nm)
194
+
195
+ # --------------------------- Датасет ---------------------------
196
+ class PngFolderDataset(Dataset):
197
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
198
+ self.root_dir = root_dir
199
+ self.resolution = resolution
200
+ self.paths = []
201
+ for root, _, files in os.walk(root_dir):
202
+ for fname in files:
203
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
204
+ self.paths.append(os.path.join(root, fname))
205
+ if limit:
206
+ self.paths = self.paths[:limit]
207
+ valid = []
208
+ for p in self.paths:
209
+ try:
210
+ with Image.open(p) as im:
211
+ im.verify()
212
+ valid.append(p)
213
+ except (OSError, UnidentifiedImageError):
214
+ continue
215
+ self.paths = valid
216
+ if len(self.paths) == 0:
217
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
218
+ random.shuffle(self.paths)
219
+
220
+ def __len__(self):
221
+ return len(self.paths)
222
+
223
+ def __getitem__(self, idx):
224
+ p = self.paths[idx % len(self.paths)]
225
+ with Image.open(p) as img:
226
+ img = img.convert("RGB")
227
+ if not resize_long_side or resize_long_side <= 0:
228
+ return img
229
+ w, h = img.size
230
+ long = max(w, h)
231
+ if long <= resize_long_side:
232
+ return img
233
+ scale = resize_long_side / float(long)
234
+ new_w = int(round(w * scale))
235
+ new_h = int(round(h * scale))
236
+ return img.resize((new_w, new_h), Image.BICUBIC)
237
+
238
+ def random_crop(img, sz):
239
+ w, h = img.size
240
+ if w < sz or h < sz:
241
+ img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
242
+ x = random.randint(0, max(1, img.width - sz))
243
+ y = random.randint(0, max(1, img.height - sz))
244
+ return img.crop((x, y, x + sz, y + sz))
245
+
246
+ tfm = transforms.Compose([
247
+ transforms.ToTensor(),
248
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
249
+ ])
250
+
251
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
252
+ if len(dataset) < batch_size:
253
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
254
+
255
+ def collate_fn(batch):
256
+ imgs = []
257
+ for img in batch:
258
+ img = random_crop(img, high_resolution)
259
+ imgs.append(tfm(img))
260
+ return torch.stack(imgs)
261
+
262
+ dataloader = DataLoader(
263
+ dataset,
264
+ batch_size=batch_size,
265
+ shuffle=True,
266
+ collate_fn=collate_fn,
267
+ num_workers=num_workers,
268
+ pin_memory=True,
269
+ drop_last=True
270
+ )
271
+
272
+ # --------------------------- Оптимизатор ---------------------------
273
+ def get_param_groups(module, weight_decay=0.001):
274
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
275
+ decay_params, no_decay_params = [], []
276
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
277
+ if not p.requires_grad:
278
+ continue
279
+ if any(nd in n for nd in no_decay):
280
+ no_decay_params.append(p)
281
+ else:
282
+ decay_params.append(p)
283
+ return [
284
+ {"params": decay_params, "weight_decay": weight_decay},
285
+ {"params": no_decay_params, "weight_decay": 0.0},
286
+ ]
287
+
288
+ def get_param_groups(module, weight_decay=0.001):
289
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
290
+ decay_params, no_decay_params = [], []
291
+ for n, p in module.named_parameters():
292
+ if not p.requires_grad:
293
+ continue
294
+ n_l = n.lower()
295
+ if any(t in n_l for t in no_decay_tokens):
296
+ no_decay_params.append(p)
297
+ else:
298
+ decay_params.append(p)
299
+ return [
300
+ {"params": decay_params, "weight_decay": weight_decay},
301
+ {"params": no_decay_params, "weight_decay": 0.0},
302
+ ]
303
+
304
+ def create_optimizer(name, param_groups):
305
+ if name == "adam8bit":
306
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
307
+ raise ValueError(name)
308
+
309
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
310
+ optimizer = create_optimizer(optimizer_type, param_groups)
311
+
312
+ # --------------------------- LR schedule ---------------------------
313
+ batches_per_epoch = len(dataloader)
314
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
315
+ total_steps = steps_per_epoch * num_epochs
316
+
317
+ def lr_lambda(step):
318
+ if not use_decay:
319
+ return 1.0
320
+ x = float(step) / float(max(1, total_steps))
321
+ warmup = float(warmup_percent)
322
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
323
+ if x < warmup:
324
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
325
+ decay_ratio = (x - warmup) / (1.0 - warmup)
326
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
327
+
328
+ scheduler = LambdaLR(optimizer, lr_lambda)
329
+
330
+ # Подготовка
331
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
332
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
333
+
334
+ # fdl
335
+ fdl_loss = FDL_loss()
336
+ fdl_loss = fdl_loss.to(accelerator.device)
337
+
338
+ # --------------------------- LPIPS и вспомогательные ---------------------------
339
+ _lpips_net = None
340
+ def _get_lpips():
341
+ global _lpips_net
342
+ if _lpips_net is None:
343
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
344
+ return _lpips_net
345
+
346
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
347
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
348
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
349
+ C = x.shape[1]
350
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
351
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
352
+ gx = F.conv2d(x, kx, padding=1, groups=C)
353
+ gy = F.conv2d(x, ky, padding=1, groups=C)
354
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
355
+
356
+ class MedianLossNormalizer:
357
+ def __init__(self, desired_ratios: dict, window_steps: int):
358
+ s = sum(desired_ratios.values())
359
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
360
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
361
+ self.window = window_steps
362
+
363
+ def update_and_total(self, abs_losses: dict):
364
+ for k, v in abs_losses.items():
365
+ if k in self.buffers:
366
+ self.buffers[k].append(float(v.detach().abs().cpu()))
367
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
368
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
369
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
370
+ return total, coeffs, meds
371
+
372
+ if full_training and not train_decoder_only:
373
+ loss_ratios["kl"] = float(kl_ratio)
374
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
375
+
376
+ # --------------------------- Сэмплы ---------------------------
377
+ @torch.no_grad()
378
+ def get_fixed_samples(n=3):
379
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
380
+ pil_imgs = [dataset[i] for i in idx]
381
+ tensors = []
382
+ for img in pil_imgs:
383
+ img = random_crop(img, high_resolution)
384
+ tensors.append(tfm(img))
385
+ return torch.stack(tensors).to(accelerator.device, dtype)
386
+
387
+ fixed_samples = get_fixed_samples()
388
+
389
+ @torch.no_grad()
390
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
391
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
392
+ return Image.fromarray(arr)
393
+
394
+
395
+ @torch.no_grad()
396
+ def generate_and_save_samples(step=None):
397
+ try:
398
+ temp_vae = accelerator.unwrap_model(vae).eval()
399
+ lpips_net = _get_lpips()
400
+ with torch.no_grad():
401
+ orig_high = fixed_samples
402
+ orig_low = F.interpolate(
403
+ orig_high,
404
+ size=(model_resolution, model_resolution),
405
+ mode="bilinear",
406
+ align_corners=False
407
+ )
408
+ model_dtype = next(temp_vae.parameters()).dtype
409
+ orig_low = orig_low.to(dtype=model_dtype)
410
+
411
+ # Encode/decode с учётом видео-режима
412
+ if is_video_vae(temp_vae):
413
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
414
+ enc = temp_vae.encode(x_in)
415
+ latents_mean = enc.latent_dist.mean
416
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
417
+ rec = dec.squeeze(2) # [B,3,H,W]
418
+ else:
419
+ enc = temp_vae.encode(orig_low)
420
+ latents_mean = enc.latent_dist.mean
421
+ rec = temp_vae.decode(latents_mean).sample
422
+
423
+ # Подгон размеров, если надо
424
+ if rec.shape[-2:] != orig_high.shape[-2:]:
425
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
426
+
427
+ # Сохраняем все real/decoded
428
+ for i in range(rec.shape[0]):
429
+ real_img = _to_pil_uint8(orig_high[i])
430
+ dec_img = _to_pil_uint8(rec[i])
431
+ real_img.save(f"{generated_folder}/sample_real_{i}.jpg", quality=95)
432
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.jpg", quality=95)
433
+
434
+ # LPIPS
435
+ lpips_scores = []
436
+ for i in range(rec.shape[0]):
437
+ orig_full = orig_high[i:i+1].to(torch.float32)
438
+ rec_full = rec[i:i+1].to(torch.float32)
439
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
440
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
441
+ lpips_val = lpips_net(orig_full, rec_full).item()
442
+ lpips_scores.append(lpips_val)
443
+ avg_lpips = float(np.mean(lpips_scores))
444
+
445
+ # W&B логирование
446
+ if use_wandb and accelerator.is_main_process:
447
+ log_data = {"lpips_mean": avg_lpips}
448
+ for i in range(rec.shape[0]):
449
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.jpg", caption=f"real_{i}")
450
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.jpg", caption=f"decoded_{i}")
451
+ wandb.log(log_data, step=step)
452
+
453
+ finally:
454
+ gc.collect()
455
+ torch.cuda.empty_cache()
456
+
457
+
458
+ if accelerator.is_main_process and save_model:
459
+ print("Генерация сэмплов до старта обучения...")
460
+ generate_and_save_samples(0)
461
+
462
+ accelerator.wait_for_everyone()
463
+
464
+ # --------------------------- Тренировка ---------------------------
465
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
466
+ global_step = 0
467
+ min_loss = float("inf")
468
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
469
+
470
+ for epoch in range(num_epochs):
471
+ vae.train()
472
+ batch_losses, batch_grads = [], []
473
+ track_losses = {k: [] for k in loss_ratios.keys()}
474
+
475
+ for imgs in dataloader:
476
+ with accelerator.accumulate(vae):
477
+ imgs = imgs.to(accelerator.device)
478
+
479
+ if high_resolution != model_resolution:
480
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
481
+ else:
482
+ imgs_low = imgs
483
+
484
+ model_dtype = next(vae.parameters()).dtype
485
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
486
+
487
+ # QWEN: encode/decode с T=1
488
+ if is_video_vae(vae):
489
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
490
+ enc = vae.encode(x_in)
491
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
492
+ dec = vae.decode(latents).sample # [B,3,1,H,W]
493
+ rec = dec.squeeze(2) # [B,3,H,W]
494
+ else:
495
+ enc = vae.encode(imgs_low_model)
496
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
497
+ rec = vae.decode(latents).sample
498
+
499
+ if rec.shape[-2:] != imgs.shape[-2:]:
500
+ rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
501
+
502
+ rec_f32 = rec.to(torch.float32)
503
+ imgs_f32 = imgs.to(torch.float32)
504
+
505
+ abs_losses = {
506
+ "mae": F.l1_loss(rec_f32, imgs_f32),
507
+ "mse": F.mse_loss(rec_f32, imgs_f32),
508
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
509
+ "fdl": fdl_loss(rec_f32, imgs_f32),
510
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
511
+ }
512
+
513
+ if full_training and not train_decoder_only:
514
+ mean = enc.latent_dist.mean
515
+ logvar = enc.latent_dist.logvar
516
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
517
+ abs_losses["kl"] = kl
518
+ else:
519
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
520
+
521
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
522
+
523
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
524
+ raise RuntimeError("NaN/Inf loss")
525
+
526
+ accelerator.backward(total_loss)
527
+
528
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
529
+ if accelerator.sync_gradients:
530
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
531
+ optimizer.step()
532
+ scheduler.step()
533
+ optimizer.zero_grad(set_to_none=True)
534
+ global_step += 1
535
+ progress.update(1)
536
+
537
+ if accelerator.is_main_process:
538
+ try:
539
+ current_lr = optimizer.param_groups[0]["lr"]
540
+ except Exception:
541
+ current_lr = scheduler.get_last_lr()[0]
542
+
543
+ batch_losses.append(total_loss.detach().item())
544
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
545
+ for k, v in abs_losses.items():
546
+ track_losses[k].append(float(v.detach().item()))
547
+
548
+ if use_wandb and accelerator.sync_gradients:
549
+ log_dict = {
550
+ "total_loss": float(total_loss.detach().item()),
551
+ "learning_rate": current_lr,
552
+ "epoch": epoch,
553
+ "grad_norm": batch_grads[-1],
554
+ }
555
+ for k, v in abs_losses.items():
556
+ log_dict[f"loss_{k}"] = float(v.detach().item())
557
+ for k in coeffs:
558
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
559
+ log_dict[f"median_{k}"] = float(meds[k])
560
+ wandb.log(log_dict, step=global_step)
561
+
562
+ if global_step > 0 and global_step % sample_interval == 0:
563
+ if accelerator.is_main_process:
564
+ generate_and_save_samples(global_step)
565
+ accelerator.wait_for_everyone()
566
+
567
+ n_micro = sample_interval * gradient_accumulation_steps
568
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
569
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
570
+
571
+ if accelerator.is_main_process:
572
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
573
+ if save_model and avg_loss < min_loss * save_barrier:
574
+ min_loss = avg_loss
575
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
576
+ if use_wandb:
577
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
578
+
579
+ if accelerator.is_main_process:
580
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
581
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
582
+ if use_wandb:
583
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
584
+
585
+ # --------------------------- Финальное сохранение ---------------------------
586
+ if accelerator.is_main_process:
587
+ print("Training finished – saving final model")
588
+ if save_model:
589
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
590
+
591
+ accelerator.free_memory()
592
+ if torch.distributed.is_initialized():
593
+ torch.distributed.destroy_process_group()
594
+ print("Готово!")
train_vae.py → train_vae_old.py RENAMED
File without changes
vae/config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_class_name": "AsymmetricAutoencoderKL",
3
  "_diffusers_version": "0.35.2",
4
- "_name_or_path": "vae3",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
 
1
  {
2
  "_class_name": "AsymmetricAutoencoderKL",
3
  "_diffusers_version": "0.35.2",
4
+ "_name_or_path": "vae",
5
  "act_fn": "silu",
6
  "block_out_channels": [
7
  128,
vae/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:95458d45a65f10704a9ce7c75176dcead07708273ae647a40acfb8d9ebc54106
3
  size 382598708
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65f78c989d2ae138dc170af3dd84ad35d16427a1ca1045468e1f5aff7a8bf682
3
  size 382598708