| import os |
| import math |
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from torch.utils.data import DataLoader, Sampler |
| from collections import defaultdict |
| from torch.optim.lr_scheduler import LambdaLR |
| from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler |
| from accelerate import Accelerator |
| from datasets import load_from_disk |
| from tqdm import tqdm |
| from PIL import Image,ImageOps |
| import wandb |
| import random |
| import gc |
| from accelerate.state import DistributedType |
| from torch.distributed import broadcast_object_list |
| from torch.utils.checkpoint import checkpoint |
| from diffusers.models.attention_processor import AttnProcessor2_0 |
| from datetime import datetime |
| import bitsandbytes as bnb |
|
|
|
|
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| from dataclasses import dataclass |
| from typing import Tuple, Any, Optional, Union |
|
|
| import torch |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.utils import BaseOutput |
| from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin |
|
|
|
|
| @dataclass |
| class FlowMatchingEulerSchedulerOutput(BaseOutput): |
| """ |
| Output class for the scheduler's `step` function output. |
| |
| Args: |
| prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
| Computed sample `(x_{t-1})` of previous timestep (which in flow-matching notation should be noted as |
| `(x_{t+h})`). `prev_sample` should be used as next model input in the denoising loop. |
| pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
| The predicted denoised sample `(x_{0})` (which in flow-matching notation should be noted as |
| `(x_{1})`) based on the model output from the current timestep. |
| `pred_original_sample` can be used to preview progress or for guidance. |
| """ |
|
|
| prev_sample: torch.Tensor |
| pred_original_sample: Optional[torch.Tensor] = None |
|
|
|
|
| def get_time_coefficients(timestep: torch.Tensor, ndim: int) -> torch.Tensor: |
| """ |
| Convert timestep to time coefficients. |
| Args: |
| timestep (`torch.Tensor`): Timestep tensor. |
| ndim (`int`): Number of dimensions. |
| Returns: |
| `torch.Tensor`: Time coefficients. |
| """ |
| return timestep.reshape((timestep.shape[0], *([1] * (ndim - 1) ))) |
|
|
|
|
| class FlowMatchingEulerScheduler(SchedulerMixin, ConfigMixin): |
| """ |
| `FlowMatchingEulerScheduler` is a scheduler for training and inferencing Conditional Flow Matching models (CFMs). |
| |
| Flow Matching (FM) is a novel, simulation-free methodology for training Continuous Normalizing Flows (CNFs) by |
| regressing vector fields of predetermined conditional probability paths, facilitating scalable training and |
| efficient sample generation through the utilization of various probability paths, including Gaussian and |
| Optimal Transport (OT) paths, thereby enhancing model performance and generalization capabilities |
| |
| Args: |
| num_inference_steps (`int`, defaults to 100): |
| The number of steps on inference. |
| """ |
|
|
| @register_to_config |
| def __init__(self, num_inference_steps: int = 100): |
| self.timesteps = None |
| self.num_inference_steps = None |
| self.h = None |
|
|
| if num_inference_steps is not None: |
| self.set_timesteps(num_inference_steps) |
|
|
| @staticmethod |
| def add_noise(original_samples: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: |
| """ |
| Add noise to the given sample |
| |
| Args: |
| original_samples (`torch.Tensor`): |
| The original sample that is to be noised |
| noise (`torch.Tensor`): |
| The noise that is used to noise the image |
| timestep (`torch.Tensor`): |
| Timestep used to create linear interpolation `x_t = t * x_1 + (1 - t) * x_0`. |
| Where x_1 is a target distribution, x_0 is a source distribution and t (timestep) ∈ [0, 1] |
| """ |
|
|
| t = get_time_coefficients(timestep, original_samples.ndim) |
|
|
| noised_sample = t * original_samples + (1 - t) * noise |
|
|
| return noised_sample |
|
|
| def set_timesteps(self, num_inference_steps: int = 100) -> None: |
| """ |
| Set number of inference steps (Euler intagration steps) |
| |
| Args: |
| num_inference_steps (`int`, defaults to 100): |
| The number of steps on inference. |
| """ |
|
|
| self.num_inference_steps = num_inference_steps |
| self.h = 1 / num_inference_steps |
| self.timesteps = torch.arange(0, 1, self.h) |
|
|
| def step(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor, |
| return_dict: bool = True) -> Union[FlowMatchingEulerSchedulerOutput, Tuple]: |
| """ |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
| process from the learned model outputs (most often the predicted noise). |
| |
| Args: |
| model_output (`torch.Tensor`): |
| The direct output from learned diffusion model. |
| timestep (`float`): |
| Timestep used to perform Euler Method `x_t = h * f(x_t, t) + x_{t-1}`. |
| Where x_1 is a target distribution, x_0 is a source distribution and t (timestep) ∈ [0, 1] |
| sample (`torch.Tensor`): |
| A current instance of a sample created by the diffusion process. |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. |
| |
| Returns: |
| [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: |
| If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a |
| tuple is returned where the first element is the sample tensor. |
| """ |
|
|
| step = FlowMatchingEulerSchedulerOutput( |
| prev_sample=sample + self.h * model_output, |
| pred_original_sample=sample + (1 - get_time_coefficients(timestep, model_output.ndim)) * model_output |
| ) |
|
|
| if return_dict: |
| return step |
|
|
| return step.prev_sample, |
|
|
| @staticmethod |
| def get_velocity(original_samples: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: |
| """ |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
| process from the learned model outputs (most often the predicted noise). |
| |
| Args: |
| original_samples (`torch.Tensor`): |
| The original sample that is to be noised |
| noise (`torch.Tensor`): |
| The noise that is used to noise the image |
| |
| Returns: |
| `torch.Tensor` |
| """ |
|
|
| return original_samples - noise |
|
|
| @staticmethod |
| def scale_model_input(sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: |
| """ |
| Ensures interchangeability with schedulers that need to scale the denoising model input depending on the |
| current timestep. |
| |
| Args: |
| sample (`torch.Tensor`): |
| The input sample. |
| timestep (`int`, *optional*): |
| The current timestep in the diffusion chain. |
| |
| Returns: |
| `torch.Tensor`: |
| A scaled input sample. |
| """ |
|
|
| return sample |
|
|
|
|
|
|
|
|
|
|
| |
|
|
| |
| save_path = "datasets/768" |
| batch_size = 30 |
| base_learning_rate = 4e-6 |
| min_learning_rate = 2.5e-5 |
| num_epochs = 1 |
| project = "sdxs" |
| <<<<<<< HEAD |
| use_wandb = True |
| save_model = True |
| ======= |
| use_wandb = False |
| save_model = False |
| >>>>>>> d0c94e4 (sdxxxs) |
| limit = 0 |
| checkpoints_folder = "" |
|
|
| |
| n_diffusion_steps = 40 |
| samples_to_generate = 12 |
| guidance_scale = 5 |
| sample_interval_share = 25 |
|
|
| |
| generated_folder = "samples" |
| os.makedirs(generated_folder, exist_ok=True) |
|
|
| |
| current_date = datetime.now() |
| seed = int(current_date.strftime("%Y%m%d")) |
| fixed_seed = True |
| if fixed_seed: |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| |
| |
| lora_name = "" |
| lora_rank = 32 |
| lora_alpha = 64 |
|
|
| print("init") |
| |
| torch.backends.cuda.enable_flash_sdp(True) |
| |
| dtype = torch.bfloat16 |
| accelerator = Accelerator(mixed_precision="bf16") |
| device = accelerator.device |
| gen = torch.Generator(device=device) |
| gen.manual_seed(seed) |
|
|
| |
| if use_wandb and accelerator.is_main_process: |
| wandb.init(project=project+lora_name, config={ |
| "batch_size": batch_size, |
| "base_learning_rate": base_learning_rate, |
| "num_epochs": num_epochs, |
| "n_diffusion_steps": n_diffusion_steps, |
| "samples_to_generate": samples_to_generate, |
| "dtype": str(dtype) |
| }) |
|
|
| |
| class ResolutionBatchSampler(Sampler): |
| """Сэмплер, который группирует примеры по одинаковым размерам""" |
| def __init__(self, dataset, batch_size, shuffle=True, drop_last=False): |
| self.dataset = dataset |
| self.batch_size = batch_size |
| self.shuffle = shuffle |
| self.drop_last = drop_last |
| |
| |
| self.size_groups = defaultdict(list) |
|
|
| try: |
| widths = dataset["width"] |
| heights = dataset["height"] |
| except KeyError: |
| widths = [0] * len(dataset) |
| heights = [0] * len(dataset) |
| |
| for i, (w, h) in enumerate(zip(widths, heights)): |
| size = (w, h) |
| self.size_groups[size].append(i) |
| |
| |
| print(f"Найдено {len(self.size_groups)} уникальных размеров:") |
| for size, indices in sorted(self.size_groups.items(), key=lambda x: len(x[1]), reverse=True): |
| width, height = size |
| print(f" {width}x{height}: {len(indices)} примеров") |
| |
| |
| self.reset() |
| |
| def reset(self): |
| """Сбрасывает и перемешивает индексы""" |
| self.batches = [] |
| |
| for size, indices in self.size_groups.items(): |
| if self.shuffle: |
| indices_copy = indices.copy() |
| random.shuffle(indices_copy) |
| else: |
| indices_copy = indices |
| |
| |
| for i in range(0, len(indices_copy), self.batch_size): |
| batch_indices = indices_copy[i:i + self.batch_size] |
| |
| |
| if self.drop_last and len(batch_indices) < self.batch_size: |
| continue |
| |
| self.batches.append(batch_indices) |
| |
| |
| if self.shuffle: |
| random.shuffle(self.batches) |
| |
| def __iter__(self): |
| self.reset() |
| return iter(self.batches) |
| |
| def __len__(self): |
| return len(self.batches) |
|
|
| |
| def get_fixed_samples_by_resolution(dataset, samples_per_group=1): |
| """Выбирает фиксированные семплы для каждого уникального разрешения""" |
| |
| size_groups = defaultdict(list) |
| try: |
| widths = dataset["width"] |
| heights = dataset["height"] |
| except KeyError: |
| widths = [0] * len(dataset) |
| heights = [0] * len(dataset) |
| for i, (w, h) in enumerate(zip(widths, heights)): |
| size = (w, h) |
| size_groups[size].append(i) |
| |
| |
| fixed_samples = {} |
| for size, indices in size_groups.items(): |
| |
| n_samples = min(samples_per_group, len(indices)) |
| if len(size_groups)==1: |
| n_samples = samples_to_generate |
| if n_samples == 0: |
| continue |
| |
| |
| sample_indices = random.sample(indices, n_samples) |
| samples_data = [dataset[idx] for idx in sample_indices] |
| |
| |
| latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).to(device) |
| embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).to(device) |
| texts = [item["text"] for item in samples_data] |
| |
| |
| fixed_samples[size] = (latents, embeddings, texts) |
| |
| print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") |
| return fixed_samples |
|
|
| if limit > 0: |
| dataset = load_from_disk(save_path).select(range(limit)) |
| else: |
| dataset = load_from_disk(save_path) |
|
|
|
|
|
|
| def collate_fn(batch): |
| |
| latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device) |
| embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device) |
| return latents, embeddings |
| |
| |
| batch_sampler = ResolutionBatchSampler(dataset, batch_size=batch_size, shuffle=True) |
| dataloader = DataLoader(dataset, batch_sampler=batch_sampler) |
|
|
| print("Total samples",len(dataloader)) |
| dataloader = accelerator.prepare(dataloader) |
|
|
| |
| |
| vae = AutoencoderKL.from_pretrained("AuraDiffusion/16ch-vae").to("cpu", dtype=dtype) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| scheduler = FlowMatchingEulerScheduler( |
| <<<<<<< HEAD |
| num_train_timesteps=1000, |
| ======= |
| |
| >>>>>>> d0c94e4 (sdxxxs) |
| ) |
|
|
| |
| start_epoch = 0 |
| global_step = 0 |
|
|
| |
| total_training_steps = (len(dataloader) * num_epochs) |
| |
| world_size = accelerator.state.num_processes |
| print(f"World Size: {world_size}") |
|
|
| |
| latest_checkpoint = os.path.join(checkpoints_folder, project) |
| if os.path.isdir(latest_checkpoint): |
| print("Загружаем UNet из чекпоинта:", latest_checkpoint) |
| unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device, dtype=dtype) |
| unet.enable_gradient_checkpointing() |
| unet.set_use_memory_efficient_attention_xformers(False) |
| try: |
| unet.set_attn_processor(AttnProcessor2_0()) |
| print("SDPA включен через set_attn_processor.") |
| except Exception as e: |
| print(f"Ошибка при включении SDPA: {e}") |
| print("Попытка использовать enable_xformers_memory_efficient_attention.") |
| unet.set_use_memory_efficient_attention_xformers(True) |
|
|
| if lora_name: |
| print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---") |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| from peft.tuners.lora import LoraModel |
| import os |
| |
| unet.requires_grad_(False) |
| print("Параметры базового UNet заморожены.") |
|
|
| |
| lora_config = LoraConfig( |
| r=lora_rank, |
| lora_alpha=lora_alpha, |
| target_modules=["to_q", "to_k", "to_v", "to_out.0"], |
| ) |
| unet.add_adapter(lora_config) |
|
|
| |
| from peft import get_peft_model |
| |
| peft_unet = get_peft_model(unet, lora_config) |
|
|
| |
| params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad) |
| |
|
|
| |
| if accelerator.is_main_process: |
| lora_params_count = sum(p.numel() for p in params_to_optimize) |
| total_params_count = sum(p.numel() for p in unet.parameters()) |
| print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}") |
| print(f"Общее количество параметров UNet: {total_params_count:,}") |
|
|
| |
| lora_save_path = os.path.join("lora", lora_name) |
| os.makedirs(lora_save_path, exist_ok=True) |
|
|
| |
| def save_lora_checkpoint(model): |
| if accelerator.is_main_process: |
| print(f"Сохраняем LoRA адаптеры в {lora_save_path}") |
| from peft.utils.save_and_load import get_peft_model_state_dict |
| |
| lora_state_dict = get_peft_model_state_dict(model) |
|
|
| |
| torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin")) |
| |
| |
| model.peft_config["default"].save_pretrained(lora_save_path) |
| |
| from diffusers import StableDiffusionXLPipeline |
| StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict) |
|
|
| |
| |
| if lora_name: |
| |
| trainable_params = [p for p in unet.parameters() if p.requires_grad] |
| else: |
| |
| trainable_params = list(unet.parameters()) |
| |
| |
| optimizer_dict = { |
| p: bnb.optim.AdamW8bit( |
| [p], |
| lr=base_learning_rate, |
| betas=(0.9, 0.999), |
| weight_decay=1e-5, |
| eps=1e-8 |
| ) for p in trainable_params |
| } |
|
|
| |
| def optimizer_hook(param): |
| optimizer_dict[param].step() |
| optimizer_dict[param].zero_grad(set_to_none=True) |
|
|
| |
| for param in trainable_params: |
| param.register_post_accumulate_grad_hook(optimizer_hook) |
|
|
| |
| unet, optimizer = accelerator.prepare(unet, optimizer_dict) |
|
|
| |
| |
| fixed_samples = get_fixed_samples_by_resolution(dataset) |
|
|
|
|
| @torch.no_grad() |
| def generate_and_save_samples(fixed_samples,step): |
| """ |
| Генерирует семплы для каждого из разрешений и сохраняет их. |
| |
| Args: |
| step: Текущий шаг обучения |
| fixed_samples: Словарь, где ключи - размеры (width, height), |
| а значения - кортежи (latents, embeddings) |
| """ |
| try: |
| original_model = accelerator.unwrap_model(unet) |
| |
| vae.to(accelerator.device, dtype=dtype) |
| |
| |
| scheduler.set_timesteps(n_diffusion_steps) |
| |
| all_generated_images = [] |
| size_info = [] |
| all_captions = [] |
| |
| |
| for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples.items(): |
| width, height = size |
| size_info.append(f"{width}x{height}") |
| |
| |
| |
| noise = torch.randn( |
| sample_latents.shape, |
| generator=gen, |
| device=sample_latents.device, |
| dtype=sample_latents.dtype |
| ) |
| |
| |
| current_latents = noise.clone() |
| |
| |
| if guidance_scale > 0: |
| empty_embeddings = torch.zeros_like(sample_text_embeddings) |
| text_embeddings = torch.cat([empty_embeddings, sample_text_embeddings], dim=0) |
| else: |
| text_embeddings = sample_text_embeddings |
| |
| |
| for t in scheduler.timesteps: |
| |
| t = t.unsqueeze(dim=0).to(device) |
| if guidance_scale > 0: |
| latent_model_input = torch.cat([current_latents] * 2) |
| latent_model_input = scheduler.scale_model_input(latent_model_input, t) |
| else: |
| latent_model_input = scheduler.scale_model_input(current_latents, t) |
| |
| |
| noise_pred = original_model(latent_model_input, t, text_embeddings).sample |
| |
| |
| if guidance_scale > 0: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
| |
| current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample |
| |
| |
| latent = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor |
| latent = latent.to(accelerator.device, dtype=dtype) |
| decoded = vae.decode(latent).sample |
| |
| |
| for img_idx, img_tensor in enumerate(decoded): |
| img = (img_tensor.to(torch.float32) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0) |
| pil_img = Image.fromarray((img * 255).astype("uint8")) |
| |
| max_width = max(size[0] for size in fixed_samples.keys()) |
| max_height = max(size[1] for size in fixed_samples.keys()) |
| max_width = max(255,max_width) |
| max_height = max(255,max_height) |
| |
| |
| padded_img = ImageOps.pad(pil_img, (max_width, max_height), color='white') |
| |
| all_generated_images.append(padded_img) |
|
|
| caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else "" |
| all_captions.append(caption_text) |
| |
| |
| save_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" |
| pil_img.save(save_path, "JPEG", quality=96) |
| |
| |
| if use_wandb and accelerator.is_main_process: |
| wandb_images = [ |
| wandb.Image(img, caption=f"{all_captions[i]}") |
| for i, img in enumerate(all_generated_images) |
| ] |
| wandb.log({"generated_images": wandb_images, "global_step": step}) |
| |
| finally: |
| |
| vae.to("cpu") |
| if original_model is not None: |
| del original_model |
| |
| for var in list(locals().keys()): |
| if isinstance(locals()[var], torch.Tensor): |
| del locals()[var] |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| |
| if accelerator.is_main_process: |
| if save_model: |
| print("Генерация сэмплов до старта обучения...") |
| generate_and_save_samples(fixed_samples,0) |
|
|
| |
| def save_checkpoint(unet): |
| if accelerator.is_main_process: |
| if lora_name: |
| |
| save_lora_checkpoint(unet) |
| else: |
| |
| accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}")) |
|
|
| |
| |
| if accelerator.is_main_process: |
| print(f"Total steps per GPU: {total_training_steps}") |
| print(f"[GPU {accelerator.process_index}] Total steps: {total_training_steps}") |
|
|
| epoch_loss_points = [] |
| progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") |
|
|
| |
| steps_per_epoch = len(dataloader) |
| sample_interval = max(1, steps_per_epoch // sample_interval_share) |
|
|
| |
| for epoch in range(start_epoch, start_epoch + num_epochs): |
| batch_losses = [] |
| unet.train() |
| |
| for step, (latents, embeddings) in enumerate(dataloader): |
| with accelerator.accumulate(unet): |
| if save_model == False and step == 3 : |
| used_gb = torch.cuda.max_memory_allocated() / 1024**3 |
| print(f"Шаг {step}: {used_gb:.2f} GB") |
| |
| noise = torch.randn_like(latents) |
| |
| timesteps = torch.randint( |
| 0, |
| 1000, |
| (latents.shape[0],), |
| device=device |
| ) / 1000 |
| |
| |
| noisy_latents = scheduler.add_noise(latents, noise, timesteps) |
|
|
| |
| noise_pred = unet(noisy_latents, timesteps, embeddings).sample |
| |
| |
| target = scheduler.get_velocity(latents, noise) |
|
|
| |
| loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float()) |
|
|
| |
| accelerator.backward(loss) |
| |
| |
| global_step += 1 |
| |
| |
| progress_bar.update(1) |
| |
| |
| if accelerator.is_main_process: |
| current_lr = base_learning_rate |
| batch_losses.append(loss.detach().item()) |
| |
| |
| if use_wandb: |
| wandb.log({ |
| "loss": loss.detach().item(), |
| "learning_rate": current_lr, |
| "epoch": epoch, |
| "global_step": global_step |
| }) |
| |
| |
| if global_step % sample_interval == 0: |
| if save_model: |
| save_checkpoint(unet) |
| |
| generate_and_save_samples(fixed_samples,global_step) |
| |
| |
| avg_loss = np.mean(batch_losses[-sample_interval:]) |
| |
| if use_wandb: |
| wandb.log({"intermediate_loss": avg_loss}) |
| |
| |
| |
| if accelerator.is_main_process: |
| avg_epoch_loss = np.mean(batch_losses) |
| print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") |
| if use_wandb: |
| wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1}) |
|
|
| |
| if accelerator.is_main_process: |
| print("Обучение завершено! Сохраняем финальную модель...") |
| |
| |
| save_checkpoint(unet) |
| print("Готово!") |