babkasotona commited on
Commit
4f8199d
·
verified ·
1 Parent(s): efd88e3

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sdxs_vae.py +45 -27
train_sdxs_vae.py CHANGED
@@ -26,7 +26,6 @@ import wandb
26
  import lpips # pip install lpips
27
  from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
  from collections import deque
29
- import torch.nn.functional as F
30
 
31
  # --------------------------- Параметры ---------------------------
32
  ds_path = "/workspace/d23"
@@ -49,7 +48,7 @@ save_barrier = 1.3
49
  warmup_percent = 0.005
50
  beta2 = 0.997
51
  eps = 1e-8
52
- clip_grad_norm = 1.0
53
  mixed_precision = "no"
54
  gradient_accumulation_steps = 1
55
  generated_folder = "samples"
@@ -72,10 +71,11 @@ kl_ratio = 0.0
72
 
73
  # Доли лоссов
74
  loss_ratios = {
75
- "lpips": 0.70,#0.50,
76
- "fdl" : 0.10,#0.25,
77
- "mse": 0.10,
78
- "mae": 0.10,
 
79
  "kl": 0.00,
80
  }
81
  median_coeff_steps = 250
@@ -273,21 +273,6 @@ if len(dataset) < batch_size:
273
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, drop_last=True)
274
 
275
  # --------------------------- Оптимизатор ---------------------------
276
- def get_param_groups(module, weight_decay=0.001):
277
- no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
278
- decay_params, no_decay_params = [], []
279
- for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
280
- if not p.requires_grad:
281
- continue
282
- if any(nd in n for nd in no_decay):
283
- no_decay_params.append(p)
284
- else:
285
- decay_params.append(p)
286
- return [
287
- {"params": decay_params, "weight_decay": weight_decay},
288
- {"params": no_decay_params, "weight_decay": 0.0},
289
- ]
290
-
291
  def get_param_groups(module, weight_decay=0.001):
292
  no_decay_tokens = ("bias", "norm", "rms", "layernorm")
293
  decay_params, no_decay_params = [], []
@@ -338,7 +323,7 @@ trainable_params = [p for p in vae.parameters() if p.requires_grad]
338
  fdl_loss = FDL_loss()
339
  fdl_loss = fdl_loss.to(accelerator.device)
340
 
341
- # --------------------------- LPIPS и вспомогательные ---------------------------
342
  _lpips_net = None
343
  def _get_lpips():
344
  global _lpips_net
@@ -346,7 +331,41 @@ def _get_lpips():
346
  _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
347
  return _lpips_net
348
 
349
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  class MedianLossNormalizer:
351
  def __init__(self, desired_ratios: dict, window_steps: int):
352
  s = sum(desired_ratios.values())
@@ -508,17 +527,16 @@ for epoch in range(num_epochs):
508
  latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
509
  rec = current_vae.decode(latents).sample
510
 
511
- #if rec.shape[-2:] != imgs.shape[-2:]:
512
- # rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
513
-
514
  rec_f32 = rec.to(torch.float32)
515
  imgs_f32 = imgs.to(torch.float32)
516
 
 
517
  abs_losses = {
518
  "mae": F.l1_loss(rec_f32, imgs_f32),
519
- "mse": F.mse_loss(rec_f32, imgs_f32),
520
  "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
521
  "fdl": fdl_loss(rec_f32, imgs_f32),
 
522
  }
523
 
524
  if full_training and not train_decoder_only:
 
26
  import lpips # pip install lpips
27
  from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
  from collections import deque
 
29
 
30
  # --------------------------- Параметры ---------------------------
31
  ds_path = "/workspace/d23"
 
48
  warmup_percent = 0.005
49
  beta2 = 0.997
50
  eps = 1e-8
51
+ clip_grad_norm = 2.0 # Увеличили с 1.0 для лучшей прорисовки деталей
52
  mixed_precision = "no"
53
  gradient_accumulation_steps = 1
54
  generated_folder = "samples"
 
71
 
72
  # Доли лоссов
