Instructions to use babkasotona/vae2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use babkasotona/vae2 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("babkasotona/vae2", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- train_sdxs_vae.py +45 -27
train_sdxs_vae.py
CHANGED
|
@@ -26,7 +26,6 @@ import wandb
|
|
| 26 |
import lpips # pip install lpips
|
| 27 |
from FDL_pytorch import FDL_loss # pip install fdl-pytorch
|
| 28 |
from collections import deque
|
| 29 |
-
import torch.nn.functional as F
|
| 30 |
|
| 31 |
# --------------------------- Параметры ---------------------------
|
| 32 |
ds_path = "/workspace/d23"
|
|
@@ -49,7 +48,7 @@ save_barrier = 1.3
|
|
| 49 |
warmup_percent = 0.005
|
| 50 |
beta2 = 0.997
|
| 51 |
eps = 1e-8
|
| 52 |
-
clip_grad_norm = 1.0
|
| 53 |
mixed_precision = "no"
|
| 54 |
gradient_accumulation_steps = 1
|
| 55 |
generated_folder = "samples"
|
|
@@ -72,10 +71,11 @@ kl_ratio = 0.0
|
|
| 72 |
|
| 73 |
# Доли лоссов
|
| 74 |
loss_ratios = {
|
| 75 |
-
"lpips": 0.
|
| 76 |
-
"fdl" : 0.
|
| 77 |
-
"mse": 0.
|
| 78 |
-
"mae": 0.
|
|
|
|
| 79 |
"kl": 0.00,
|
| 80 |
}
|
| 81 |
median_coeff_steps = 250
|
|
@@ -273,21 +273,6 @@ if len(dataset) < batch_size:
|
|
| 273 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, drop_last=True)
|
| 274 |
|
| 275 |
# --------------------------- Оптимизатор ---------------------------
|
| 276 |
-
def get_param_groups(module, weight_decay=0.001):
|
| 277 |
-
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
|
| 278 |
-
decay_params, no_decay_params = [], []
|
| 279 |
-
for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
|
| 280 |
-
if not p.requires_grad:
|
| 281 |
-
continue
|
| 282 |
-
if any(nd in n for nd in no_decay):
|
| 283 |
-
no_decay_params.append(p)
|
| 284 |
-
else:
|
| 285 |
-
decay_params.append(p)
|
| 286 |
-
return [
|
| 287 |
-
{"params": decay_params, "weight_decay": weight_decay},
|
| 288 |
-
{"params": no_decay_params, "weight_decay": 0.0},
|
| 289 |
-
]
|
| 290 |
-
|
| 291 |
def get_param_groups(module, weight_decay=0.001):
|
| 292 |
no_decay_tokens = ("bias", "norm", "rms", "layernorm")
|
| 293 |
decay_params, no_decay_params = [], []
|
|
@@ -338,7 +323,7 @@ trainable_params = [p for p in vae.parameters() if p.requires_grad]
|
|
| 338 |
fdl_loss = FDL_loss()
|
| 339 |
fdl_loss = fdl_loss.to(accelerator.device)
|
| 340 |
|
| 341 |
-
# --------------------------- LPIPS и
|
| 342 |
_lpips_net = None
|
| 343 |
def _get_lpips():
|
| 344 |
global _lpips_net
|
|
@@ -346,7 +331,41 @@ def _get_lpips():
|
|
| 346 |
_lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
|
| 347 |
return _lpips_net
|
| 348 |
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
class MedianLossNormalizer:
|
| 351 |
def __init__(self, desired_ratios: dict, window_steps: int):
|
| 352 |
s = sum(desired_ratios.values())
|
|
@@ -508,17 +527,16 @@ for epoch in range(num_epochs):
|
|
| 508 |
latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
|
| 509 |
rec = current_vae.decode(latents).sample
|
| 510 |
|
| 511 |
-
#if rec.shape[-2:] != imgs.shape[-2:]:
|
| 512 |
-
# rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
|
| 513 |
-
|
| 514 |
rec_f32 = rec.to(torch.float32)
|
| 515 |
imgs_f32 = imgs.to(torch.float32)
|
| 516 |
|
|
|
|
| 517 |
abs_losses = {
|
| 518 |
"mae": F.l1_loss(rec_f32, imgs_f32),
|
| 519 |
-
"mse":
|
| 520 |
"lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
|
| 521 |
"fdl": fdl_loss(rec_f32, imgs_f32),
|
|
|
|
| 522 |
}
|
| 523 |
|
| 524 |
if full_training and not train_decoder_only:
|
|
|
|
| 26 |
import lpips # pip install lpips
|
| 27 |
from FDL_pytorch import FDL_loss # pip install fdl-pytorch
|
| 28 |
from collections import deque
|
|
|
|
| 29 |
|
| 30 |
# --------------------------- Параметры ---------------------------
|
| 31 |
ds_path = "/workspace/d23"
|
|
|
|
| 48 |
warmup_percent = 0.005
|
| 49 |
beta2 = 0.997
|
| 50 |
eps = 1e-8
|
| 51 |
+
clip_grad_norm = 2.0 # Увеличили с 1.0 для лучшей прорисовки деталей
|
| 52 |
mixed_precision = "no"
|
| 53 |
gradient_accumulation_steps = 1
|
| 54 |
generated_folder = "samples"
|
|
|
|
| 71 |
|
| 72 |
# Доли лоссов
|
| 73 |
loss_ratios = {
|
| 74 |
+
"lpips": 0.60,
|
| 75 |
+
"fdl" : 0.15,
|
| 76 |
+
"mse": 0.00,
|
| 77 |
+
"mae": 0.20,
|
| 78 |
+
"dssim": 0.05, # <-- Заменён edge на DSSIM
|
| 79 |
"kl": 0.00,
|
| 80 |
}
|
| 81 |
median_coeff_steps = 250
|
|
|
|
| 273 |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, drop_last=True)
|
| 274 |
|
| 275 |
# --------------------------- Оптимизатор ---------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
def get_param_groups(module, weight_decay=0.001):
|
| 277 |
no_decay_tokens = ("bias", "norm", "rms", "layernorm")
|
| 278 |
decay_params, no_decay_params = [], []
|
|
|
|
| 323 |
fdl_loss = FDL_loss()
|
| 324 |
fdl_loss = fdl_loss.to(accelerator.device)
|
| 325 |
|
| 326 |
+
# --------------------------- LPIPS и DSSIM ---------------------------
|
| 327 |
_lpips_net = None
|
| 328 |
def _get_lpips():
|
| 329 |
global _lpips_net
|
|
|
|
| 331 |
_lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
|
| 332 |
return _lpips_net
|
| 333 |
|
| 334 |
+
# --- DSSIM Implementation (замена Sobel) ---
|
| 335 |
+
def _gaussian_kernel(window_size, sigma, device, dtype):
|
| 336 |
+
coords = torch.arange(window_size, dtype=dtype, device=device) - (window_size - 1) / 2
|
| 337 |
+
k = torch.exp(-coords**2 / (2 * sigma**2))
|
| 338 |
+
return k / k.sum()
|
| 339 |
+
|
| 340 |
+
def _ssim(x, y, window_size=11, sigma=1.5):
|
| 341 |
+
"""Вычисляет SSIM для тензоров в диапазоне [-1, 1]"""
|
| 342 |
+
C = x.shape[1]
|
| 343 |
+
k = _gaussian_kernel(window_size, sigma, x.device, x.dtype)
|
| 344 |
+
win = (k.view(1, 1, -1, 1) * k.view(1, 1, 1, -1)).expand(C, 1, window_size, window_size).contiguous()
|
| 345 |
+
|
| 346 |
+
mu1 = F.conv2d(x, win, padding=window_size//2, groups=C)
|
| 347 |
+
mu2 = F.conv2d(y, win, padding=window_size//2, groups=C)
|
| 348 |
+
mu1_sq = mu1.pow(2)
|
| 349 |
+
mu2_sq = mu2.pow(2)
|
| 350 |
+
mu1_mu2 = mu1 * mu2
|
| 351 |
+
|
| 352 |
+
sigma1_sq = F.conv2d(x*x, win, padding=window_size//2, groups=C) - mu1_sq
|
| 353 |
+
sigma2_sq = F.conv2d(y*y, win, padding=window_size//2, groups=C) - mu2_sq
|
| 354 |
+
sigma12 = F.conv2d(x*y, win, padding=window_size//2, groups=C) - mu1_mu2
|
| 355 |
+
|
| 356 |
+
L = 2.0 # диапазон для [-1, 1] -> 2
|
| 357 |
+
C1 = (0.01 * L)**2
|
| 358 |
+
C2 = (0.03 * L)**2
|
| 359 |
+
|
| 360 |
+
num = (2*mu1_mu2 + C1) * (2*sigma12 + C2)
|
| 361 |
+
den = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
|
| 362 |
+
ssim_map = num / den
|
| 363 |
+
return ssim_map.mean()
|
| 364 |
+
|
| 365 |
+
def dssim_loss(x, y):
|
| 366 |
+
"""DSSIM = 1 - SSIM. Отлично сохраняет локальную структуру штрихов."""
|
| 367 |
+
return 1.0 - _ssim(x, y)
|
| 368 |
+
|
| 369 |
class MedianLossNormalizer:
|
| 370 |
def __init__(self, desired_ratios: dict, window_steps: int):
|
| 371 |
s = sum(desired_ratios.values())
|
|
|
|
| 527 |
latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
|
| 528 |
rec = current_vae.decode(latents).sample
|
| 529 |
|
|
|
|
|
|
|
|
|
|
| 530 |
rec_f32 = rec.to(torch.float32)
|
| 531 |
imgs_f32 = imgs.to(torch.float32)
|
| 532 |
|
| 533 |
+
# --- Расчёт лоссов ---
|
| 534 |
abs_losses = {
|
| 535 |
"mae": F.l1_loss(rec_f32, imgs_f32),
|
| 536 |
+
"mse": torch.tensor(0.0, device=accelerator.device, dtype=torch.float32), # Отключён, экономим вычисления
|
| 537 |
"lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
|
| 538 |
"fdl": fdl_loss(rec_f32, imgs_f32),
|
| 539 |
+
"dssim": dssim_loss(rec_f32, imgs_f32), # <-- DSSIM вместо Edge
|
| 540 |
}
|
| 541 |
|
| 542 |
if full_training and not train_decoder_only:
|