babkasotona commited on
Commit
3d2849c
·
verified ·
1 Parent(s): 58c8c94

Upload folder using huggingface_hub

Browse files
diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ea0a9e3cec9df3c46e9ff4da0cc62f1a6bb546c747e30ed7a8c9b5abb0058b4
3
  size 383499124
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0963803789980cdf3df4d4e87b26c099ccd3ed08b78deb7a12afaba5e193c97
3
  size 383499124
train_sdxs_vae.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ # Import standard and asymmetric VAEs only
17
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
18
+
19
+ from accelerate import Accelerator
20
+ from PIL import Image, UnidentifiedImageError
21
+ from tqdm import tqdm
22
+ import bitsandbytes as bnb
23
+ import wandb
24
+ import lpips # pip install lpips
25
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
26
+ from collections import deque
27
+
28
+ # --- Configuration ---
29
+ DATASET_PATH = "/workspace/d23/alchemist"
30
+ PROJECT_NAME = "vae7"
31
+ BATCH_SIZE = 1
32
+ BASE_LEARNING_RATE = 4e-6
33
+ MIN_LEARNING_RATE = 4e-7
34
+ NUM_EPOCHS = 8
35
+ SAMPLE_INTERVAL_SHARE = 2
36
+ USE_WANDB = False
37
+ SAVE_MODEL = True
38
+ USE_DECAY = True
39
+ OPTIMIZER_TYPE = "adam8bit"
40
+ DTYPE = torch.float32
41
+
42
+ MODEL_RESOLUTION = 576
43
+ HIGH_RESOLUTION = 1152
44
+ DATA_LIMIT = 0 # Limit dataset size (0 for no limit)
45
+ SAVE_BARRIER = 1.3
46
+ WARMUP_PERCENT = 0.005
47
+ BETA2 = 0.997
48
+ EPSILON = 1e-8
49
+ CLIP_GRAD_NORM = 1.0
50
+ MIXED_PRECISION = "no"
51
+ GRADIENT_ACCUMULATION_STEPS = 1
52
+ GENERATED_FOLDER = "samples"
53
+ SAVE_AS = "vae8"
54
+ NUM_WORKERS = 0
55
+
56
+ # Enable deterministic training and optimizations
57
+ torch.backends.cuda.matmul.allow_tf32 = True
58
+ torch.backends.cudnn.allow_tf32 = True
59
+ torch.backends.cuda.enable_flash_sdp(True)
60
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
61
+ torch.backends.cuda.enable_math_sdp(False)
62
+
63
+ # --- Training Modes ---
64
+ TRAIN_DECODER_ONLY = True
65
+ TRAIN_UP_ONLY = False
66
+ FULL_TRAINING = False
67
+ KL_RATIO = 0.0
68
+
69
+ # --- Loss Ratios ---
70
+ LOSS_RATIOS = {
71
+ "lpips": 0.60,
72
+ "fdl" : 0.10,
73
+ "mse": 0.06,
74
+ "mae": 0.12,
75
+ "dssim": 0.06,
76
+ "kl": 0.00,
77
+ "edge": 0.06,
78
+ }
79
+ MEDIAN_COEFF_STEPS = 250
80
+
81
+ # --- VAE Type ---
82
+ # 'kl' for standard AutoencoderKL, 'asymmetric' for AsymmetricAutoencoderKL
83
+ VAE_TYPE = "asymmetric"
84
+
85
+ Path(GENERATED_FOLDER).mkdir(parents=True, exist_ok=True)
86
+
87
+ # Initialize Accelerator
88
+ accelerator = Accelerator(
89
+ mixed_precision=MIXED_PRECISION,
90
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS
91
+ )
92
+ device = accelerator.device
93
+
94
+ # Set seeds for reproducibility
95
+ seed = int(datetime.now().strftime("%Y%m%d")) + 42
96
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
97
+ torch.backends.cudnn.benchmark = False
98
+
99
+ # --------------------------- WandB Logging ---------------------------
100
+ if USE_WANDB and accelerator.is_main_process:
101
+ wandb.init(project=PROJECT_NAME, config={
102
+ "batch_size": BATCH_SIZE,
103
+ "base_learning_rate": BASE_LEARNING_RATE,
104
+ "num_epochs": NUM_EPOCHS,
105
+ "optimizer_type": OPTIMIZER_TYPE,
106
+ "model_resolution": MODEL_RESOLUTION,
107
+ "high_resolution": HIGH_RESOLUTION,
108
+ "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
109
+ "train_decoder_only": TRAIN_DECODER_ONLY,
110
+ "full_training": FULL_TRAINING,
111
+ "kl_ratio": KL_RATIO,
112
+ "vae_type": VAE_TYPE,
113
+ })
114
+
115
+ # --------------------------- VAE Model Loading ---------------------------
116
+ def get_core_model(model):
117
+ """Unwraps a model potentially wrapped by torch.compile."""
118
+ if hasattr(model, "_orig_mod"):
119
+ model = model._orig_mod
120
+ return model
121
+
122
+ # Load the appropriate VAE model (Video VAEs completely removed)
123
+ if VAE_TYPE == "asymmetric":
124
+ vae = AsymmetricAutoencoderKL.from_pretrained(PROJECT_NAME)
125
+ elif VAE_TYPE == "kl":
126
+ vae = AutoencoderKL.from_pretrained(PROJECT_NAME)
127
+ else:
128
+ raise ValueError(f"Unsupported VAE_TYPE: {VAE_TYPE}")
129
+
130
+ vae = vae.to(DTYPE)
131
+
132
+ # Apply torch.compile
133
+ if hasattr(torch, "compile"):
134
+ try:
135
+ vae = torch.compile(vae)
136
+ print("[INFO] torch.compile applied successfully.")
137
+ except Exception as e:
138
+ print(f"[WARN] torch.compile failed: {e}")
139
+
140
+ # --------------------------- Freeze/Unfreeze Parameters ---------------------------
141
+ core = get_core_model(vae)
142
+
143
+ for p in core.parameters():
144
+ p.requires_grad = False
145
+
146
+ unfrozen_param_names = []
147
+
148
+ if FULL_TRAINING and not TRAIN_DECODER_ONLY:
149
+ for name, p in core.named_parameters():
150
+ p.requires_grad = True
151
+ unfrozen_param_names.append(name)
152
+ LOSS_RATIOS["kl"] = float(KL_RATIO)
153
+ trainable_module = core
154
+ else:
155
+ if hasattr(core, "decoder"):
156
+ if TRAIN_UP_ONLY and hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
157
+ for name, p in core.decoder.up_blocks[0].named_parameters():
158
+ p.requires_grad = True
159
+ unfrozen_param_names.append(f"decoder.up_blocks[0].{name}")
160
+ else:
161
+ print("[INFO] Decoder: Falling back to training the full decoder.")
162
+ for name, p in core.decoder.named_parameters():
163
+ p.requires_grad = True
164
+ unfrozen_param_names.append(f"decoder.{name}")
165
+ if hasattr(core, "post_quant_conv"):
166
+ for name, p in core.post_quant_conv.named_parameters():
167
+ p.requires_grad = True
168
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
169
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
170
+
171
+ print(f"[INFO] Unfrozen parameters: {len(unfrozen_param_names)}. First 10 names:")
172
+ for nm in unfrozen_param_names[:10]:
173
+ print(f" {nm}")
174
+
175
+ # --------------------------- Dataset Preparation ---------------------------
176
+ class PngFolderDataset(Dataset):
177
+ def __init__(self, root_dir, resolution=1024, min_exts=('.png',), limit=0):
178
+ self.resolution = resolution
179
+ self.paths = []
180
+ for root, _, files in os.walk(root_dir):
181
+ for f in files:
182
+ if f.lower().endswith(tuple(ext.lower() for ext in min_exts)):
183
+ self.paths.append(os.path.join(root, f))
184
+ if limit > 0:
185
+ self.paths = self.paths[:limit]
186
+
187
+ valid_paths = []
188
+ for p in self.paths:
189
+ try:
190
+ with Image.open(p) as img:
191
+ img.verify()
192
+ w, h = img.size
193
+ if w < resolution or h < resolution:
194
+ continue
195
+ valid_paths.append(p)
196
+ except (OSError, UnidentifiedImageError) as e:
197
+ print(f"[WARN] Skipping invalid image file {p}: {e}")
198
+ self.paths = valid_paths
199
+ if not self.paths:
200
+ raise RuntimeError(f"No valid images found in {root_dir}")
201
+ random.shuffle(self.paths)
202
+ self.transform = transforms.ToTensor()
203
+
204
+ def __len__(self): return len(self.paths)
205
+
206
+ def __getitem__(self, idx):
207
+ p = self.paths[idx % len(self.paths)]
208
+ try:
209
+ with Image.open(p) as img:
210
+ return img.convert("RGB")
211
+ except Exception as e:
212
+ print(f"[ERROR] Failed to load image {p}: {e}")
213
+ return Image.new("RGB", (self.resolution, self.resolution), 'red')
214
+
215
+ def random_crop(img, sz):
216
+ w, h = img.size
217
+ crop_w = min(sz, w)
218
+ crop_h = min(sz, h)
219
+ x = random.randint(0, max(0, w - crop_w))
220
+ y = random.randint(0, max(0, h - crop_h))
221
+ return img.crop((x, y, x + crop_w, y + crop_h))
222
+
223
+ input_tfm = transforms.Compose([
224
+ transforms.ToTensor(),
225
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
226
+ ])
227
+
228
+ def collate_fn(batch):
229
+ imgs = []
230
+ for img in batch:
231
+ img = random_crop(img, HIGH_RESOLUTION)
232
+ imgs.append(input_tfm(img))
233
+ return torch.stack(imgs)
234
+
235
+ try:
236
+ dataset = PngFolderDataset(DATASET_PATH, min_exts=('.png', '.PNG'), resolution=HIGH_RESOLUTION, limit=DATA_LIMIT)
237
+ print(f"[INFO] Dataset loaded: {len(dataset)} images.")
238
+ if len(dataset) < BATCH_SIZE:
239
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {BATCH_SIZE}")
240
+
241
+ dataloader = DataLoader(
242
+ dataset, batch_size=BATCH_SIZE, shuffle=True,
243
+ collate_fn=collate_fn, num_workers=NUM_WORKERS,
244
+ pin_memory=True, drop_last=True
245
+ )
246
+ except RuntimeError as e:
247
+ print(f"[ERROR] Failed to initialize dataloader: {e}")
248
+ exit()
249
+
250
+ # --------------------------- Optimizer Setup ---------------------------
251
+ def get_param_groups(module, weight_decay=0.001):
252
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
253
+ decay_params, no_decay_params = [], []
254
+ for name, param in module.named_parameters():
255
+ if not param.requires_grad: continue
256
+ name_lower = name.lower()
257
+ if any(token in name_lower for token in no_decay_tokens):
258
+ no_decay_params.append(param)
259
+ else:
260
+ decay_params.append(param)
261
+ return [
262
+ {"params": decay_params, "weight_decay": weight_decay},
263
+ {"params": no_decay_params, "weight_decay": 0.0},
264
+ ]
265
+
266
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
267
+ optimizer = bnb.optim.AdamW8bit(param_groups, lr=BASE_LEARNING_RATE, betas=(0.9, BETA2), eps=EPSILON)
268
+
269
+ # --------------------------- Learning Rate Scheduler ---------------------------
270
+ batches_per_epoch = len(dataloader)
271
+ steps_per_epoch = math.ceil(batches_per_epoch / float(GRADIENT_ACCUMULATION_STEPS))
272
+ total_steps = steps_per_epoch * NUM_EPOCHS
273
+
274
+ def lr_lambda(step):
275
+ if not USE_DECAY: return 1.0
276
+ current_step_fraction = float(step) / float(max(1, total_steps))
277
+ warmup_fraction = float(WARMUP_PERCENT)
278
+ min_lr_ratio = float(MIN_LEARNING_RATE) / float(BASE_LEARNING_RATE)
279
+ if current_step_fraction < warmup_fraction:
280
+ return min_lr_ratio + (1.0 - min_lr_ratio) * (current_step_fraction / warmup_fraction)
281
+ else:
282
+ decay_fraction = (current_step_fraction - warmup_fraction) / (1.0 - warmup_fraction)
283
+ return min_lr_ratio + 0.5 * (1.0 - min_lr_ratio) * (1.0 + math.cos(math.pi * decay_fraction))
284
+
285
+ scheduler = LambdaLR(optimizer, lr_lambda)
286
+
287
+ # --------------------------- Prepare for Training ---------------------------
288
+ (dataloader, vae, optimizer, scheduler) = accelerator.prepare(dataloader, vae, optimizer, scheduler)
289
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
290
+
291
+ fdl_loss_fn = FDL_loss().to(accelerator.device)
292
+
293
+ _lpips_net = None
294
+ def get_lpips_loss():
295
+ global _lpips_net
296
+ if _lpips_net is None:
297
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device)
298
+ return _lpips_net
299
+
300
+ def _gaussian_kernel(window_size, sigma, device, dtype):
301
+ coords = torch.arange(window_size, dtype=dtype, device=device) - (window_size - 1) / 2
302
+ k = torch.exp(-coords**2 / (2 * sigma**2))
303
+ return k / k.sum()
304
+
305
+ def _ssim(img1, img2, window_size=11, sigma=1.5):
306
+ channels = img1.shape[1]
307
+ kernel = _gaussian_kernel(window_size, sigma, img1.device, img1.dtype)
308
+ win = (kernel.view(1, 1, -1, 1) * kernel.view(1, 1, 1, -1)).expand(channels, 1, window_size, window_size).contiguous()
309
+ mu1 = F.conv2d(img1, win, padding=window_size//2, groups=channels)
310
+ mu2 = F.conv2d(img2, win, padding=window_size//2, groups=channels)
311
+ mu1_sq, mu2_sq, mu1_mu2 = mu1.pow(2), mu2.pow(2), mu1 * mu2
312
+ sigma1_sq = F.conv2d(img1*img1, win, padding=window_size//2, groups=channels) - mu1_sq
313
+ sigma2_sq = F.conv2d(img2*img2, win, padding=window_size//2, groups=channels) - mu2_sq
314
+ sigma12 = F.conv2d(img1*img2, win, padding=window_size//2, groups=channels) - mu1_mu2
315
+
316
+ # ИСПРАВЛЕНО: Разделено присваивание, чтобы избежать UnboundLocalError
317
+ L = 2.0
318
+ C1 = (0.01 * L) ** 2
319
+ C2 = (0.03 * L) ** 2
320
+
321
+ num = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
322
+ den = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
323
+ return (num / den).mean()
324
+
325
+ def edge_loss(img1, img2):
326
+ def get_edges(img):
327
+ C = img.shape[1]
328
+ # Sobel x kernel (horizontal edges)
329
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=img.dtype, device=img.device).view(1, 1, 3, 3)
330
+ # Sobel y kernel (vertical edges)
331
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=img.dtype, device=img.device).view(1, 1, 3, 3)
332
+ # Repeat for each channel
333
+ sobel_x_c = sobel_x.repeat(C, 1, 1, 1)
334
+ sobel_y_c = sobel_y.repeat(C, 1, 1, 1)
335
+ # Apply convolution per channel
336
+ grad_x = F.conv2d(img, sobel_x_c, padding=1, groups=C)
337
+ grad_y = F.conv2d(img, sobel_y_c, padding=1, groups=C)
338
+ # Gradient magnitude
339
+ return torch.sqrt(grad_x**2 + grad_y**2 + 1e-12)
340
+ return F.l1_loss(get_edges(img1), get_edges(img2))
341
+
342
+ def dssim_loss(img1, img2):
343
+ return 1.0 - _ssim(img1, img2)
344
+
345
+ class MedianLossNormalizer:
346
+ def __init__(self, desired_ratios: dict, window_steps: int):
347
+ total_ratio = sum(desired_ratios.values())
348
+ self.ratios = {k: (v / total_ratio) if total_ratio > 0 else 0.0 for k, v in desired_ratios.items()}
349
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
350
+
351
+ def update_and_total(self, absolute_losses: dict):
352
+ for k, v in absolute_losses.items():
353
+ if k in self.buffers:
354
+ self.buffers[k].append(float(v.detach().abs().cpu()))
355
+ medians = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
356
+ coefficients = {k: (self.ratios[k] / max(medians[k], 1e-12)) for k in self.ratios}
357
+ total_loss = sum(coefficients[k] * absolute_losses[k] for k in absolute_losses if k in coefficients)
358
+ return total_loss, coefficients, medians
359
+
360
+ loss_normalizer = MedianLossNormalizer(LOSS_RATIOS, MEDIAN_COEFF_STEPS)
361
+
362
+ # --------------------------- Sample Generation ---------------------------
363
+ @torch.no_grad()
364
+ def get_fixed_samples(n=3):
365
+ indices = random.sample(range(len(dataset)), min(n, len(dataset)))
366
+ tensors = [input_tfm(random_crop(dataset[i], HIGH_RESOLUTION)) for i in indices]
367
+ return torch.stack(tensors).to(accelerator.device, DTYPE)
368
+
369
+ fixed_samples = get_fixed_samples()
370
+
371
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
372
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
373
+ return Image.fromarray(arr)
374
+
375
+ @torch.no_grad()
376
+ def generate_and_save_samples(step=None):
377
+ try:
378
+ unwrapped_vae = accelerator.unwrap_model(vae)
379
+ temp_vae = get_core_model(unwrapped_vae).eval()
380
+ lpips_net = get_lpips_loss()
381
+
382
+ original_high_res = fixed_samples
383
+ input_tensor = original_high_res.to(dtype=next(temp_vae.parameters()).dtype) if MODEL_RESOLUTION == HIGH_RESOLUTION else F.interpolate(original_high_res, size=(MODEL_RESOLUTION, MODEL_RESOLUTION), mode="area")
384
+
385
+ encoder_output = temp_vae.encode(input_tensor)
386
+ latents = encoder_output.latent_dist.mean if TRAIN_DECODER_ONLY else encoder_output.latent_dist.sample()
387
+ reconstructed_images = temp_vae.decode(latents).sample
388
+
389
+ if reconstructed_images.shape[-2:] != original_high_res.shape[-2:]:
390
+ reconstructed_images = F.interpolate(reconstructed_images, size=original_high_res.shape[-2:], mode="bilinear", align_corners=False)
391
+
392
+ for i in range(reconstructed_images.shape[0]):
393
+ _to_pil_uint8(original_high_res[i]).save(os.path.join(GENERATED_FOLDER, f"sample_real_{i}.png"))
394
+ _to_pil_uint8(reconstructed_images[i]).save(os.path.join(GENERATED_FOLDER, f"sample_decoded_{i}.png"))
395
+
396
+ if USE_WANDB and accelerator.is_main_process:
397
+ log_data = {"lpips_mean": float(np.mean([lpips_net(original_high_res[i:i+1], reconstructed_images[i:i+1]).item() for i in range(len(original_high_res))]))}
398
+ for i in range(len(original_high_res)):
399
+ log_data[f"sample/real_{i}"] = wandb.Image(os.path.join(GENERATED_FOLDER, f"sample_real_{i}.png"))
400
+ log_data[f"sample/decoded_{i}"] = wandb.Image(os.path.join(GENERATED_FOLDER, f"sample_decoded_{i}.png"))
401
+ wandb.log(log_data, step=step)
402
+ finally:
403
+ gc.collect()
404
+ torch.cuda.empty_cache()
405
+
406
+ if accelerator.is_main_process and SAVE_MODEL:
407
+ print("[INFO] Generating initial samples before training...")
408
+ generate_and_save_samples(step=0)
409
+
410
+ accelerator.wait_for_everyone()
411
+
412
+ # --------------------------- Training Loop ---------------------------
413
+ progress_bar = tqdm(total=total_steps, desc="Training", disable=not accelerator.is_local_main_process)
414
+ global_step = 0
415
+ min_loss = float("inf")
416
+ num_samples_per_epoch = max(1, int(total_steps / max(1, SAMPLE_INTERVAL_SHARE * NUM_EPOCHS)))
417
+ sample_interval = max(1, int(round(num_samples_per_epoch / GRADIENT_ACCUMULATION_STEPS)))
418
+
419
+ for epoch in range(NUM_EPOCHS):
420
+ vae.train()
421
+ batch_losses_history, batch_grads_history = [], []
422
+ tracked_losses = {k: [] for k in LOSS_RATIOS.keys()}
423
+
424
+ for batch_idx, imgs in enumerate(dataloader):
425
+ with accelerator.accumulate(vae):
426
+ imgs = imgs.to(accelerator.device)
427
+ imgs_low = imgs if MODEL_RESOLUTION == HIGH_RESOLUTION else F.interpolate(imgs, size=(MODEL_RESOLUTION, MODEL_RESOLUTION), mode="area")
428
+ model_dtype = next(vae.parameters()).dtype
429
+ input_images = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
430
+
431
+ current_vae_model = get_core_model(accelerator.unwrap_model(vae))
432
+ encoder_output = current_vae_model.encode(input_images)
433
+ latents = encoder_output.latent_dist.mean if TRAIN_DECODER_ONLY else encoder_output.latent_dist.sample()
434
+ rec_f32 = current_vae_model.decode(latents).sample.to(torch.float32)
435
+ imgs_f32 = imgs.to(torch.float32)
436
+
437
+ mae_loss = F.l1_loss(rec_f32, imgs_f32)
438
+ mse_loss = F.mse_loss(rec_f32, imgs_f32)
439
+ lpips_loss_val = get_lpips_loss()(rec_f32, imgs_f32).mean()
440
+ fdl_loss_val = fdl_loss_fn(rec_f32, imgs_f32)
441
+ dssim_loss_val = dssim_loss(rec_f32, imgs_f32)
442
+ edge_loss_val = edge_loss(rec_f32, imgs_f32)
443
+
444
+ kl_loss = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
445
+ if FULL_TRAINING and not TRAIN_DECODER_ONLY:
446
+ mean = encoder_output.latent_dist.mean
447
+ logvar = encoder_output.latent_dist.logvar
448
+ kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
449
+
450
+ absolute_losses = {
451
+ "mae": mae_loss, "mse": mse_loss, "lpips": lpips_loss_val,
452
+ "fdl": fdl_loss_val, "dssim": dssim_loss_val, "kl": kl_loss,
453
+ "edge": edge_loss_val,
454
+ }
455
+
456
+ total_loss, coeffs, medians = loss_normalizer.update_and_total(absolute_losses)
457
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
458
+ raise RuntimeError("NaN/Inf loss encountered during training.")
459
+
460
+ accelerator.backward(total_loss)
461
+
462
+ current_grad_norm = torch.tensor(0.0, device=accelerator.device)
463
+ if accelerator.sync_gradients:
464
+ current_grad_norm = accelerator.clip_grad_norm_(trainable_params, CLIP_GRAD_NORM)
465
+ optimizer.step()
466
+ scheduler.step()
467
+ optimizer.zero_grad(set_to_none=True)
468
+ global_step += 1
469
+ progress_bar.update(1)
470
+
471
+ if accelerator.is_main_process:
472
+ try: current_lr = optimizer.param_groups[0]["lr"]
473
+ except Exception: current_lr = scheduler.get_last_lr()[0]
474
+
475
+ batch_losses_history.append(total_loss.detach().item())
476
+ batch_grads_history.append(float(current_grad_norm.detach().cpu().item()))
477
+ for k, v in absolute_losses.items():
478
+ tracked_losses[k].append(float(v.detach().item()))
479
+
480
+ if USE_WANDB and accelerator.sync_gradients:
481
+ log_dict = {"total_loss": batch_losses_history[-1], "learning_rate": current_lr, "epoch": epoch, "grad_norm": batch_grads_history[-1]}
482
+ for k, v in absolute_losses.items(): log_dict[f"loss_{k}"] = float(v.detach().item())
483
+ for k in coeffs: log_dict[f"coeff_{k}"] = float(coeffs[k])
484
+ wandb.log(log_dict, step=global_step)
485
+
486
+ if global_step > 0 and global_step % sample_interval == 0:
487
+ if accelerator.is_main_process:
488
+ generate_and_save_samples(step=global_step)
489
+ accelerator.wait_for_everyone()
490
+
491
+ n_logs = min(len(batch_losses_history), sample_interval)
492
+ avg_total = float(np.mean(batch_losses_history[-n_logs:]))
493
+ avg_grad = float(np.mean(batch_grads_history[-n_logs:]))
494
+
495
+ # ЯВНОЕ ЛОГИРОВАНИЕ КОМПОНЕНТ ПОТЕРЬ
496
+ loss_avgs = {k: float(np.mean(tracked_losses[k][-n_logs:])) for k in tracked_losses if len(tracked_losses[k]) >= n_logs}
497
+
498
+ print(f"Epoch {epoch} | Step {global_step} | "
499
+ f"Total: {avg_total:.5f} | "
500
+ f"LPIPS: {loss_avgs.get('lpips', 0):.5f} | "
501
+ f"DSSIM: {loss_avgs.get('dssim', 0):.5f} | "
502
+ f"MAE: {loss_avgs.get('mae', 0):.5f} | "
503
+ f"FDL: {loss_avgs.get('fdl', 0):.5f} | "
504
+ f"EDGE: {loss_avgs.get('edge', 0):.5f} | "
505
+ f"MSE: {loss_avgs.get('mse', 0):.5f} | "
506
+ f"Grad: {avg_grad:.5f} | LR: {current_lr:.9f}")
507
+
508
+ if SAVE_MODEL and avg_total < min_loss * SAVE_BARRIER:
509
+ min_loss = avg_total
510
+ print(f"[INFO] Saving model with improved loss: {min_loss:.6f}")
511
+ get_core_model(accelerator.unwrap_model(vae)).save_pretrained(SAVE_AS)
512
+
513
+ if accelerator.is_main_process:
514
+ print(f"Epoch {epoch} completed. Average Loss: {float(np.mean(batch_losses_history)):.6f}")
515
+
516
+ if accelerator.is_main_process:
517
+ print("Training finished – saving final model.")
518
+ if SAVE_MODEL:
519
+ get_core_model(accelerator.unwrap_model(vae)).save_pretrained(SAVE_AS)
520
+
521
+ accelerator.free_memory()
522
+ if torch.distributed.is_initialized():
523
+ torch.distributed.destroy_process_group()
524
+
525
+ print("Training complete. Done!")