babkasotona commited on
Commit
efd88e3
·
verified ·
1 Parent(s): 6bbb9ef

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sdxs_vae-Copy1.py +627 -0
train_sdxs_vae-Copy1.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
+ from collections import deque
29
+
30
+ # --------------------------- Параметры ---------------------------
31
+ ds_path = "/workspace/d23"
32
+ project = "vae2"
33
+ batch_size = 1
34
+ base_learning_rate = 6e-6
35
+ min_learning_rate = 7e-7
36
+ num_epochs = 1
37
+ sample_interval_share = 30
38
+ use_wandb = False
39
+ save_model = True
40
+ use_decay = True
41
+ optimizer_type = "adam8bit"
42
+ dtype = torch.float32
43
+
44
+ model_resolution = 512
45
+ high_resolution = 1024
46
+ limit = 0
47
+ save_barrier = 1.3
48
+ warmup_percent = 0.005
49
+ beta2 = 0.997
50
+ eps = 1e-8
51
+ clip_grad_norm = 2.0 # Увеличили с 1.0 для лучшей прорисовки деталей
52
+ mixed_precision = "no"
53
+ gradient_accumulation_steps = 1
54
+ generated_folder = "samples"
55
+ save_as = "vae3"
56
+ num_workers = 0
57
+ device = None
58
+ torch.backends.cuda.matmul.allow_tf32 = True
59
+ torch.backends.cudnn.allow_tf32 = True
60
+ # Включение Flash Attention 2/SDPA #MAX_JOBS=4 pip install flash-attn --no-build-isolation
61
+ torch.backends.cuda.enable_flash_sdp(True)
62
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
63
+ torch.backends.cuda.enable_math_sdp(False)
64
+
65
+ # --- Режимы обучения ---
66
+ # QWEN: учим только декодер
67
+ train_decoder_only = True
68
+ train_up_only = False
69
+ full_training = False # если True — учим весь VAE и добавляем KL (ниже)
70
+ kl_ratio = 0.0
71
+
72
+ # Доли лоссов
73
+ loss_ratios = {
74
+ "lpips": 0.60,
75
+ "fdl" : 0.15,
76
+ "mse": 0.00,
77
+ "mae": 0.20,
78
+ "dssim": 0.05, # <-- Заменён edge на DSSIM
79
+ "kl": 0.00,
80
+ }
81
+ median_coeff_steps = 250
82
+
83
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
84
+
85
+ # QWEN: конфиг загрузки модели
86
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
87
+
88
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
89
+
90
+ accelerator = Accelerator(
91
+ mixed_precision=mixed_precision,
92
+ gradient_accumulation_steps=gradient_accumulation_steps
93
+ )
94
+ device = accelerator.device
95
+
96
+ # reproducibility
97
+ seed = int(datetime.now().strftime("%Y%m%d")) + 42
98
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
99
+ torch.backends.cudnn.benchmark = False
100
+
101
+ # --------------------------- WandB ---------------------------
102
+ if use_wandb and accelerator.is_main_process:
103
+ wandb.init(project=project, config={
104
+ "batch_size": batch_size,
105
+ "base_learning_rate": base_learning_rate,
106
+ "num_epochs": num_epochs,
107
+ "optimizer_type": optimizer_type,
108
+ "model_resolution": model_resolution,
109
+ "high_resolution": high_resolution,
110
+ "gradient_accumulation_steps": gradient_accumulation_steps,
111
+ "train_decoder_only": train_decoder_only,
112
+ "full_training": full_training,
113
+ "kl_ratio": kl_ratio,
114
+ "vae_kind": vae_kind,
115
+ })
116
+
117
+ # --------------------------- VAE ---------------------------
118
+ def get_core_model(model):
119
+ m = model
120
+ # если модель уже обёрнута torch.compile
121
+ if hasattr(m, "_orig_mod"):
122
+ m = m._orig_mod
123
+ return m
124
+
125
+ def is_video_vae(model) -> bool:
126
+ # WAN/Qwen — это видео-VAEs
127
+ if vae_kind in ("wan", "qwen"):
128
+ return True
129
+ # fallback по структуре (если понадобится)
130
+ try:
131
+ core = get_core_model(model)
132
+ enc = getattr(core, "encoder", None)
133
+ conv_in = getattr(enc, "conv_in", None)
134
+ w = getattr(conv_in, "weight", None)
135
+ if isinstance(w, torch.nn.Parameter):
136
+ return w.ndim == 5
137
+ except Exception:
138
+ pass
139
+ return False
140
+
141
+ # загрузка
142
+ if vae_kind == "qwen":
143
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
144
+ else:
145
+ if vae_kind == "wan":
146
+ vae = AutoencoderKLWan.from_pretrained(project)
147
+ else:
148
+ # старое поведение (пример)
149
+ if model_resolution==high_resolution:
150
+ vae = AutoencoderKL.from_pretrained(project)
151
+ else:
152
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
153
+
154
+ vae = vae.to(dtype)
155
+
156
+ # torch.compile (опционально)
157
+ if hasattr(torch, "compile"):
158
+ try:
159
+ vae = torch.compile(vae)
160
+ except Exception as e:
161
+ print(f"[WARN] torch.compile failed: {e}")
162
+
163
+ # --------------------------- Freeze/Unfreeze ---------------------------
164
+ core = get_core_model(vae)
165
+
166
+ for p in core.parameters():
167
+ p.requires_grad = False
168
+
169
+ unfrozen_param_names = []
170
+
171
+ if full_training and not train_decoder_only:
172
+ for name, p in core.named_parameters():
173
+ p.requires_grad = True
174
+ unfrozen_param_names.append(name)
175
+ loss_ratios["kl"] = float(kl_ratio)
176
+ trainable_module = core
177
+ else:
178
+ # учим только 0-й блок декодера + post_quant_conv
179
+ if hasattr(core, "decoder"):
180
+ if train_up_only:#hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
181
+ # --- только 0-й up_block ---
182
+ for name, p in core.decoder.up_blocks[0].named_parameters():
183
+ p.requires_grad = True
184
+ unfrozen_param_names.append(f"{name}")
185
+ else:
186
+ print("Decoder — fallback to full decoder")
187
+ for name, p in core.decoder.named_parameters():
188
+ p.requires_grad = True
189
+ unfrozen_param_names.append(f"decoder.{name}")
190
+ if hasattr(core, "post_quant_conv"):
191
+ for name, p in core.post_quant_conv.named_parameters():
192
+ p.requires_grad = True
193
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
194
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
195
+
196
+
197
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
198
+ for nm in unfrozen_param_names[:10]:
199
+ print(" ", nm)
200
+
201
+ # --------------------------- Датасет ---------------------------
202
+ from torch.utils.data import Dataset
203
+ from PIL import Image, UnidentifiedImageError
204
+ import random
205
+ import torchvision.transforms as transforms
206
+ import os
207
+
208
+ class PngFolderDataset(Dataset):
209
+ def __init__(self, root_dir, resolution=1024, min_exts=('.png',), limit=0):
210
+ self.resolution = resolution
211
+ self.paths = []
212
+
213
+ for root, _, files in os.walk(root_dir):
214
+ for f in files:
215
+ if f.lower().endswith(tuple(ext.lower() for ext in min_exts)):
216
+ self.paths.append(os.path.join(root, f))
217
+
218
+ if limit:
219
+ self.paths = self.paths[:limit]
220
+
221
+ # фильтруем недопустимые картинки
222
+ valid = []
223
+ for p in self.paths:
224
+ try:
225
+ with Image.open(p) as img:
226
+ img.verify() # только метаданные
227
+ w, h = img.size
228
+ if w < resolution or h < resolution:
229
+ continue
230
+ valid.append(p)
231
+ except (OSError, UnidentifiedImageError):
232
+ continue
233
+
234
+ self.paths = valid
235
+ if not self.paths:
236
+ raise RuntimeError("No valid images found")
237
+
238
+ random.shuffle(self.paths)
239
+ self.transform = transforms.ToTensor() # конвертирует сразу [0,1] float32
240
+
241
+ def __len__(self):
242
+ return len(self.paths)
243
+
244
+ def __getitem__(self, idx):
245
+ p = self.paths[idx % len(self.paths)]
246
+ with Image.open(p) as img:
247
+ img = img.convert("RGB")
248
+ return img
249
+
250
+ def random_crop(img, sz):
251
+ w, h = img.size
252
+ x = random.randint(0, max(1, img.width - sz))
253
+ y = random.randint(0, max(1, img.height - sz))
254
+ return img.crop((x, y, x + sz, y + sz))
255
+
256
+ tfm = transforms.Compose([
257
+ transforms.ToTensor(),
258
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
259
+ ])
260
+
261
+ def collate_fn(batch):
262
+ imgs = []
263
+ for img in batch:
264
+ img = random_crop(img, high_resolution)
265
+ imgs.append(tfm(img))
266
+ return torch.stack(imgs)
267
+
268
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
269
+ print("len(dataset)",len(dataset))
270
+ if len(dataset) < batch_size:
271
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
272
+
273
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, drop_last=True)
274
+
275
+ # --------------------------- Оптимизатор ---------------------------
276
+ def get_param_groups(module, weight_decay=0.001):
277
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
278
+ decay_params, no_decay_params = [], []
279
+ for n, p in module.named_parameters():
280
+ if not p.requires_grad:
281
+ continue
282
+ n_l = n.lower()
283
+ if any(t in n_l for t in no_decay_tokens):
284
+ no_decay_params.append(p)
285
+ else:
286
+ decay_params.append(p)
287
+ return [
288
+ {"params": decay_params, "weight_decay": weight_decay},
289
+ {"params": no_decay_params, "weight_decay": 0.0},
290
+ ]
291
+
292
+ def create_optimizer(name, param_groups):
293
+ if name == "adam8bit":
294
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
295
+ raise ValueError(name)
296
+
297
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
298
+ optimizer = create_optimizer(optimizer_type, param_groups)
299
+
300
+ # --------------------------- LR schedule ---------------------------
301
+ batches_per_epoch = len(dataloader)
302
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
303
+ total_steps = steps_per_epoch * num_epochs
304
+
305
+ def lr_lambda(step):
306
+ if not use_decay:
307
+ return 1.0
308
+ x = float(step) / float(max(1, total_steps))
309
+ warmup = float(warmup_percent)
310
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
311
+ if x < warmup:
312
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
313
+ decay_ratio = (x - warmup) / (1.0 - warmup)
314
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
315
+
316
+ scheduler = LambdaLR(optimizer, lr_lambda)
317
+
318
+ # Подготовка
319
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
320
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
321
+
322
+ # fdl
323
+ fdl_loss = FDL_loss()
324
+ fdl_loss = fdl_loss.to(accelerator.device)
325
+
326
+ # --------------------------- LPIPS и DSSIM ---------------------------
327
+ _lpips_net = None
328
+ def _get_lpips():
329
+ global _lpips_net
330
+ if _lpips_net is None:
331
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
332
+ return _lpips_net
333
+
334
+ # --- DSSIM Implementation (замена Sobel) ---
335
+ def _gaussian_kernel(window_size, sigma, device, dtype):
336
+ coords = torch.arange(window_size, dtype=dtype, device=device) - (window_size - 1) / 2
337
+ k = torch.exp(-coords**2 / (2 * sigma**2))
338
+ return k / k.sum()
339
+
340
+ def _ssim(x, y, window_size=11, sigma=1.5):
341
+ """Вычисляет SSIM для тензоров в диапазоне [-1, 1]"""
342
+ C = x.shape[1]
343
+ k = _gaussian_kernel(window_size, sigma, x.device, x.dtype)
344
+ win = (k.view(1, 1, -1, 1) * k.view(1, 1, 1, -1)).expand(C, 1, window_size, window_size).contiguous()
345
+
346
+ mu1 = F.conv2d(x, win, padding=window_size//2, groups=C)
347
+ mu2 = F.conv2d(y, win, padding=window_size//2, groups=C)
348
+ mu1_sq = mu1.pow(2)
349
+ mu2_sq = mu2.pow(2)
350
+ mu1_mu2 = mu1 * mu2
351
+
352
+ sigma1_sq = F.conv2d(x*x, win, padding=window_size//2, groups=C) - mu1_sq
353
+ sigma2_sq = F.conv2d(y*y, win, padding=window_size//2, groups=C) - mu2_sq
354
+ sigma12 = F.conv2d(x*y, win, padding=window_size//2, groups=C) - mu1_mu2
355
+
356
+ L = 2.0 # диапазон для [-1, 1] -> 2
357
+ C1 = (0.01 * L)**2
358
+ C2 = (0.03 * L)**2
359
+
360
+ num = (2*mu1_mu2 + C1) * (2*sigma12 + C2)
361
+ den = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
362
+ ssim_map = num / den
363
+ return ssim_map.mean()
364
+
365
+ def dssim_loss(x, y):
366
+ """DSSIM = 1 - SSIM. Отлично сохраняет локальную структуру штрихов."""
367
+ return 1.0 - _ssim(x, y)
368
+
369
+ class MedianLossNormalizer:
370
+ def __init__(self, desired_ratios: dict, window_steps: int):
371
+ s = sum(desired_ratios.values())
372
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
373
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
374
+ self.window = window_steps
375
+
376
+ def update_and_total(self, abs_losses: dict):
377
+ for k, v in abs_losses.items():
378
+ if k in self.buffers:
379
+ self.buffers[k].append(float(v.detach().abs().cpu()))
380
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
381
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
382
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
383
+ return total, coeffs, meds
384
+
385
+ if full_training and not train_decoder_only:
386
+ loss_ratios["kl"] = float(kl_ratio)
387
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
388
+
389
+ # --------------------------- Сэмплы ---------------------------
390
+ @torch.no_grad()
391
+ def get_fixed_samples(n=3):
392
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
393
+ pil_imgs = [dataset[i] for i in idx]
394
+ tensors = []
395
+ for img in pil_imgs:
396
+ img = random_crop(img, high_resolution)
397
+ tensors.append(tfm(img))
398
+ return torch.stack(tensors).to(accelerator.device, dtype)
399
+
400
+ fixed_samples = get_fixed_samples()
401
+
402
+ @torch.no_grad()
403
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
404
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
405
+ return Image.fromarray(arr)
406
+
407
+
408
+ @torch.no_grad()
409
+ def generate_and_save_samples(step=None):
410
+ try:
411
+ #temp_vae = accelerator.unwrap_model(vae).eval()
412
+ if hasattr(vae, "module"):
413
+ # Если это DDP или DistributedDataParallel
414
+ unwrapped_vae = vae.module
415
+ else:
416
+ unwrapped_vae = vae
417
+
418
+ # Если испол��зовался torch.compile, достаем оригинал
419
+ if hasattr(unwrapped_vae, "_orig_mod"):
420
+ temp_vae = unwrapped_vae._orig_mod
421
+ else:
422
+ temp_vae = unwrapped_vae
423
+
424
+ temp_vae = temp_vae.eval()
425
+ lpips_net = _get_lpips()
426
+ with torch.no_grad():
427
+ orig_high = fixed_samples
428
+ orig_low = F.interpolate(
429
+ orig_high,
430
+ size=(model_resolution, model_resolution),
431
+ mode="bilinear",
432
+ align_corners=False
433
+ )
434
+ model_dtype = next(temp_vae.parameters()).dtype
435
+ orig_low = orig_low.to(dtype=model_dtype)
436
+
437
+ # Encode/decode с учётом видео-режима
438
+ if is_video_vae(temp_vae):
439
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
440
+ enc = temp_vae.encode(x_in)
441
+ latents_mean = enc.latent_dist.mean
442
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
443
+ rec = dec.squeeze(2) # [B,3,H,W]
444
+ else:
445
+ enc = temp_vae.encode(orig_low)
446
+ latents_mean = enc.latent_dist.mean
447
+ rec = temp_vae.decode(latents_mean).sample
448
+
449
+ # Подгон размеров, если надо
450
+ #if rec.shape[-2:] != orig_high.shape[-2:]:
451
+ # rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
452
+
453
+ # Сохраняем все real/decoded
454
+ for i in range(rec.shape[0]):
455
+ real_img = _to_pil_uint8(orig_high[i])
456
+ dec_img = _to_pil_uint8(rec[i])
457
+ real_img.save(f"{generated_folder}/sample_real_{i}.png")
458
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.png")
459
+
460
+ # LPIPS
461
+ lpips_scores = []
462
+ for i in range(rec.shape[0]):
463
+ orig_full = orig_high[i:i+1].to(torch.float32)
464
+ rec_full = rec[i:i+1].to(torch.float32)
465
+ #if rec_full.shape[-2:] != orig_full.shape[-2:]:
466
+ # rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
467
+ lpips_val = lpips_net(orig_full, rec_full).item()
468
+ lpips_scores.append(lpips_val)
469
+ avg_lpips = float(np.mean(lpips_scores))
470
+
471
+ # W&B логирование
472
+ if use_wandb and accelerator.is_main_process:
473
+ log_data = {"lpips_mean": avg_lpips}
474
+ for i in range(rec.shape[0]):
475
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.png", caption=f"real_{i}")
476
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.png", caption=f"decoded_{i}")
477
+ wandb.log(log_data, step=step)
478
+
479
+ finally:
480
+ gc.collect()
481
+ torch.cuda.empty_cache()
482
+
483
+
484
+ if accelerator.is_main_process and save_model:
485
+ print("Генерация сэмплов до старта обучения...")
486
+ generate_and_save_samples(0)
487
+
488
+ accelerator.wait_for_everyone()
489
+
490
+ # --------------------------- Тренировка ---------------------------
491
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
492
+ global_step = 0
493
+ min_loss = float("inf")
494
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
495
+
496
+ for epoch in range(num_epochs):
497
+ vae.train()
498
+ batch_losses, batch_grads = [], []
499
+ track_losses = {k: [] for k in loss_ratios.keys()}
500
+
501
+ for imgs in dataloader:
502
+ with accelerator.accumulate(vae):
503
+ imgs = imgs.to(accelerator.device)
504
+
505
+ if high_resolution != model_resolution:
506
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution),mode="area") # mode="bilinear", align_corners=False)
507
+ else:
508
+ imgs_low = imgs
509
+
510
+ model_dtype = next(vae.parameters()).dtype
511
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
512
+
513
+ # Вместо: current_vae = accelerator.unwrap_model(vae)
514
+ unwrapped = vae.module if hasattr(vae, "module") else vae
515
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
516
+
517
+
518
+ # QWEN: encode/decode с T=1
519
+ if is_video_vae(current_vae):
520
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
521
+ enc = current_vae.encode(x_in)
522
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
523
+ dec = current_vae.decode(latents).sample # [B,3,1,H,W]
524
+ rec = dec.squeeze(2) # [B,3,H,W]
525
+ else:
526
+ enc = current_vae.encode(imgs_low_model)
527
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
528
+ rec = current_vae.decode(latents).sample
529
+
530
+ rec_f32 = rec.to(torch.float32)
531
+ imgs_f32 = imgs.to(torch.float32)
532
+
533
+ # --- Расчёт лоссов ---
534
+ abs_losses = {
535
+ "mae": F.l1_loss(rec_f32, imgs_f32),
536
+ "mse": torch.tensor(0.0, device=accelerator.device, dtype=torch.float32), # Отключён, экономим вычисления
537
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
538
+ "fdl": fdl_loss(rec_f32, imgs_f32),
539
+ "dssim": dssim_loss(rec_f32, imgs_f32), # <-- DSSIM вместо Edge
540
+ }
541
+
542
+ if full_training and not train_decoder_only:
543
+ mean = enc.latent_dist.mean
544
+ logvar = enc.latent_dist.logvar
545
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
546
+ abs_losses["kl"] = kl
547
+ else:
548
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
549
+
550
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
551
+
552
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
553
+ raise RuntimeError("NaN/Inf loss")
554
+
555
+ accelerator.backward(total_loss)
556
+
557
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
558
+ if accelerator.sync_gradients:
559
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
560
+ optimizer.step()
561
+ scheduler.step()
562
+ optimizer.zero_grad(set_to_none=True)
563
+ global_step += 1
564
+ progress.update(1)
565
+
566
+ if accelerator.is_main_process:
567
+ try:
568
+ current_lr = optimizer.param_groups[0]["lr"]
569
+ except Exception:
570
+ current_lr = scheduler.get_last_lr()[0]
571
+
572
+ batch_losses.append(total_loss.detach().item())
573
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
574
+ for k, v in abs_losses.items():
575
+ track_losses[k].append(float(v.detach().item()))
576
+
577
+ if use_wandb and accelerator.sync_gradients:
578
+ log_dict = {
579
+ "total_loss": float(total_loss.detach().item()),
580
+ "learning_rate": current_lr,
581
+ "epoch": epoch,
582
+ "grad_norm": batch_grads[-1],
583
+ }
584
+ for k, v in abs_losses.items():
585
+ log_dict[f"loss_{k}"] = float(v.detach().item())
586
+ for k in coeffs:
587
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
588
+ log_dict[f"median_{k}"] = float(meds[k])
589
+ wandb.log(log_dict, step=global_step)
590
+
591
+ if global_step > 0 and global_step % sample_interval == 0:
592
+ if accelerator.is_main_process:
593
+ generate_and_save_samples(global_step)
594
+ accelerator.wait_for_everyone()
595
+
596
+ n_micro = sample_interval * gradient_accumulation_steps
597
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
598
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
599
+
600
+ if accelerator.is_main_process:
601
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
602
+ if save_model and avg_loss < min_loss * save_barrier:
603
+ min_loss = avg_loss
604
+ unwrapped = vae.module if hasattr(vae, "module") else vae
605
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
606
+ current_vae.save_pretrained(save_as)
607
+ if use_wandb:
608
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
609
+
610
+ if accelerator.is_main_process:
611
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
612
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
613
+ if use_wandb:
614
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
615
+
616
+ # --------------------------- Финальное сохранение ---------------------------
617
+ if accelerator.is_main_process:
618
+ print("Training finished – saving final model")
619
+ if save_model:
620
+ unwrapped = vae.module if hasattr(vae, "module") else vae
621
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
622
+ current_vae.save_pretrained(save_as)
623
+
624
+ accelerator.free_memory()
625
+ if torch.distributed.is_initialized():
626
+ torch.distributed.destroy_process_group()
627
+ print("Готово!")