73
  loss_ratios = {
74
+ "lpips": 0.60,
75
+ "fdl" : 0.15,
76
+ "mse": 0.00,
77
+ "mae": 0.20,
78
+ "dssim": 0.05, # <-- Заменён edge на DSSIM
79
  "kl": 0.00,
80
  }
81
  median_coeff_steps = 250
 
273
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, drop_last=True)
274
 
275
  # --------------------------- Оптимизатор ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  def get_param_groups(module, weight_decay=0.001):
277
  no_decay_tokens = ("bias", "norm", "rms", "layernorm")
278
  decay_params, no_decay_params = [], []
 
323
  fdl_loss = FDL_loss()
324
  fdl_loss = fdl_loss.to(accelerator.device)
325
 
326
+ # --------------------------- LPIPS и DSSIM ---------------------------
327
  _lpips_net = None
328
  def _get_lpips():
329
  global _lpips_net
 
331
  _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
332
  return _lpips_net
333
 
334
+ # --- DSSIM Implementation (замена Sobel) ---
335
+ def _gaussian_kernel(window_size, sigma, device, dtype):
336
+ coords = torch.arange(window_size, dtype=dtype, device=device) - (window_size - 1) / 2
337
+ k = torch.exp(-coords**2 / (2 * sigma**2))
338
+ return k / k.sum()
339
+
340
+ def _ssim(x, y, window_size=11, sigma=1.5):
341
+ """Вычисляет SSIM для тензоров в диапазоне [-1, 1]"""
342
+ C = x.shape[1]
343
+ k = _gaussian_kernel(window_size, sigma, x.device, x.dtype)
344
+ win = (k.view(1, 1, -1, 1) * k.view(1, 1, 1, -1)).expand(C, 1, window_size, window_size).contiguous()
345
+
346
+ mu1 = F.conv2d(x, win, padding=window_size//2, groups=C)
347
+ mu2 = F.conv2d(y, win, padding=window_size//2, groups=C)
348
+ mu1_sq = mu1.pow(2)
349
+ mu2_sq = mu2.pow(2)
350
+ mu1_mu2 = mu1 * mu2
351
+
352
+ sigma1_sq = F.conv2d(x*x, win, padding=window_size//2, groups=C) - mu1_sq
353
+ sigma2_sq = F.conv2d(y*y, win, padding=window_size//2, groups=C) - mu2_sq
354
+ sigma12 = F.conv2d(x*y, win, padding=window_size//2, groups=C) - mu1_mu2
355
+
356
+ L = 2.0 # диапазон для [-1, 1] -> 2
357
+ C1 = (0.01 * L)**2
358
+ C2 = (0.03 * L)**2
359
+
360
+ num = (2*mu1_mu2 + C1) * (2*sigma12 + C2)
361
+ den = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
362
+ ssim_map = num / den
363
+ return ssim_map.mean()
364
+
365
+ def dssim_loss(x, y):
366
+ """DSSIM = 1 - SSIM. Отлично сохраняет локальную структуру штрихов."""
367
+ return 1.0 - _ssim(x, y)
368
+
369
  class MedianLossNormalizer:
370
  def __init__(self, desired_ratios: dict, window_steps: int):
371
  s = sum(desired_ratios.values())
 
527
  latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
528
  rec = current_vae.decode(latents).sample
529
 
 
 
 
530
  rec_f32 = rec.to(torch.float32)
531
  imgs_f32 = imgs.to(torch.float32)
532
 
533
+ # --- Расчёт лоссов ---
534
  abs_losses = {
535
  "mae": F.l1_loss(rec_f32, imgs_f32),
536
+ "mse": torch.tensor(0.0, device=accelerator.device, dtype=torch.float32), # Отключён, экономим вычисления
537
  "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
538
  "fdl": fdl_loss(rec_f32, imgs_f32),
539
+ "dssim": dssim_loss(rec_f32, imgs_f32), # <-- DSSIM вместо Edge
540
  }
541
 
542
  if full_training and not train_decoder_only: