Instructions to use babkasotona/vae8 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use babkasotona/vae8 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("babkasotona/vae8", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # -*- coding: utf-8 -*- | |
| import os | |
| import math | |
| import re | |
| import torch | |
| import numpy as np | |
| import random | |
| import gc | |
| from datetime import datetime | |
| from pathlib import Path | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.optim.lr_scheduler import LambdaLR | |
| # Import standard and asymmetric VAEs only | |
| from diffusers import AutoencoderKL, AsymmetricAutoencoderKL | |
| from accelerate import Accelerator | |
| from PIL import Image, UnidentifiedImageError | |
| from tqdm import tqdm | |
| import bitsandbytes as bnb | |
| import wandb | |
| import lpips # pip install lpips | |
| from FDL_pytorch import FDL_loss # pip install fdl-pytorch | |
| from collections import deque | |
| # --- Configuration --- | |
| DATASET_PATH = "/workspace/d23/alchemist" | |
| PROJECT_NAME = "vae7" | |
| BATCH_SIZE = 1 | |
| BASE_LEARNING_RATE = 4e-6 | |
| MIN_LEARNING_RATE = 4e-7 | |
| NUM_EPOCHS = 8 | |
| SAMPLE_INTERVAL_SHARE = 2 | |
| USE_WANDB = False | |
| SAVE_MODEL = True | |
| USE_DECAY = True | |
| OPTIMIZER_TYPE = "adam8bit" | |
| DTYPE = torch.float32 | |
| MODEL_RESOLUTION = 576 | |
| HIGH_RESOLUTION = 1152 | |
| DATA_LIMIT = 0 # Limit dataset size (0 for no limit) | |
| SAVE_BARRIER = 1.3 | |
| WARMUP_PERCENT = 0.005 | |
| BETA2 = 0.997 | |
| EPSILON = 1e-8 | |
| CLIP_GRAD_NORM = 1.0 | |
| MIXED_PRECISION = "no" | |
| GRADIENT_ACCUMULATION_STEPS = 1 | |
| GENERATED_FOLDER = "samples" | |
| SAVE_AS = "vae8" | |
| NUM_WORKERS = 0 | |
| # Enable deterministic training and optimizations | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.enable_flash_sdp(True) | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| torch.backends.cuda.enable_math_sdp(False) | |
| # --- Training Modes --- | |
| TRAIN_DECODER_ONLY = True | |
| TRAIN_UP_ONLY = False | |
| FULL_TRAINING = False | |
| KL_RATIO = 0.0 | |
| # --- Loss Ratios --- | |
| LOSS_RATIOS = { | |
| "lpips": 0.60, | |
| "fdl" : 0.10, | |
| "mse": 0.06, | |
| "mae": 0.12, | |
| "dssim": 0.06, | |
| "kl": 0.00, | |
| "edge": 0.06, | |
| } | |
| MEDIAN_COEFF_STEPS = 250 | |
| # --- VAE Type --- | |
| # 'kl' for standard AutoencoderKL, 'asymmetric' for AsymmetricAutoencoderKL | |
| VAE_TYPE = "asymmetric" | |
| Path(GENERATED_FOLDER).mkdir(parents=True, exist_ok=True) | |
| # Initialize Accelerator | |
| accelerator = Accelerator( | |
| mixed_precision=MIXED_PRECISION, | |
| gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS | |
| ) | |
| device = accelerator.device | |
| # Set seeds for reproducibility | |
| seed = int(datetime.now().strftime("%Y%m%d")) + 42 | |
| torch.manual_seed(seed); np.random.seed(seed); random.seed(seed) | |
| torch.backends.cudnn.benchmark = False | |
| # --------------------------- WandB Logging --------------------------- | |
| if USE_WANDB and accelerator.is_main_process: | |
| wandb.init(project=PROJECT_NAME, config={ | |
| "batch_size": BATCH_SIZE, | |
| "base_learning_rate": BASE_LEARNING_RATE, | |
| "num_epochs": NUM_EPOCHS, | |
| "optimizer_type": OPTIMIZER_TYPE, | |
| "model_resolution": MODEL_RESOLUTION, | |
| "high_resolution": HIGH_RESOLUTION, | |
| "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS, | |
| "train_decoder_only": TRAIN_DECODER_ONLY, | |
| "full_training": FULL_TRAINING, | |
| "kl_ratio": KL_RATIO, | |
| "vae_type": VAE_TYPE, | |
| }) | |
| # --------------------------- VAE Model Loading --------------------------- | |
| def get_core_model(model): | |
| """Unwraps a model potentially wrapped by torch.compile.""" | |
| if hasattr(model, "_orig_mod"): | |
| model = model._orig_mod | |
| return model | |
| # Load the appropriate VAE model (Video VAEs completely removed) | |
| if VAE_TYPE == "asymmetric": | |
| vae = AsymmetricAutoencoderKL.from_pretrained(PROJECT_NAME) | |
| elif VAE_TYPE == "kl": | |
| vae = AutoencoderKL.from_pretrained(PROJECT_NAME) | |
| else: | |
| raise ValueError(f"Unsupported VAE_TYPE: {VAE_TYPE}") | |
| vae = vae.to(DTYPE) | |
| # Apply torch.compile | |
| if hasattr(torch, "compile"): | |
| try: | |
| vae = torch.compile(vae) | |
| print("[INFO] torch.compile applied successfully.") | |
| except Exception as e: | |
| print(f"[WARN] torch.compile failed: {e}") | |
| # --------------------------- Freeze/Unfreeze Parameters --------------------------- | |
| core = get_core_model(vae) | |
| for p in core.parameters(): | |
| p.requires_grad = False | |
| unfrozen_param_names = [] | |
| if FULL_TRAINING and not TRAIN_DECODER_ONLY: | |
| for name, p in core.named_parameters(): | |
| p.requires_grad = True | |
| unfrozen_param_names.append(name) | |
| LOSS_RATIOS["kl"] = float(KL_RATIO) | |
| trainable_module = core | |
| else: | |
| if hasattr(core, "decoder"): | |
| if TRAIN_UP_ONLY and hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0: | |
| for name, p in core.decoder.up_blocks[0].named_parameters(): | |
| p.requires_grad = True | |
| unfrozen_param_names.append(f"decoder.up_blocks[0].{name}") | |
| else: | |
| print("[INFO] Decoder: Falling back to training the full decoder.") | |
| for name, p in core.decoder.named_parameters(): | |
| p.requires_grad = True | |
| unfrozen_param_names.append(f"decoder.{name}") | |
| if hasattr(core, "post_quant_conv"): | |
| for name, p in core.post_quant_conv.named_parameters(): | |
| p.requires_grad = True | |
| unfrozen_param_names.append(f"post_quant_conv.{name}") | |
| trainable_module = core.decoder if hasattr(core, "decoder") else core | |
| print(f"[INFO] Unfrozen parameters: {len(unfrozen_param_names)}. First 10 names:") | |
| for nm in unfrozen_param_names[:10]: | |
| print(f" {nm}") | |
| # --------------------------- Dataset Preparation --------------------------- | |
| class PngFolderDataset(Dataset): | |
| def __init__(self, root_dir, resolution=1024, min_exts=('.png',), limit=0): | |
| self.resolution = resolution | |
| self.paths = [] | |
| for root, _, files in os.walk(root_dir): | |
| for f in files: | |
| if f.lower().endswith(tuple(ext.lower() for ext in min_exts)): | |
| self.paths.append(os.path.join(root, f)) | |
| if limit > 0: | |
| self.paths = self.paths[:limit] | |
| valid_paths = [] | |
| for p in self.paths: | |
| try: | |
| with Image.open(p) as img: | |
| img.verify() | |
| w, h = img.size | |
| if w < resolution or h < resolution: | |
| continue | |
| valid_paths.append(p) | |
| except (OSError, UnidentifiedImageError) as e: | |
| print(f"[WARN] Skipping invalid image file {p}: {e}") | |
| self.paths = valid_paths | |
| if not self.paths: | |
| raise RuntimeError(f"No valid images found in {root_dir}") | |
| random.shuffle(self.paths) | |
| self.transform = transforms.ToTensor() | |
| def __len__(self): return len(self.paths) | |
| def __getitem__(self, idx): | |
| p = self.paths[idx % len(self.paths)] | |
| try: | |
| with Image.open(p) as img: | |
| return img.convert("RGB") | |
| except Exception as e: | |
| print(f"[ERROR] Failed to load image {p}: {e}") | |
| return Image.new("RGB", (self.resolution, self.resolution), 'red') | |
| def random_crop(img, sz): | |
| w, h = img.size | |
| crop_w = min(sz, w) | |
| crop_h = min(sz, h) | |
| x = random.randint(0, max(0, w - crop_w)) | |
| y = random.randint(0, max(0, h - crop_h)) | |
| return img.crop((x, y, x + crop_w, y + crop_h)) | |
| input_tfm = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
| ]) | |
| def collate_fn(batch): | |
| imgs = [] | |
| for img in batch: | |
| img = random_crop(img, HIGH_RESOLUTION) | |
| imgs.append(input_tfm(img)) | |
| return torch.stack(imgs) | |
| try: | |
| dataset = PngFolderDataset(DATASET_PATH, min_exts=('.png', '.PNG'), resolution=HIGH_RESOLUTION, limit=DATA_LIMIT) | |
| print(f"[INFO] Dataset loaded: {len(dataset)} images.") | |
| if len(dataset) < BATCH_SIZE: | |
| raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {BATCH_SIZE}") | |
| dataloader = DataLoader( | |
| dataset, batch_size=BATCH_SIZE, shuffle=True, | |
| collate_fn=collate_fn, num_workers=NUM_WORKERS, | |
| pin_memory=True, drop_last=True | |
| ) | |
| except RuntimeError as e: | |
| print(f"[ERROR] Failed to initialize dataloader: {e}") | |
| exit() | |
| # --------------------------- Optimizer Setup --------------------------- | |
| def get_param_groups(module, weight_decay=0.001): | |
| no_decay_tokens = ("bias", "norm", "rms", "layernorm") | |
| decay_params, no_decay_params = [], [] | |
| for name, param in module.named_parameters(): | |
| if not param.requires_grad: continue | |
| name_lower = name.lower() | |
| if any(token in name_lower for token in no_decay_tokens): | |
| no_decay_params.append(param) | |
| else: | |
| decay_params.append(param) | |
| return [ | |
| {"params": decay_params, "weight_decay": weight_decay}, | |
| {"params": no_decay_params, "weight_decay": 0.0}, | |
| ] | |
| param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001) | |
| optimizer = bnb.optim.AdamW8bit(param_groups, lr=BASE_LEARNING_RATE, betas=(0.9, BETA2), eps=EPSILON) | |
| # --------------------------- Learning Rate Scheduler --------------------------- | |
| batches_per_epoch = len(dataloader) | |
| steps_per_epoch = math.ceil(batches_per_epoch / float(GRADIENT_ACCUMULATION_STEPS)) | |
| total_steps = steps_per_epoch * NUM_EPOCHS | |
| def lr_lambda(step): | |
| if not USE_DECAY: return 1.0 | |
| current_step_fraction = float(step) / float(max(1, total_steps)) | |
| warmup_fraction = float(WARMUP_PERCENT) | |
| min_lr_ratio = float(MIN_LEARNING_RATE) / float(BASE_LEARNING_RATE) | |
| if current_step_fraction < warmup_fraction: | |
| return min_lr_ratio + (1.0 - min_lr_ratio) * (current_step_fraction / warmup_fraction) | |
| else: | |
| decay_fraction = (current_step_fraction - warmup_fraction) / (1.0 - warmup_fraction) | |
| return min_lr_ratio + 0.5 * (1.0 - min_lr_ratio) * (1.0 + math.cos(math.pi * decay_fraction)) | |
| scheduler = LambdaLR(optimizer, lr_lambda) | |
| # --------------------------- Prepare for Training --------------------------- | |
| (dataloader, vae, optimizer, scheduler) = accelerator.prepare(dataloader, vae, optimizer, scheduler) | |
| trainable_params = [p for p in vae.parameters() if p.requires_grad] | |
| fdl_loss_fn = FDL_loss().to(accelerator.device) | |
| _lpips_net = None | |
| def get_lpips_loss(): | |
| global _lpips_net | |
| if _lpips_net is None: | |
| _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device) | |
| return _lpips_net | |
| def _gaussian_kernel(window_size, sigma, device, dtype): | |
| coords = torch.arange(window_size, dtype=dtype, device=device) - (window_size - 1) / 2 | |
| k = torch.exp(-coords**2 / (2 * sigma**2)) | |
| return k / k.sum() | |
| def _ssim(img1, img2, window_size=11, sigma=1.5): | |
| channels = img1.shape[1] | |
| kernel = _gaussian_kernel(window_size, sigma, img1.device, img1.dtype) | |
| win = (kernel.view(1, 1, -1, 1) * kernel.view(1, 1, 1, -1)).expand(channels, 1, window_size, window_size).contiguous() | |
| mu1 = F.conv2d(img1, win, padding=window_size//2, groups=channels) | |
| mu2 = F.conv2d(img2, win, padding=window_size//2, groups=channels) | |
| mu1_sq, mu2_sq, mu1_mu2 = mu1.pow(2), mu2.pow(2), mu1 * mu2 | |
| sigma1_sq = F.conv2d(img1*img1, win, padding=window_size//2, groups=channels) - mu1_sq | |
| sigma2_sq = F.conv2d(img2*img2, win, padding=window_size//2, groups=channels) - mu2_sq | |
| sigma12 = F.conv2d(img1*img2, win, padding=window_size//2, groups=channels) - mu1_mu2 | |
| # ИСПРАВЛЕНО: Разделено присваивание, чтобы избежать UnboundLocalError | |
| L = 2.0 | |
| C1 = (0.01 * L) ** 2 | |
| C2 = (0.03 * L) ** 2 | |
| num = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) | |
| den = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) | |
| return (num / den).mean() | |
| def edge_loss(img1, img2): | |
| def get_edges(img): | |
| C = img.shape[1] | |
| # Sobel x kernel (horizontal edges) | |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=img.dtype, device=img.device).view(1, 1, 3, 3) | |
| # Sobel y kernel (vertical edges) | |
| sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=img.dtype, device=img.device).view(1, 1, 3, 3) | |
| # Repeat for each channel | |
| sobel_x_c = sobel_x.repeat(C, 1, 1, 1) | |
| sobel_y_c = sobel_y.repeat(C, 1, 1, 1) | |
| # Apply convolution per channel | |
| grad_x = F.conv2d(img, sobel_x_c, padding=1, groups=C) | |
| grad_y = F.conv2d(img, sobel_y_c, padding=1, groups=C) | |
| # Gradient magnitude | |
| return torch.sqrt(grad_x**2 + grad_y**2 + 1e-12) | |
| return F.l1_loss(get_edges(img1), get_edges(img2)) | |
| def dssim_loss(img1, img2): | |
| return 1.0 - _ssim(img1, img2) | |
| class MedianLossNormalizer: | |
| def __init__(self, desired_ratios: dict, window_steps: int): | |
| total_ratio = sum(desired_ratios.values()) | |
| self.ratios = {k: (v / total_ratio) if total_ratio > 0 else 0.0 for k, v in desired_ratios.items()} | |
| self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()} | |
| def update_and_total(self, absolute_losses: dict): | |
| for k, v in absolute_losses.items(): | |
| if k in self.buffers: | |
| self.buffers[k].append(float(v.detach().abs().cpu())) | |
| medians = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers} | |
| coefficients = {k: (self.ratios[k] / max(medians[k], 1e-12)) for k in self.ratios} | |
| total_loss = sum(coefficients[k] * absolute_losses[k] for k in absolute_losses if k in coefficients) | |
| return total_loss, coefficients, medians | |
| loss_normalizer = MedianLossNormalizer(LOSS_RATIOS, MEDIAN_COEFF_STEPS) | |
| # --------------------------- Sample Generation --------------------------- | |
| def get_fixed_samples(n=3): | |
| indices = random.sample(range(len(dataset)), min(n, len(dataset))) | |
| tensors = [input_tfm(random_crop(dataset[i], HIGH_RESOLUTION)) for i in indices] | |
| return torch.stack(tensors).to(accelerator.device, DTYPE) | |
| fixed_samples = get_fixed_samples() | |
| def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image: | |
| arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) | |
| return Image.fromarray(arr) | |
| def generate_and_save_samples(step=None): | |
| try: | |
| unwrapped_vae = accelerator.unwrap_model(vae) | |
| temp_vae = get_core_model(unwrapped_vae).eval() | |
| lpips_net = get_lpips_loss() | |
| original_high_res = fixed_samples | |
| input_tensor = original_high_res.to(dtype=next(temp_vae.parameters()).dtype) if MODEL_RESOLUTION == HIGH_RESOLUTION else F.interpolate(original_high_res, size=(MODEL_RESOLUTION, MODEL_RESOLUTION), mode="area") | |
| encoder_output = temp_vae.encode(input_tensor) | |
| latents = encoder_output.latent_dist.mean if TRAIN_DECODER_ONLY else encoder_output.latent_dist.sample() | |
| reconstructed_images = temp_vae.decode(latents).sample | |
| if reconstructed_images.shape[-2:] != original_high_res.shape[-2:]: | |
| reconstructed_images = F.interpolate(reconstructed_images, size=original_high_res.shape[-2:], mode="bilinear", align_corners=False) | |
| for i in range(reconstructed_images.shape[0]): | |
| _to_pil_uint8(original_high_res[i]).save(os.path.join(GENERATED_FOLDER, f"sample_real_{i}.png")) | |
| _to_pil_uint8(reconstructed_images[i]).save(os.path.join(GENERATED_FOLDER, f"sample_decoded_{i}.png")) | |
| if USE_WANDB and accelerator.is_main_process: | |
| log_data = {"lpips_mean": float(np.mean([lpips_net(original_high_res[i:i+1], reconstructed_images[i:i+1]).item() for i in range(len(original_high_res))]))} | |
| for i in range(len(original_high_res)): | |
| log_data[f"sample/real_{i}"] = wandb.Image(os.path.join(GENERATED_FOLDER, f"sample_real_{i}.png")) | |
| log_data[f"sample/decoded_{i}"] = wandb.Image(os.path.join(GENERATED_FOLDER, f"sample_decoded_{i}.png")) | |
| wandb.log(log_data, step=step) | |
| finally: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if accelerator.is_main_process and SAVE_MODEL: | |
| print("[INFO] Generating initial samples before training...") | |
| generate_and_save_samples(step=0) | |
| accelerator.wait_for_everyone() | |
| # --------------------------- Training Loop --------------------------- | |
| progress_bar = tqdm(total=total_steps, desc="Training", disable=not accelerator.is_local_main_process) | |
| global_step = 0 | |
| min_loss = float("inf") | |
| num_samples_per_epoch = max(1, int(total_steps / max(1, SAMPLE_INTERVAL_SHARE * NUM_EPOCHS))) | |
| sample_interval = max(1, int(round(num_samples_per_epoch / GRADIENT_ACCUMULATION_STEPS))) | |
| for epoch in range(NUM_EPOCHS): | |
| vae.train() | |
| batch_losses_history, batch_grads_history = [], [] | |
| tracked_losses = {k: [] for k in LOSS_RATIOS.keys()} | |
| for batch_idx, imgs in enumerate(dataloader): | |
| with accelerator.accumulate(vae): | |
| imgs = imgs.to(accelerator.device) | |
| imgs_low = imgs if MODEL_RESOLUTION == HIGH_RESOLUTION else F.interpolate(imgs, size=(MODEL_RESOLUTION, MODEL_RESOLUTION), mode="area") | |
| model_dtype = next(vae.parameters()).dtype | |
| input_images = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low | |
| current_vae_model = get_core_model(accelerator.unwrap_model(vae)) | |
| encoder_output = current_vae_model.encode(input_images) | |
| latents = encoder_output.latent_dist.mean if TRAIN_DECODER_ONLY else encoder_output.latent_dist.sample() | |
| rec_f32 = current_vae_model.decode(latents).sample.to(torch.float32) | |
| imgs_f32 = imgs.to(torch.float32) | |
| mae_loss = F.l1_loss(rec_f32, imgs_f32) | |
| mse_loss = F.mse_loss(rec_f32, imgs_f32) | |
| lpips_loss_val = get_lpips_loss()(rec_f32, imgs_f32).mean() | |
| fdl_loss_val = fdl_loss_fn(rec_f32, imgs_f32) | |
| dssim_loss_val = dssim_loss(rec_f32, imgs_f32) | |
| edge_loss_val = edge_loss(rec_f32, imgs_f32) | |
| kl_loss = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32) | |
| if FULL_TRAINING and not TRAIN_DECODER_ONLY: | |
| mean = encoder_output.latent_dist.mean | |
| logvar = encoder_output.latent_dist.logvar | |
| kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp()) | |
| absolute_losses = { | |
| "mae": mae_loss, "mse": mse_loss, "lpips": lpips_loss_val, | |
| "fdl": fdl_loss_val, "dssim": dssim_loss_val, "kl": kl_loss, | |
| "edge": edge_loss_val, | |
| } | |
| total_loss, coeffs, medians = loss_normalizer.update_and_total(absolute_losses) | |
| if torch.isnan(total_loss) or torch.isinf(total_loss): | |
| raise RuntimeError("NaN/Inf loss encountered during training.") | |
| accelerator.backward(total_loss) | |
| current_grad_norm = torch.tensor(0.0, device=accelerator.device) | |
| if accelerator.sync_gradients: | |
| current_grad_norm = accelerator.clip_grad_norm_(trainable_params, CLIP_GRAD_NORM) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| global_step += 1 | |
| progress_bar.update(1) | |
| if accelerator.is_main_process: | |
| try: current_lr = optimizer.param_groups[0]["lr"] | |
| except Exception: current_lr = scheduler.get_last_lr()[0] | |
| batch_losses_history.append(total_loss.detach().item()) | |
| batch_grads_history.append(float(current_grad_norm.detach().cpu().item())) | |
| for k, v in absolute_losses.items(): | |
| tracked_losses[k].append(float(v.detach().item())) | |
| if USE_WANDB and accelerator.sync_gradients: | |
| log_dict = {"total_loss": batch_losses_history[-1], "learning_rate": current_lr, "epoch": epoch, "grad_norm": batch_grads_history[-1]} | |
| for k, v in absolute_losses.items(): log_dict[f"loss_{k}"] = float(v.detach().item()) | |
| for k in coeffs: log_dict[f"coeff_{k}"] = float(coeffs[k]) | |
| wandb.log(log_dict, step=global_step) | |
| if global_step > 0 and global_step % sample_interval == 0: | |
| if accelerator.is_main_process: | |
| generate_and_save_samples(step=global_step) | |
| accelerator.wait_for_everyone() | |
| n_logs = min(len(batch_losses_history), sample_interval) | |
| avg_total = float(np.mean(batch_losses_history[-n_logs:])) | |
| avg_grad = float(np.mean(batch_grads_history[-n_logs:])) | |
| # ЯВНОЕ ЛОГИРОВАНИЕ КОМПОНЕНТ ПОТЕРЬ | |
| loss_avgs = {k: float(np.mean(tracked_losses[k][-n_logs:])) for k in tracked_losses if len(tracked_losses[k]) >= n_logs} | |
| print(f"Epoch {epoch} | Step {global_step} | " | |
| f"Total: {avg_total:.5f} | " | |
| f"LPIPS: {loss_avgs.get('lpips', 0):.5f} | " | |
| f"DSSIM: {loss_avgs.get('dssim', 0):.5f} | " | |
| f"MAE: {loss_avgs.get('mae', 0):.5f} | " | |
| f"FDL: {loss_avgs.get('fdl', 0):.5f} | " | |
| f"EDGE: {loss_avgs.get('edge', 0):.5f} | " | |
| f"MSE: {loss_avgs.get('mse', 0):.5f} | " | |
| f"Grad: {avg_grad:.5f} | LR: {current_lr:.9f}") | |
| if SAVE_MODEL and avg_total < min_loss * SAVE_BARRIER: | |
| min_loss = avg_total | |
| print(f"[INFO] Saving model with improved loss: {min_loss:.6f}") | |
| get_core_model(accelerator.unwrap_model(vae)).save_pretrained(SAVE_AS) | |
| if accelerator.is_main_process: | |
| print(f"Epoch {epoch} completed. Average Loss: {float(np.mean(batch_losses_history)):.6f}") | |
| if accelerator.is_main_process: | |
| print("Training finished – saving final model.") | |
| if SAVE_MODEL: | |
| get_core_model(accelerator.unwrap_model(vae)).save_pretrained(SAVE_AS) | |
| accelerator.free_memory() | |
| if torch.distributed.is_initialized(): | |
| torch.distributed.destroy_process_group() | |
| print("Training complete. Done!") |