#from comet_ml import Experiment import os import math import torch import numpy as np import matplotlib.pyplot as plt from torch.utils.data import DataLoader, Sampler from torch.utils.data.distributed import DistributedSampler from torch.optim.lr_scheduler import LambdaLR from collections import defaultdict from diffusers import UNet2DConditionModel, AutoencoderKL from accelerate import Accelerator from datasets import load_from_disk from tqdm import tqdm from PIL import Image, ImageOps import wandb import random import gc from accelerate.state import DistributedType from torch.distributed import broadcast_object_list from torch.utils.checkpoint import checkpoint from diffusers.models.attention_processor import AttnProcessor2_0 from datetime import datetime import bitsandbytes as bnb import torch.nn.functional as F from collections import deque from transformers import AutoTokenizer, AutoModel # --------------------------- Параметры --------------------------- ds_path = "/workspace/sdxs/datasets/768" project = "unet" batch_size = 36 base_learning_rate = 2.7e-5 #4e-5 min_learning_rate = 1e-5 #2.7e-5 num_epochs = 80 sample_interval_share = 5 max_length = 192 use_wandb = True use_comet_ml = False save_model = True use_decay = True fbp = False optimizer_type = "adam8bit" torch_compile = False unet_gradient = True fixed_seed = False shuffle = True comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" comet_ml_workspace = "recoilme" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True #torch.backends.cuda.enable_mem_efficient_sdp(False) dtype = torch.float32 save_barrier = 1.01 warmup_percent = 0.01 percentile_clipping = 96 #97 betta2 = 0.999 eps = 1e-7 clip_grad_norm = 1.0 limit = 0 checkpoints_folder = "" mixed_precision = "no" gradient_accumulation_steps = 1 accelerator = Accelerator( mixed_precision=mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps ) device = accelerator.device # Параметры для диффузии n_diffusion_steps = 40 samples_to_generate = 12 guidance_scale = 4 # Папки для сохранения результатов generated_folder = "samples" os.makedirs(generated_folder, exist_ok=True) # Настройка seed current_date = datetime.now() seed = int(current_date.strftime("%Y%m%d")) if fixed_seed: torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # --------------------------- Параметры LoRA --------------------------- lora_name = "" lora_rank = 32 lora_alpha = 64 print("init") loss_ratios = { "mse": 1., } median_coeff_steps = 256 # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ class MedianLossNormalizer: def __init__(self, desired_ratios: dict, window_steps: int): # нормируем доли на случай, если сумма != 1 s = sum(desired_ratios.values()) self.ratios = {k: (v / s) for k, v in desired_ratios.items()} self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()} self.window = window_steps def update_and_total(self, losses: dict): """ losses: dict ключ->тензор (значения лоссов) Поведение: - буферим ABS(l) только для активных (ratio>0) лоссов - coeff = ratio / median(abs(loss)) - total = sum(coeff * loss) по активным лоссам CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление. """ # буферим только активные лоссы for k, v in losses.items(): if k in self.buffers and self.ratios.get(k, 0) > 0: self.buffers[k].append(float(v.detach().abs().cpu())) meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers} coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios} # суммируем только по активным (ratio>0) total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0) return total, coeffs, meds # создаём normalizer после определения loss_ratios normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps) # --------------------------- Инициализация WandB --------------------------- if accelerator.is_main_process: if use_wandb: wandb.init(project=project+lora_name, config={ "batch_size": batch_size, "base_learning_rate": base_learning_rate, "num_epochs": num_epochs, "optimizer_type": optimizer_type, }) if use_comet_ml: from comet_ml import Experiment comet_experiment = Experiment( api_key=comet_ml_api_key, project_name=project, workspace=comet_ml_workspace ) hyper_params = { "batch_size": batch_size, "base_learning_rate": base_learning_rate, "num_epochs": num_epochs, } comet_experiment.log_parameters(hyper_params) # Включение Flash Attention 2/SDPA torch.backends.cuda.enable_flash_sdp(True) # --------------------------- Загрузка моделей --------------------------- vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval() tokenizer = AutoTokenizer.from_pretrained("tokenizer") text_model = AutoModel.from_pretrained("text_encoder").to(device).eval() # --- [UPDATED] Функция кодирования текста (с маской и пулингом) --- def encode_texts(texts, max_length=max_length): # Если тексты пустые (для unconditional), создаем заглушки if texts is None: # В случае None возвращаем нули (логика для get_negative_embedding) # Но здесь мы обычно ожидаем список строк. pass with torch.no_grad(): if isinstance(texts, str): texts = [texts] for i, prompt_item in enumerate(texts): messages = [ {"role": "user", "content": prompt_item}, ] prompt_item = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, #enable_thinking=True, ) #print(prompt_item+"\n") texts[i] = prompt_item toks = tokenizer( texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length ).to(device) outs = text_model(**toks, output_hidden_states=True, return_dict=True) # Используем last_hidden_state или hidden_states[-1] (если Qwen, лучше last_hidden_state - прим человека: ХУЙ) hidden = outs.hidden_states[-2] # 2. Маска внимания attention_mask = toks["attention_mask"] # 3. Пулинг-эмбеддинг (Последний токен) sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = hidden.shape[0] pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths] #return hidden, attention_mask # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ --- # 1. Расширяем пулинг-вектор до последовательности [B, 1, emb] pooled_expanded = pooled.unsqueeze(1) # 2. Объединяем последовательность токенов и пулинг-вектор # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ # Теперь: [B, 1 + L, emb]. Пулинг стал токеном в НАЧАЛЕ. new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1) # 3. Обновляем маску внимания для нового токена # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО. # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1. new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1) return new_encoder_hidden_states, new_attention_mask shift_factor = getattr(vae.config, "shift_factor", 0.0) if shift_factor is None: shift_factor = 0.0 scaling_factor = getattr(vae.config, "scaling_factor", 1.0) if scaling_factor is None: scaling_factor = 1.0 from diffusers import FlowMatchEulerDiscreteScheduler num_train_timesteps = 1000 scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps) class DistributedResolutionBatchSampler(Sampler): def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True): self.dataset = dataset self.batch_size = max(1, batch_size // num_replicas) self.num_replicas = num_replicas self.rank = rank self.shuffle = shuffle self.drop_last = drop_last self.epoch = 0 try: widths = np.array(dataset["width"]) heights = np.array(dataset["height"]) except KeyError: widths = np.zeros(len(dataset)) heights = np.zeros(len(dataset)) self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0) self.size_groups = {} for w, h in self.size_keys: mask = (widths == w) & (heights == h) self.size_groups[(w, h)] = np.where(mask)[0] self.group_num_batches = {} total_batches = 0 for size, indices in self.size_groups.items(): num_full_batches = len(indices) // (self.batch_size * self.num_replicas) self.group_num_batches[size] = num_full_batches total_batches += num_full_batches self.num_batches = (total_batches // self.num_replicas) * self.num_replicas def __iter__(self): if torch.cuda.is_available(): torch.cuda.empty_cache() all_batches = [] rng = np.random.RandomState(self.epoch) for size, indices in self.size_groups.items(): indices = indices.copy() if self.shuffle: rng.shuffle(indices) num_full_batches = self.group_num_batches[size] if num_full_batches == 0: continue valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas] batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas) start_idx = self.rank * self.batch_size end_idx = start_idx + self.batch_size gpu_batches = batches[:, start_idx:end_idx] all_batches.extend(gpu_batches) if self.shuffle: rng.shuffle(all_batches) accelerator.wait_for_everyone() return iter(all_batches) def __len__(self): return self.num_batches def set_epoch(self, epoch): self.epoch = epoch # --- [UPDATED] Функция для фиксированных семплов --- def get_fixed_samples_by_resolution(dataset, samples_per_group=1): size_groups = defaultdict(list) try: widths = dataset["width"] heights = dataset["height"] except KeyError: widths = [0] * len(dataset) heights = [0] * len(dataset) for i, (w, h) in enumerate(zip(widths, heights)): size = (w, h) size_groups[size].append(i) fixed_samples = {} for size, indices in size_groups.items(): n_samples = min(samples_per_group, len(indices)) if len(size_groups)==1: n_samples = samples_to_generate if n_samples == 0: continue sample_indices = random.sample(indices, n_samples) samples_data = [dataset[idx] for idx in sample_indices] latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype) texts = [item["text"] for item in samples_data] # Кодируем тексты на лету, чтобы получить маски и пулинг embeddings, masks = encode_texts(texts) fixed_samples[size] = (latents, embeddings, masks, texts) print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") return fixed_samples if limit > 0: dataset = load_from_disk(ds_path).select(range(limit)) else: dataset = load_from_disk(ds_path) # --- [UPDATED] Collate Function --- def collate_fn_simple(batch): # 1. Латенты (VAE) latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype) # 2. Текст берем сырой из датасета raw_texts = [item["text"] for item in batch] texts = [ "" if t.lower().startswith("zero") else "" if random.random() < 0.05 else t[1:].lstrip() if t.startswith(".") else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip() for t in raw_texts ] # 3. Кодируем на лету # Возвращает: hidden (B, L, D), mask (B, L) embeddings, attention_mask = encode_texts(texts) # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long attention_mask = attention_mask.to(dtype=torch.int64) return latents, embeddings, attention_mask batch_sampler = DistributedResolutionBatchSampler( dataset=dataset, batch_size=batch_size, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=shuffle ) dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple) print("Total samples", len(dataloader)) dataloader = accelerator.prepare(dataloader) start_epoch = 0 global_step = 0 total_training_steps = (len(dataloader) * num_epochs) world_size = accelerator.state.num_processes # Загрузка UNet latest_checkpoint = os.path.join(checkpoints_folder, project) if os.path.isdir(latest_checkpoint): print("Загружаем UNet из чекпоинта:", latest_checkpoint) unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype) if unet_gradient: unet.enable_gradient_checkpointing() unet.set_use_memory_efficient_attention_xformers(False) try: unet.set_attn_processor(AttnProcessor2_0()) except Exception as e: print(f"Ошибка при включении SDPA: {e}") unet.set_use_memory_efficient_attention_xformers(True) else: raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}") if lora_name: # ... (Код LoRA без изменений, опущен для краткости, если не используется, иначе раскомментируйте оригинальный блок) ... pass # Оптимизатор if lora_name: trainable_params = [p for p in unet.parameters() if p.requires_grad] else: if fbp: trainable_params = list(unet.parameters()) def create_optimizer(name, params): if name == "adam8bit": return bnb.optim.AdamW8bit( params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01, percentile_clipping=percentile_clipping ) elif name == "adam": return torch.optim.AdamW( params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01 ) elif name == "muon": from muon import MuonWithAuxAdam trainable_params = [p for p in params if p.requires_grad] hidden_weights = [p for p in trainable_params if p.ndim >= 2] hidden_gains_biases = [p for p in trainable_params if p.ndim < 2] param_groups = [ dict(params=hidden_weights, use_muon=True, lr=1e-3, weight_decay=1e-4), dict(params=hidden_gains_biases, use_muon=False, lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4), ] optimizer = MuonWithAuxAdam(param_groups) from snooc import SnooC return SnooC(optimizer) else: raise ValueError(f"Unknown optimizer: {name}") if fbp: optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params} def optimizer_hook(param): optimizer_dict[param].step() optimizer_dict[param].zero_grad(set_to_none=True) for param in trainable_params: param.register_post_accumulate_grad_hook(optimizer_hook) unet, optimizer = accelerator.prepare(unet, optimizer_dict) else: optimizer = create_optimizer(optimizer_type, unet.parameters()) def lr_schedule(step): x = step / (total_training_steps * world_size) warmup = warmup_percent if not use_decay: return base_learning_rate if x < warmup: return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup) decay_ratio = (x - warmup) / (1 - warmup) return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \ (1 + math.cos(math.pi * decay_ratio)) lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate) unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) if torch_compile: print("compiling") unet = torch.compile(unet) print("compiling - ok") # Фиксированные семплы fixed_samples = get_fixed_samples_by_resolution(dataset) # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) --- def get_negative_embedding(neg_prompt="", batch_size=1): if not neg_prompt: hidden_dim = 2048 seq_len = max_length empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) return empty_emb, empty_mask uncond_emb, uncond_mask = encode_texts([neg_prompt]) uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1) uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1) return uncond_emb, uncond_mask # Получаем негативные (пустые) условия для валидации uncond_emb, uncond_mask = get_negative_embedding("low quality") # --- Функция генерации семплов --- @torch.compiler.disable() @torch.no_grad() def generate_and_save_samples(fixed_samples_cpu, uncond_data, step): uncond_emb, uncond_mask = uncond_data original_model = None try: if not torch_compile: original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval() else: original_model = unet.eval() vae.to(device=device).eval() all_generated_images = [] all_captions = [] # Распаковываем 5 элементов (добавились mask) for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items(): width, height = size sample_latents = sample_latents.to(dtype=dtype, device=device) sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device) sample_mask = sample_mask.to(device=device) latents = torch.randn( sample_latents.shape, device=device, dtype=sample_latents.dtype, generator=torch.Generator(device=device).manual_seed(seed) ) scheduler.set_timesteps(n_diffusion_steps, device=device) for t in scheduler.timesteps: if guidance_scale != 1: latent_model_input = torch.cat([latents, latents], dim=0) # Подготовка батчей для CFG (Negative + Positive) # 1. Embeddings curr_batch_size = sample_text_embeddings.shape[0] seq_len = sample_text_embeddings.shape[1] hidden_dim = sample_text_embeddings.shape[2] neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1) text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0) # 2. Masks neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1) attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0) else: latent_model_input = latents text_embeddings_batch = sample_text_embeddings attention_mask_batch = sample_mask # Предсказание с передачей всех условий model_out = original_model( latent_model_input, t, encoder_hidden_states=text_embeddings_batch, encoder_attention_mask=attention_mask_batch, ) flow = getattr(model_out, "sample", model_out) if guidance_scale != 1: flow_uncond, flow_cond = flow.chunk(2) flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond) latents = scheduler.step(flow, t, latents).prev_sample current_latents = latents latent_for_vae = current_latents.detach() / scaling_factor + shift_factor decoded = vae.decode(latent_for_vae.to(torch.float32)).sample decoded_fp32 = decoded.to(torch.float32) for img_idx, img_tensor in enumerate(decoded_fp32): img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy() img = img.transpose(1, 2, 0) if np.isnan(img).any(): print("NaNs found, saving stopped! Step:", step) pil_img = Image.fromarray((img * 255).astype("uint8")) max_w_overall = max(s[0] for s in fixed_samples_cpu.keys()) max_h_overall = max(s[1] for s in fixed_samples_cpu.keys()) max_w_overall = max(255, max_w_overall) max_h_overall = max(255, max_h_overall) padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white') all_generated_images.append(padded_img) caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else "" all_captions.append(caption_text) sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" pil_img.save(sample_path, "JPEG", quality=96) if use_wandb and accelerator.is_main_process: wandb_images = [ wandb.Image(img, caption=f"{all_captions[i]}") for i, img in enumerate(all_generated_images) ] wandb.log({"generated_images": wandb_images}) if use_comet_ml and accelerator.is_main_process: for i, img in enumerate(all_generated_images): comet_experiment.log_image( image_data=img, name=f"step_{step}_img_{i}", step=step, metadata={"caption": all_captions[i]} ) finally: vae.to("cpu") torch.cuda.empty_cache() gc.collect() # --------------------------- Генерация сэмплов перед обучением --------------------------- if accelerator.is_main_process: if save_model: print("Генерация сэмплов до старта обучения...") generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0) accelerator.wait_for_everyone() def save_checkpoint(unet, variant=""): if accelerator.is_main_process: if lora_name: save_lora_checkpoint(unet) else: model_to_save = None if not torch_compile: model_to_save = accelerator.unwrap_model(unet) else: model_to_save = unet if variant != "": model_to_save.to(dtype=torch.float16).save_pretrained( os.path.join(checkpoints_folder, f"{project}"), variant=variant ) else: model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}")) unet = unet.to(dtype=dtype) # --------------------------- Тренировочный цикл --------------------------- if accelerator.is_main_process: print(f"Total steps per GPU: {total_training_steps}") epoch_loss_points = [] progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") steps_per_epoch = len(dataloader) sample_interval = max(1, steps_per_epoch // sample_interval_share) min_loss = 2. for epoch in range(start_epoch, start_epoch + num_epochs): batch_losses = [] batch_grads = [] batch_sampler.set_epoch(epoch) accelerator.wait_for_everyone() unet.train() for step, (latents, embeddings, attention_mask) in enumerate(dataloader): with accelerator.accumulate(unet): if save_model == False and step == 5 : used_gb = torch.cuda.max_memory_allocated() / 1024**3 print(f"Шаг {step}: {used_gb:.2f} GB") # шум noise = torch.randn_like(latents, dtype=latents.dtype) # берём t из [0, 1] t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype) # интерполяция между x0 и шумом noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise # делаем integer timesteps для UNet timesteps = (t * scheduler.config.num_train_timesteps).long() # --- Вызов UNet с маской --- model_pred = unet( noisy_latents, timesteps, encoder_hidden_states=embeddings, encoder_attention_mask=attention_mask ).sample target = noise - latents mse_loss = F.mse_loss(model_pred.float(), target.float()) batch_losses.append(mse_loss.detach().item()) if (global_step % 100 == 0) or (global_step % sample_interval == 0): accelerator.wait_for_everyone() losses_dict = {} losses_dict["mse"] = mse_loss # === Нормализация всех лоссов === abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()} total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm) if (global_step % 100 == 0) or (global_step % sample_interval == 0): accelerator.wait_for_everyone() accelerator.backward(total_loss) if (global_step % 100 == 0) or (global_step % sample_interval == 0): accelerator.wait_for_everyone() grad = 0.0 if not fbp: if accelerator.sync_gradients: #with torch.amp.autocast('cuda', enabled=False): grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm) grad = float(grad_val) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if accelerator.sync_gradients: global_step += 1 progress_bar.update(1) if accelerator.is_main_process: if fbp: current_lr = base_learning_rate else: current_lr = lr_scheduler.get_last_lr()[0] batch_grads.append(grad) log_data = {} log_data["loss"] = mse_loss.detach().item() log_data["lr"] = current_lr log_data["grad"] = grad log_data["loss_total"] = float(total_loss.item()) for k, c in coeffs.items(): log_data[f"coeff_{k}"] = float(c) if accelerator.sync_gradients: if use_wandb: wandb.log(log_data, step=global_step) if use_comet_ml: comet_experiment.log_metrics(log_data, step=global_step) if global_step % sample_interval == 0: # Передаем tuple (emb, mask) для негатива generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) last_n = sample_interval if save_model: has_losses = len(batch_losses) > 0 avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0 last_loss = batch_losses[-1] if has_losses else 0.0 max_loss = max(avg_sample_loss, last_loss) should_save = max_loss < min_loss * save_barrier print( f"Saving: {should_save} | Max: {max_loss:.4f} | " f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}" ) # 6. Сохранение и обновление if should_save: min_loss = max_loss save_checkpoint(unet) if accelerator.is_main_process: avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0 avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0 print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") log_data_ep = { "epoch_loss": avg_epoch_loss, "epoch_grad": avg_epoch_grad, "epoch": epoch + 1, } if use_wandb: wandb.log(log_data_ep) if use_comet_ml: comet_experiment.log_metrics(log_data_ep) if accelerator.is_main_process: print("Обучение завершено! Сохраняем финальную модель...") if save_model: save_checkpoint(unet,"fp16") if use_comet_ml: comet_experiment.end() accelerator.free_memory() if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() print("Готово!")