recoilme commited on
Commit
b21699b
·
1 Parent(s): 772daf3
samples/unet_192x384_0.jpg → result_grid.jpg RENAMED
File without changes
samples/unet_256x384_0.jpg DELETED

Git LFS Details

  • SHA256: 71c029b2d342b339035eb1c00d1179520b3df6c1887c9f6ffb6dd69ec794628d
  • Pointer size: 130 Bytes
  • Size of remote file: 28.3 kB
samples/unet_320x384_0.jpg DELETED

Git LFS Details

  • SHA256: ff649a103e26e2d7509eac66cedc82b56d850f4792332404086b874dda69b644
  • Pointer size: 130 Bytes
  • Size of remote file: 75.7 kB
samples/unet_384x192_0.jpg DELETED

Git LFS Details

  • SHA256: b6da1cf2877c0e41e237c59360cd30afad1ac968557d4f31e1850cbf1d31855f
  • Pointer size: 130 Bytes
  • Size of remote file: 29.8 kB
samples/unet_384x256_0.jpg DELETED

Git LFS Details

  • SHA256: 5d487645de57527c39ba7e161c2516458ec71d1455c87de17052fa575a631aff
  • Pointer size: 130 Bytes
  • Size of remote file: 33.4 kB
samples/unet_384x320_0.jpg DELETED

Git LFS Details

  • SHA256: ff097668f8dda8afae87d66b1a16f648278691add6927104930831f1afc66963
  • Pointer size: 130 Bytes
  • Size of remote file: 52 kB
samples/unet_384x384_0.jpg DELETED

Git LFS Details

  • SHA256: 38df045e8f0233c43f14738f6ec8270c14eafe6b95ea8fed83650f182bf6c01c
  • Pointer size: 130 Bytes
  • Size of remote file: 86.3 kB
src/model_create.ipynb CHANGED
@@ -569,7 +569,7 @@
569
  "\n",
570
  "\n",
571
  "if 1:\n",
572
- " checkpoint_path = \"/workspace/sdxs3d/unet\"#\"sdxs\"\n",
573
  " import torch\n",
574
  " from diffusers import UNet2DConditionModel\n",
575
  " print(\"test unet\")\n",
 
569
  "\n",
570
  "\n",
571
  "if 1:\n",
572
+ " checkpoint_path = \"/workspace/sdxs3d/butterfly\"#\"sdxs\"\n",
573
  " import torch\n",
574
  " from diffusers import UNet2DConditionModel\n",
575
  " print(\"test unet\")\n",
src/sample.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
train.py CHANGED
@@ -31,7 +31,7 @@ project = "unet"
31
  batch_size = 128
32
  base_learning_rate = 8e-5
33
  min_learning_rate = 2e-5
34
- num_epochs = 100
35
  # samples/save per epoch
36
  sample_interval_share = 5
37
  use_wandb = True
@@ -71,7 +71,7 @@ device = accelerator.device
71
  # Параметры для диффузии
72
  n_diffusion_steps = 50
73
  samples_to_generate = 12
74
- guidance_scale = 1
75
 
76
  # Папки для сохранения результатов
77
  generated_folder = "samples"
@@ -184,7 +184,16 @@ if scaling_factor is None:
184
  latents_mean = getattr(vae.config, "latents_mean", None)
185
  latents_std = getattr(vae.config, "latents_std", None)
186
 
 
187
 
 
 
 
 
 
 
 
 
188
 
189
 
190
  class DistributedResolutionBatchSampler(Sampler):
@@ -420,9 +429,73 @@ else:
420
  # --------------------------- Фиксированные семплы для генерации ---------------------------
421
  fixed_samples = get_fixed_samples_by_resolution(dataset)
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  @torch.compiler.disable()
424
  @torch.no_grad()
425
- def generate_and_save_samples(fixed_samples_cpu, step):
426
  original_model = None
427
  try:
428
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
@@ -436,39 +509,41 @@ def generate_and_save_samples(fixed_samples_cpu, step):
436
  width, height = size
437
  sample_latents = sample_latents.to(dtype=dtype, device=device)
438
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
439
-
440
- noise = torch.randn(
 
441
  sample_latents.shape,
442
- generator=gen,
443
  device=device,
444
- dtype=sample_latents.dtype
 
445
  )
446
- current_latents = noise.clone()
447
-
448
- if guidance_scale != 1:
449
- empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
450
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
451
- else:
452
- text_embeddings_batch = sample_text_embeddings
453
-
454
- timesteps = torch.linspace(0, 1, n_diffusion_steps+1, device=device, dtype=sample_latents.dtype)
455
- for i in range(0, n_diffusion_steps):
456
- t_cur = timesteps[i].unsqueeze(0)
457
- t_next = timesteps[i+1]
458
- dt = t_next - t_cur
459
  if guidance_scale != 1:
460
- latent_model_input = torch.cat((current_latents, current_latents))
 
461
  else:
462
- latent_model_input = current_latents
463
- t_batch = t_cur.repeat(latent_model_input.shape[0]).to(device)
464
- t_batch = (t_batch * 1000).long().view(-1)
465
- flow = original_model(latent_model_input, t_batch, text_embeddings_batch).sample
466
-
 
 
 
467
  if guidance_scale != 1:
468
  flow_uncond, flow_cond = flow.chunk(2)
469
  flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
 
 
 
 
 
470
 
471
- current_latents = current_latents + flow * dt.to(device)
472
 
473
  # Параметры нормализации
474
  latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
@@ -535,7 +610,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
535
  if accelerator.is_main_process:
536
  if save_model:
537
  print("Генерация сэмплов до старта обучения...")
538
- generate_and_save_samples(fixed_samples,0)
539
  accelerator.wait_for_everyone()
540
 
541
  # Модифицируем функцию сохранения модели для поддержки LoRA
@@ -574,15 +649,26 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
574
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
575
  print(f"Шаг {step}: {used_gb:.2f} GB")
576
 
 
577
  noise = torch.randn_like(latents, dtype=latents.dtype)
578
- t = logit_normal_samples((latents.shape[0], 1, 1, 1), mu=0.0, sigma=1.0, device=latents.device, dtype=latents.dtype)
579
- noisy_latents = (1 - t) * noise + t * latents
580
-
581
- t_for_unet = (t * 1000).long().view(-1)
582
- model_pred = unet(noisy_latents, t_for_unet, embeddings).sample
583
- target_pred = latents - noise
584
-
585
- mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
 
 
 
 
 
 
 
 
 
 
586
 
587
  # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
588
  batch_losses.append(mse_loss.detach().item())
@@ -629,7 +715,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
629
 
630
  # Генерируем сэмплы с заданным интервалом
631
  if global_step % sample_interval == 0:
632
- generate_and_save_samples(fixed_samples,global_step)
633
  last_n = sample_interval
634
 
635
  if save_model:
 
31
  batch_size = 128
32
  base_learning_rate = 8e-5
33
  min_learning_rate = 2e-5
34
+ num_epochs = 50
35
  # samples/save per epoch
36
  sample_interval_share = 5
37
  use_wandb = True
 
71
  # Параметры для диффузии
72
  n_diffusion_steps = 50
73
  samples_to_generate = 12
74
+ guidance_scale = 4
75
 
76
  # Папки для сохранения результатов
77
  generated_folder = "samples"
 
184
  latents_mean = getattr(vae.config, "latents_mean", None)
185
  latents_std = getattr(vae.config, "latents_std", None)
186
 
187
+ from diffusers import FlowMatchEulerDiscreteScheduler
188
 
189
+ # Подстрой под свои параметры
190
+ num_train_timesteps = 1000
191
+
192
+ scheduler = FlowMatchEulerDiscreteScheduler(
193
+ num_train_timesteps=num_train_timesteps,
194
+ #shift=3.0, # пример; подбирается при необходимости
195
+ #use_dynamic_shifting=True
196
+ )
197
 
198
 
199
  class DistributedResolutionBatchSampler(Sampler):
 
429
  # --------------------------- Фиксированные семплы для генерации ---------------------------
430
  fixed_samples = get_fixed_samples_by_resolution(dataset)
431
 
432
+ def get_negative_embedding(neg_prompt="", batch_size=1):
433
+ """
434
+ Возвращает эмбеддинг негативного промпта с батчем.
435
+ Загружает модели, вычисляет эмбеддинг, выгружает модели на CPU.
436
+ """
437
+ import torch
438
+ from transformers import AutoTokenizer, AutoModel
439
+
440
+ # Настройки
441
+ dtype = torch.float16
442
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
443
+
444
+ # Загрузка моделей (если ещё не загружены)
445
+ if not hasattr(get_negative_embedding, "tokenizer"):
446
+ get_negative_embedding.tokenizer = AutoTokenizer.from_pretrained(
447
+ "Qwen/Qwen3-Embedding-0.6B", padding_side="left"
448
+ )
449
+ get_negative_embedding.text_model = AutoModel.from_pretrained(
450
+ "Qwen/Qwen3-Embedding-0.6B"
451
+ ).to(device).eval()
452
+
453
+ # Вспомогательная функция для пулинга
454
+ def last_token_pool(last_hidden_states, attention_mask):
455
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
456
+ if left_padding:
457
+ return last_hidden_states[:, -1]
458
+ else:
459
+ sequence_lengths = attention_mask.sum(dim=1) - 1
460
+ batch_size = last_hidden_states.shape[0]
461
+ return last_hidden_states[
462
+ torch.arange(batch_size, device=last_hidden_states.device),
463
+ sequence_lengths
464
+ ]
465
+
466
+ # Вычисление эмбеддинга
467
+ def encode_texts(texts, max_length=512):
468
+ with torch.inference_mode():
469
+ toks = get_negative_embedding.tokenizer(
470
+ texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
471
+ ).to(device)
472
+ outs = get_negative_embedding.text_model(**toks)
473
+ emb = last_token_pool(outs.last_hidden_state, toks["attention_mask"])
474
+ emb = emb.unsqueeze(1) # Добавляем размерность sequence_length
475
+ return emb
476
+
477
+ # Возвращаем эмбеддинг
478
+ if not neg_prompt:
479
+ hidden_dim = 1024 # Размерность эмбеддинга Qwen3-Embedding-0.6B
480
+ return torch.zeros((batch_size, 1, hidden_dim), dtype=dtype, device=device)
481
+
482
+ uncond_emb = encode_texts([neg_prompt]).to(dtype=dtype, device=device)
483
+ uncond_emb = uncond_emb.repeat(batch_size, 1, 1) # Добавляем батч
484
+
485
+ # Выгружаем модели
486
+ if hasattr(get_negative_embedding, "text_model"):
487
+ get_negative_embedding.text_model = get_negative_embedding.text_model.to("cpu")
488
+ if hasattr(get_negative_embedding, "tokenizer"):
489
+ del get_negative_embedding.tokenizer # Освобождаем память
490
+ torch.cuda.empty_cache()
491
+
492
+ return uncond_emb
493
+
494
+ uncond_emb = get_negative_embedding("low quality")
495
+
496
  @torch.compiler.disable()
497
  @torch.no_grad()
498
+ def generate_and_save_samples(fixed_samples_cpu,empty_embeddings, step):
499
  original_model = None
500
  try:
501
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
 
509
  width, height = size
510
  sample_latents = sample_latents.to(dtype=dtype, device=device)
511
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
512
+
513
+ # начальный шум
514
+ latents = torch.randn(
515
  sample_latents.shape,
 
516
  device=device,
517
+ dtype=sample_latents.dtype,
518
+ generator=torch.Generator(device=device).manual_seed(seed)
519
  )
520
+
521
+ # подготовим timesteps через шедулер
522
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
523
+
524
+ for t in scheduler.timesteps:
525
+ # guidance: удваиваем батч
 
 
 
 
 
 
 
526
  if guidance_scale != 1:
527
+ latent_model_input = torch.cat([latents, latents], dim=0)
528
+ text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
529
  else:
530
+ latent_model_input = latents
531
+ text_embeddings_batch = sample_text_embeddings
532
+
533
+ # предсказание потока (velocity)
534
+ model_out = original_model(latent_model_input, t, encoder_hidden_states=text_embeddings_batch)
535
+ flow = getattr(model_out, "sample", model_out)
536
+
537
+ # guidance объединение
538
  if guidance_scale != 1:
539
  flow_uncond, flow_cond = flow.chunk(2)
540
  flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
541
+
542
+ # шаг через scheduler
543
+ latents = scheduler.step(flow, t, latents).prev_sample
544
+
545
+ current_latents = latents
546
 
 
547
 
548
  # Параметры нормализации
549
  latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
 
610
  if accelerator.is_main_process:
611
  if save_model:
612
  print("Генерация сэмплов до старта обучения...")
613
+ generate_and_save_samples(fixed_samples,uncond_emb,0)
614
  accelerator.wait_for_everyone()
615
 
616
  # Модифицируем функцию сохранения модели для поддержки LoRA
 
649
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
650
  print(f"Шаг {step}: {used_gb:.2f} GB")
651
 
652
+ # шум
653
  noise = torch.randn_like(latents, dtype=latents.dtype)
654
+
655
+ # берём t из [0, 1]
656
+ t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
657
+
658
+ # интерполяция между x0 и шумом
659
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
660
+
661
+ # делаем integer timesteps для UNet
662
+ timesteps = (t * scheduler.config.num_train_timesteps).long()
663
+
664
+ # предсказание потока (Flow)
665
+ model_pred = unet(noisy_latents, timesteps, embeddings).sample
666
+
667
+ # таргет — векторное поле (= разность между конечными точками)
668
+ target = noise - latents # или latents - noise?
669
+
670
+ # MSE лосс
671
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
672
 
673
  # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
674
  batch_losses.append(mse_loss.detach().item())
 
715
 
716
  # Генерируем сэмплы с заданным интервалом
717
  if global_step % sample_interval == 0:
718
+ generate_and_save_samples(fixed_samples,uncond_emb, global_step)
719
  last_n = sample_interval
720
 
721
  if save_model:
train_chatgpt.py DELETED
@@ -1,678 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from diffusers import UNet2DConditionModel, AutoencoderKL
11
- from accelerate import Accelerator
12
- from datasets import load_from_disk
13
- from tqdm import tqdm
14
- from PIL import Image, ImageOps
15
- import wandb
16
- import random
17
- import gc
18
- from accelerate.state import DistributedType
19
- from torch.distributed import broadcast_object_list
20
- from torch.utils.checkpoint import checkpoint
21
- from diffusers.models.attention_processor import AttnProcessor2_0
22
- from datetime import datetime
23
- import bitsandbytes as bnb
24
- import torch.nn.functional as F
25
- from collections import deque
26
-
27
- # --------------------------- Параметры ---------------------------
28
- ds_path = "/workspace/sdxs3d/datasets/butterfly"
29
- project = "unet"
30
- batch_size = 128
31
- base_learning_rate = 9e-5
32
- min_learning_rate = 1e-5
33
- num_epochs = 100
34
- # samples/save per epoch
35
- sample_interval_share = 1
36
- use_wandb = True
37
- use_comet_ml = False
38
- save_model = False
39
- use_decay = True
40
- fbp = False # fused backward pass
41
- optimizer_type = "adam8bit"
42
- torch_compile = False
43
- unet_gradient = True
44
- clip_sample = False #Scheduler
45
- fixed_seed = True
46
- shuffle = True
47
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
48
- comet_ml_workspace = "recoilme"
49
- torch.backends.cuda.matmul.allow_tf32 = True
50
- torch.backends.cudnn.allow_tf32 = True
51
- torch.backends.cuda.enable_mem_efficient_sdp(False)
52
- dtype = torch.float32
53
- save_barrier = 1.01
54
- warmup_percent = 0.01
55
- percentile_clipping = 99 # 8bit optim
56
- betta2 = 0.995
57
- eps = 1e-8
58
- clip_grad_norm = 1.0
59
- steps_offset = 0 # Scheduler
60
- limit = 0
61
- checkpoints_folder = ""
62
- mixed_precision = "no" # "fp16"
63
- gradient_accumulation_steps = 1
64
- accelerator = Accelerator(
65
- mixed_precision=mixed_precision,
66
- gradient_accumulation_steps=gradient_accumulation_steps
67
- )
68
- device = accelerator.device
69
-
70
- # Параметры для диффузии
71
- n_diffusion_steps = 50
72
- samples_to_generate = 12
73
- guidance_scale = 1
74
-
75
- # Папки для сохранения результатов
76
- generated_folder = "samples"
77
- os.makedirs(generated_folder, exist_ok=True)
78
-
79
- # Настройка seed для воспроизводимости
80
- current_date = datetime.now()
81
- seed = int(current_date.strftime("%Y%m%d"))
82
- if fixed_seed:
83
- torch.manual_seed(seed)
84
- np.random.seed(seed)
85
- random.seed(seed)
86
- if torch.cuda.is_available():
87
- torch.cuda.manual_seed_all(seed)
88
-
89
- # --------------------------- Параметры LoRA ---------------------------
90
- lora_name = ""
91
- lora_rank = 32
92
- lora_alpha = 64
93
-
94
- print("init")
95
-
96
- # --------------------------- вспомогательные функции ---------------------------
97
- def sample_timesteps_bias(
98
- batch_size: int,
99
- progress: float, # [0..1]
100
- num_train_timesteps: int, # обычно 1000
101
- steps_offset: int = 0,
102
- device=None,
103
- mode: str = "beta", # "beta", "uniform"
104
- ) -> torch.Tensor:
105
- max_idx = num_train_timesteps - 1 - steps_offset
106
-
107
- if mode == "beta":
108
- alpha = 1.0 + .5 * (1.0 - progress)
109
- beta = 1.0 + .5 * progress
110
- samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
111
-
112
- elif mode == "uniform":
113
- samples = torch.rand(batch_size)
114
-
115
- else:
116
- raise ValueError(f"Unknown mode: {mode}")
117
-
118
- timesteps = steps_offset + (samples * max_idx).long().to(device)
119
- return timesteps
120
-
121
-
122
- def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
123
- normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
124
- logit_normal_samples = torch.sigmoid(normal_samples)
125
- return logit_normal_samples
126
-
127
- # --------------------------- Инициализация WandB ---------------------------
128
- if accelerator.is_main_process:
129
- if use_wandb:
130
- wandb.init(project=project + lora_name, config={
131
- "batch_size": batch_size,
132
- "base_learning_rate": base_learning_rate,
133
- "num_epochs": num_epochs,
134
- "fbp": fbp,
135
- "optimizer_type": optimizer_type,
136
- })
137
- if use_comet_ml:
138
- from comet_ml import Experiment
139
- comet_experiment = Experiment(
140
- api_key=comet_ml_api_key,
141
- project_name=project,
142
- workspace=comet_ml_workspace
143
- )
144
- hyper_params = {
145
- "batch_size": batch_size,
146
- "base_learning_rate": base_learning_rate,
147
- "min_learning_rate": min_learning_rate,
148
- "num_epochs": num_epochs,
149
- "n_diffusion_steps": n_diffusion_steps,
150
- "guidance_scale": guidance_scale,
151
- "optimizer_type": optimizer_type,
152
- "mixed_precision": mixed_precision,
153
- }
154
- comet_experiment.log_parameters(hyper_params)
155
-
156
- # Включение Flash Attention 2/SDPA
157
- torch.backends.cuda.enable_flash_sdp(True)
158
- # --------------------------- Инициализация Accelerator --------------------
159
- gen = torch.Generator(device=device)
160
- gen.manual_seed(seed)
161
-
162
- # --------------------------- Загрузка моделей ---------------------------
163
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
164
-
165
- shift_factor = getattr(vae.config, "shift_factor", 0.0)
166
- if shift_factor is None:
167
- shift_factor = 0.0
168
-
169
- scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
170
- if scaling_factor is None:
171
- scaling_factor = 1.0
172
-
173
- latents_mean = getattr(vae.config, "latents_mean", None)
174
- latents_std = getattr(vae.config, "latents_std", None)
175
-
176
-
177
- class DistributedResolutionBatchSampler(Sampler):
178
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
179
- self.dataset = dataset
180
- self.batch_size = max(1, batch_size // num_replicas)
181
- self.num_replicas = num_replicas
182
- self.rank = rank
183
- self.shuffle = shuffle
184
- self.drop_last = drop_last
185
- self.epoch = 0
186
-
187
- try:
188
- widths = np.array(dataset["width"])
189
- heights = np.array(dataset["height"])
190
- except KeyError:
191
- widths = np.zeros(len(dataset))
192
- heights = np.zeros(len(dataset))
193
-
194
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
195
- self.size_groups = {}
196
- for w, h in self.size_keys:
197
- mask = (widths == w) & (heights == h)
198
- self.size_groups[(w, h)] = np.where(mask)[0]
199
-
200
- self.group_num_batches = {}
201
- total_batches = 0
202
- for size, indices in self.size_groups.items():
203
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
204
- self.group_num_batches[size] = num_full_batches
205
- total_batches += num_full_batches
206
-
207
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
208
-
209
- def __iter__(self):
210
- if torch.cuda.is_available():
211
- torch.cuda.empty_cache()
212
- all_batches = []
213
- rng = np.random.RandomState(self.epoch)
214
-
215
- for size, indices in self.size_groups.items():
216
- indices = indices.copy()
217
- if self.shuffle:
218
- rng.shuffle(indices)
219
- num_full_batches = self.group_num_batches[size]
220
- if num_full_batches == 0:
221
- continue
222
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
223
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
224
- start_idx = self.rank * self.batch_size
225
- end_idx = start_idx + self.batch_size
226
- gpu_batches = batches[:, start_idx:end_idx]
227
- all_batches.extend(gpu_batches)
228
-
229
- if self.shuffle:
230
- rng.shuffle(all_batches)
231
- accelerator.wait_for_everyone()
232
- return iter(all_batches)
233
-
234
- def __len__(self):
235
- return self.num_batches
236
-
237
- def set_epoch(self, epoch):
238
- self.epoch = epoch
239
-
240
- # Функция для выборки фиксированных семплов по размерам
241
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
242
- size_groups = defaultdict(list)
243
- try:
244
- widths = dataset["width"]
245
- heights = dataset["height"]
246
- except KeyError:
247
- widths = [0] * len(dataset)
248
- heights = [0] * len(dataset)
249
- for i, (w, h) in enumerate(zip(widths, heights)):
250
- size = (w, h)
251
- size_groups[size].append(i)
252
-
253
- fixed_samples = {}
254
- for size, indices in size_groups.items():
255
- n_samples = min(samples_per_group, len(indices))
256
- if len(size_groups) == 1:
257
- n_samples = samples_to_generate
258
- if n_samples == 0:
259
- continue
260
- sample_indices = random.sample(indices, n_samples)
261
- samples_data = [dataset[idx] for idx in sample_indices]
262
-
263
- # FIXED: keep fixed samples on CPU to avoid device/device-transfer issues when creating dataset
264
- latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).cpu()
265
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).cpu()
266
- texts = [item["text"] for item in samples_data]
267
- fixed_samples[size] = (latents, embeddings, texts)
268
-
269
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
270
- return fixed_samples
271
-
272
- if limit > 0:
273
- dataset = load_from_disk(ds_path).select(range(limit))
274
- else:
275
- dataset = load_from_disk(ds_path)
276
-
277
-
278
- def collate_fn_simple(batch):
279
- latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
280
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
281
- return latents, embeddings
282
-
283
- batch_sampler = DistributedResolutionBatchSampler(
284
- dataset=dataset,
285
- batch_size=batch_size,
286
- num_replicas=accelerator.num_processes,
287
- rank=accelerator.process_index,
288
- shuffle=shuffle
289
- )
290
-
291
- # NOTE: we create dataloader first, then prepare it with accelerator. We'll create optimizer/lr_scheduler after
292
- # we know len(dataloader) (which is per-process after prepare) so that scheduling is consistent.
293
-
294
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
295
- print("Total batches (pre-prepare):", len(dataloader))
296
- dataloader = accelerator.prepare(dataloader)
297
-
298
- # --------------------------- Теперь безопасно --- вычисляем шаги и создаём оптимизатор/шедулер ---------------------------
299
- steps_per_epoch = len(dataloader) # this is per-process (after prepare)
300
- total_training_steps = steps_per_epoch * num_epochs
301
- print(f"[rank {accelerator.process_index}] steps_per_epoch={steps_per_epoch}, total_training_steps={total_training_steps}")
302
-
303
- # --------------------------- Создание/загрузка UNet ---------------------------
304
- latest_checkpoint = os.path.join(checkpoints_folder, project)
305
- if os.path.isdir(latest_checkpoint):
306
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
307
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
308
- if torch_compile:
309
- print("compiling")
310
- torch.set_float32_matmul_precision('high')
311
- unet = torch.compile(unet)
312
- print("compiling - ok")
313
- if unet_gradient:
314
- unet.enable_gradient_checkpointing()
315
- unet.set_use_memory_efficient_attention_xformers(False)
316
- try:
317
- unet.set_attn_processor(AttnProcessor2_0())
318
- except Exception as e:
319
- print(f"Ошибка при включении SDPA: {e}")
320
- unet.set_use_memory_efficient_attention_xformers(True)
321
- else:
322
- raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
323
-
324
- # --------------------------- LoRA (если нужно) ---------------------------
325
- if lora_name:
326
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
327
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
328
- from peft.tuners.lora import LoraModel
329
- unet.requires_grad_(False)
330
- print("Параметры базового UNet заморожены.")
331
-
332
- lora_config = LoraConfig(
333
- r=lora_rank,
334
- lora_alpha=lora_alpha,
335
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
336
- )
337
- unet.add_adapter(lora_config)
338
-
339
- from peft import get_peft_model
340
- peft_unet = get_peft_model(unet, lora_config)
341
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
342
-
343
- if accelerator.is_main_process:
344
- lora_params_count = sum(p.numel() for p in params_to_optimize)
345
- total_params_count = sum(p.numel() for p in unet.parameters())
346
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
347
- print(f"Общее количество параметров UNet: {total_params_count:,}")
348
-
349
- lora_save_path = os.path.join("lora", lora_name)
350
- os.makedirs(lora_save_path, exist_ok=True)
351
-
352
- def save_lora_checkpoint(model):
353
- if accelerator.is_main_process:
354
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
355
- from peft.utils.save_and_load import get_peft_model_state_dict
356
- lora_state_dict = get_peft_model_state_dict(model)
357
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
358
- model.peft_config["default"].save_pretrained(lora_save_path)
359
- from diffusers import StableDiffusionXLPipeline
360
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
361
-
362
- # --------------------------- Оптимизатор ---------------------------
363
- if lora_name:
364
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
365
- else:
366
- trainable_params = list(unet.parameters()) if fbp else [p for p in unet.parameters() if p.requires_grad]
367
-
368
-
369
- def create_optimizer(name, params):
370
- if name == "adam8bit":
371
- return bnb.optim.AdamW8bit(
372
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
373
- percentile_clipping=percentile_clipping
374
- )
375
- elif name == "adam":
376
- return torch.optim.AdamW(
377
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
378
- )
379
- elif name == "lion8bit":
380
- return bnb.optim.Lion8bit(
381
- params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
382
- percentile_clipping=percentile_clipping
383
- )
384
- elif name == "adafactor":
385
- from transformers import Adafactor
386
- return Adafactor(
387
- params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
388
- warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
389
- beta1=0.9, weight_decay=0.01
390
- )
391
- else:
392
- raise ValueError(f"Unknown optimizer: {name}")
393
-
394
- if fbp:
395
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
396
- def optimizer_hook(param):
397
- optimizer_dict[param].step()
398
- optimizer_dict[param].zero_grad(set_to_none=True)
399
- for param in trainable_params:
400
- param.register_post_accumulate_grad_hook(optimizer_hook)
401
- # FIXED: prepare fbp variant (keeps original logic)
402
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
403
- else:
404
- optimizer = create_optimizer(optimizer_type, trainable_params)
405
-
406
- # FIXED: LR schedule should be based on total_training_steps (per-process steps * epochs)
407
- def lr_schedule(step):
408
- # step is current scheduler step (0..total_training_steps)
409
- x = step / max(1, total_training_steps)
410
- warmup = warmup_percent
411
- if not use_decay:
412
- return base_learning_rate
413
- if x < warmup:
414
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
415
- decay_ratio = (x - warmup) / (1 - warmup)
416
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
417
- (1 + math.cos(math.pi * decay_ratio))
418
-
419
- # LambdaLR expects a multiplier, so divide by base_learning_rate
420
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
421
-
422
- num_params = sum(p.numel() for p in unet.parameters())
423
- print(f"[rank {accelerator.process_index}] total params: {num_params}")
424
- for name, param in unet.named_parameters():
425
- if torch.isnan(param).any() or torch.isinf(param).any():
426
- print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
427
-
428
- # FIXED: prepare model, optimizer, scheduler AFTER creating them and after dataloader.prepare
429
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
430
-
431
- # --------------------------- Фиксированные семплы для генерации ---------------------------
432
- fixed_samples = get_fixed_samples_by_resolution(dataset)
433
-
434
- @torch.no_grad()
435
- def generate_and_save_samples(fixed_samples_cpu, step):
436
- original_model = None
437
- try:
438
- original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
439
- vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
440
-
441
- all_generated_images = []
442
- all_captions = []
443
-
444
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
445
- width, height = size
446
- # move CPU tensors to device here (they were kept on CPU in get_fixed_samples_by_resolution)
447
- sample_latents = sample_latents.to(dtype=dtype, device=device)
448
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
449
-
450
- noise = torch.randn(
451
- sample_latents.shape,
452
- generator=gen,
453
- device=device,
454
- dtype=sample_latents.dtype
455
- )
456
- current_latents = noise.clone()
457
-
458
- if guidance_scale != 1:
459
- empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
460
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
461
- else:
462
- text_embeddings_batch = sample_text_embeddings
463
-
464
- timesteps = torch.linspace(0, 1, n_diffusion_steps+1, device=device, dtype=sample_latents.dtype)
465
- for i in range(0, n_diffusion_steps):
466
- t_cur = timesteps[i].unsqueeze(0)
467
- t_next = timesteps[i+1]
468
- dt = t_next - t_cur
469
- if guidance_scale != 1:
470
- latent_model_input = torch.cat((current_latents, current_latents))
471
- else:
472
- latent_model_input = current_latents
473
- t_batch = t_cur.repeat(latent_model_input.shape[0]).to(device)
474
- t_batch = (t_batch * 1000).long().view(-1)
475
- flow = original_model(latent_model_input, t_batch, text_embeddings_batch).sample
476
-
477
- if guidance_scale != 1:
478
- flow_uncond, flow_cond = flow.chunk(2)
479
- flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
480
-
481
- current_latents = current_latents + flow * dt.to(device)
482
-
483
- # Параметры нормализации
484
- latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
485
- decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
486
- decoded_fp32 = decoded.to(torch.float32)
487
- for img_idx, img_tensor in enumerate(decoded_fp32):
488
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
489
- img = img.transpose(1, 2, 0)
490
- if np.isnan(img).any():
491
- print("NaNs found, saving stopped! Step:", step)
492
- pil_img = Image.fromarray((img * 255).astype("uint8"))
493
-
494
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
495
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
496
- max_w_overall = max(255, max_w_overall)
497
- max_h_overall = max(255, max_h_overall)
498
-
499
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
500
- all_generated_images.append(padded_img)
501
-
502
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
503
- all_captions.append(caption_text)
504
-
505
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
506
- pil_img.save(sample_path, "JPEG", quality=96)
507
-
508
- if use_wandb and accelerator.is_main_process:
509
- wandb_images = [
510
- wandb.Image(img, caption=f"{all_captions[i]}")
511
- for i, img in enumerate(all_generated_images)
512
- ]
513
- wandb.log({"generated_images": wandb_images}, step=step)
514
- if use_comet_ml and accelerator.is_main_process:
515
- for i, img in enumerate(all_generated_images):
516
- comet_experiment.log_image(
517
- image_data=img,
518
- name=f"step_{step}_img_{i}",
519
- step=step,
520
- metadata={
521
- "caption": all_captions[i],
522
- "width": img.width,
523
- "height": img.height,
524
- "global_step": step
525
- }
526
- )
527
- finally:
528
- vae.to("cpu")
529
- torch.cuda.empty_cache()
530
- gc.collect()
531
-
532
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
533
- if accelerator.is_main_process:
534
- if save_model:
535
- print("Генерация сэмплов до старта обучения...")
536
- generate_and_save_samples(fixed_samples, 0)
537
- accelerator.wait_for_everyone()
538
-
539
- # Модифицируем функцию сохранения модели для поддержки LoRA
540
- def save_checkpoint(unet, variant=""):
541
- if accelerator.is_main_process:
542
- if lora_name:
543
- save_lora_checkpoint(unet)
544
- else:
545
- model_to_save = accelerator.unwrap_model(unet)
546
- dest = os.path.join(checkpoints_folder, f"{project}")
547
- if variant != "":
548
- model_to_save.to(dtype=torch.float16).save_pretrained(dest, variant=variant)
549
- else:
550
- model_to_save.save_pretrained(dest)
551
-
552
-
553
- # --------------------------- Тренировочный цикл ---------------------------
554
- if accelerator.is_main_process:
555
- print(f"Total steps per GPU: {total_training_steps}")
556
-
557
- epoch_loss_points = []
558
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
559
-
560
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
561
- min_loss = 2.
562
- global_step = 0
563
-
564
- for epoch in range(0, num_epochs):
565
- # FIXED: set epoch on the dataloader's batch_sampler if available (accelerator may have wrapped it)
566
- if hasattr(dataloader, "batch_sampler") and hasattr(dataloader.batch_sampler, "set_epoch"):
567
- dataloader.batch_sampler.set_epoch(epoch)
568
- elif hasattr(batch_sampler, "set_epoch"):
569
- batch_sampler.set_epoch(epoch)
570
-
571
- accelerator.wait_for_everyone()
572
- unet.train()
573
-
574
- batch_losses = []
575
- batch_grads = []
576
-
577
- for step, (latents, embeddings) in enumerate(dataloader):
578
- with accelerator.accumulate(unet):
579
- if save_model == False and step == 5:
580
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
581
- print(f"Шаг {step}: {used_gb:.2f} GB")
582
-
583
- noise = torch.randn_like(latents, dtype=latents.dtype)
584
- t = logit_normal_samples((latents.shape[0], 1, 1, 1), mu=0.0, sigma=1.0, device=latents.device, dtype=latents.dtype)
585
- noisy_latents = (1 - t) * noise + t * latents
586
-
587
- t_for_unet = (t * 1000).long().view(-1)
588
- model_pred = unet(noisy_latents, t_for_unet, embeddings).sample
589
- target_pred = latents - noise
590
-
591
- mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
592
-
593
- # Сохраняем для локальных логов
594
- batch_losses.append(mse_loss.detach().cpu().item())
595
-
596
- # Backward
597
- accelerator.backward(mse_loss)
598
-
599
- grad_norm_val = 0.0
600
- if not fbp:
601
- if accelerator.sync_gradients:
602
- # Clip gradients and step only when gradients are synchronized (i.e. actual optimizer step)
603
- grad_norm = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
604
- grad_norm_val = float(grad_norm)
605
- optimizer.step()
606
- lr_scheduler.step()
607
- optimizer.zero_grad(set_to_none=True)
608
-
609
- # increment global_step only when we have synchronized gradients (i.e. on optimizer step)
610
- # FIXED: ensure global_step reflects optimizer updates, not micro-batches
611
- if accelerator.sync_gradients:
612
- global_step += 1
613
- progress_bar.update(1)
614
-
615
- # Логируем метрики только на главном процессе
616
- if accelerator.is_main_process:
617
- if fbp:
618
- current_lr = base_learning_rate
619
- else:
620
- current_lr = lr_scheduler.get_last_lr()[0]
621
-
622
- batch_grads.append(grad_norm_val)
623
-
624
- log_data = {
625
- "loss": mse_loss.detach().item(),
626
- "lr": current_lr,
627
- "grad": grad_norm_val,
628
- "epoch": epoch + 1,
629
- "global_step": global_step,
630
- }
631
-
632
- if use_wandb:
633
- wandb.log(log_data, step=global_step)
634
- if use_comet_ml:
635
- comet_experiment.log_metrics(log_data, step=global_step)
636
-
637
- # Генерируем сэмплы с заданным интервалом (только на главном процессе)
638
- if global_step % sample_interval == 0:
639
- generate_and_save_samples(fixed_samples, global_step)
640
-
641
- if save_model:
642
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
643
- if use_wandb:
644
- wandb.log({"sample_loss": avg_sample_loss})
645
- print("saving:", avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
646
- if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
647
- min_loss = avg_sample_loss
648
- save_checkpoint(unet)
649
-
650
- # Эпоха окончена — агрегируем и логируем средние значения
651
- if accelerator.is_main_process:
652
- # local averages
653
- avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
654
- avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
655
-
656
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
657
- log_data_ep = {
658
- "epoch_loss": avg_epoch_loss,
659
- "epoch_grad": avg_epoch_grad,
660
- "epoch": epoch + 1,
661
- }
662
- if use_wandb:
663
- wandb.log(log_data_ep)
664
- if use_comet_ml:
665
- comet_experiment.log_metrics(log_data_ep)
666
-
667
- # Завершение обучения - сохраняем финальную модель
668
- if accelerator.is_main_process:
669
- print("Обучение завершено! Сохраняем финальную модель...")
670
- if save_model:
671
- save_checkpoint(unet, "fp16")
672
- if use_comet_ml:
673
- comet_experiment.end()
674
- accelerator.free_memory()
675
- if torch.distributed.is_initialized():
676
- torch.distributed.destroy_process_group()
677
-
678
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_velocity.py DELETED
@@ -1,825 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL, DDPMScheduler
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
- from collections import deque
27
-
28
- # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs3d/datasets/butterfly"
30
- project = "unet"
31
- batch_size = 16
32
- base_learning_rate = 9e-5
33
- min_learning_rate = 1e-5
34
- num_epochs = 300
35
- # samples/save per epoch
36
- sample_interval_share = 1
37
- use_wandb = True
38
- save_model = True
39
- use_decay = True
40
- fbp = False # fused backward pass
41
- optimizer_type = "adam8bit"
42
- torch_compile = False
43
- unet_gradient = True
44
- clip_sample = False #Scheduler
45
- fixed_seed = True
46
- shuffle = True
47
- torch.backends.cuda.matmul.allow_tf32 = True
48
- torch.backends.cudnn.allow_tf32 = True
49
- torch.backends.cuda.enable_mem_efficient_sdp(False)
50
- dtype = torch.float32
51
- save_barrier = 2.03 # TODO
52
- warmup_percent = 0.01
53
- dispersive_temperature=0.5
54
- dispersive_weight= 0.05
55
- percentile_clipping = 99 # 8bit optim
56
- betta2 = 0.995
57
- eps = 1e-8
58
- clip_grad_norm = 1.0
59
- steps_offset = 0 # Scheduler
60
- limit = 0
61
- checkpoints_folder = ""
62
- mixed_precision = "no" #"fp16"
63
- gradient_accumulation_steps = 1
64
- accelerator = Accelerator(
65
- mixed_precision=mixed_precision,
66
- gradient_accumulation_steps=gradient_accumulation_steps
67
- )
68
- device = accelerator.device
69
-
70
- # Параметры для диффузии
71
- n_diffusion_steps = 50
72
- samples_to_generate = 12
73
- guidance_scale = 5
74
-
75
- # Папки для сохранения результатов
76
- generated_folder = "samples"
77
- os.makedirs(generated_folder, exist_ok=True)
78
-
79
- # Настройка seed для воспроизводимости
80
- current_date = datetime.now()
81
- seed = int(current_date.strftime("%Y%m%d"))
82
- if fixed_seed:
83
- torch.manual_seed(seed)
84
- np.random.seed(seed)
85
- random.seed(seed)
86
- if torch.cuda.is_available():
87
- torch.cuda.manual_seed_all(seed)
88
-
89
- # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
90
- # CHANGED: добавлен huber и dispersive в пропорции, суммы = 1.0
91
- loss_ratios = {
92
- "mse": 1.0,
93
- "mae": 0.0,
94
- "huber": 0.0,
95
- "dispersive": 0.0,
96
- }
97
- median_coeff_steps = 128 # за сколько шагов считать медианные коэффициенты
98
-
99
- # --------------------------- Параметры LoRA ---------------------------
100
- lora_name = ""
101
- lora_rank = 32
102
- lora_alpha = 64
103
-
104
- print("init")
105
-
106
- # --------------------------- вспомогательные функции ---------------------------
107
- def sample_timesteps_bias(
108
- batch_size: int,
109
- progress: float, # [0..1]
110
- num_train_timesteps: int, # обычно 1000
111
- steps_offset: int = 0,
112
- device=None,
113
- mode: str = "uniform", # "beta", "uniform"
114
- ) -> torch.Tensor:
115
- """
116
- Возвращает timesteps с разным bias:
117
- - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
118
- - normal : около середины (гауссовое распределение)
119
- - uniform: равномерно по всем timestep’ам
120
- """
121
-
122
- max_idx = num_train_timesteps - 1 - steps_offset
123
-
124
- if mode == "beta":
125
- alpha = 1.0 + .5 * (1.0 - progress)
126
- beta = 1.0 + .5 * progress
127
- samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
128
-
129
- elif mode == "uniform":
130
- samples = torch.rand(batch_size)
131
-
132
- else:
133
- raise ValueError(f"Unknown mode: {mode}")
134
-
135
- timesteps = steps_offset + (samples * max_idx).long().to(device)
136
- return timesteps
137
-
138
-
139
- # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
140
- class MedianLossNormalizer:
141
- def __init__(self, desired_ratios: dict, window_steps: int):
142
- # нормируем доли на случай, если сумма != 1
143
- s = sum(desired_ratios.values())
144
- self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
145
- self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
146
- self.window = window_steps
147
-
148
- def update_and_total(self, losses: dict):
149
- """
150
- losses: dict ключ->тензор (значения лоссов)
151
- Поведение:
152
- - буферим ABS(l) только для активных (ratio>0) лоссов
153
- - coeff = ratio / median(abs(loss))
154
- - total = sum(coeff * loss) по активным лоссам
155
- CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
156
- """
157
- # буферим только активные лоссы
158
- for k, v in losses.items():
159
- if k in self.buffers and self.ratios.get(k, 0) > 0:
160
- self.buffers[k].append(float(v.detach().abs().cpu()))
161
-
162
- meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
163
- coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
164
-
165
- # суммируем только по активным (ratio>0)
166
- total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
167
- return total, coeffs, meds
168
-
169
- # создаём normalizer после определения loss_ratios
170
- normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
171
-
172
- class AccelerateDispersiveLoss:
173
- def __init__(self, accelerator, temperature=0.5, weight=0.5):
174
- self.accelerator = accelerator
175
- self.temperature = temperature
176
- self.weight = weight
177
- self.activations = []
178
- self.hooks = []
179
-
180
- def register_hooks(self, model, target_layer="down_blocks.0"):
181
- unwrapped_model = self.accelerator.unwrap_model(model)
182
- print("=== Поиск слоев в unwrapped модели ===")
183
- for name, module in unwrapped_model.named_modules():
184
- if target_layer in name:
185
- hook = module.register_forward_hook(self.hook_fn)
186
- self.hooks.append(hook)
187
- print(f"✅ Хук зарегистрирован на: {name}")
188
- break
189
-
190
- def hook_fn(self, module, input, output):
191
- if isinstance(output, tuple):
192
- activation = output[0]
193
- else:
194
- activation = output
195
- if len(activation.shape) > 2:
196
- activation = activation.view(activation.shape[0], -1)
197
- self.activations.append(activation.detach().clone())
198
-
199
- def compute_dispersive_loss(self):
200
- if not self.activations:
201
- return torch.tensor(0.0, requires_grad=True, device=device)
202
- local_activations = self.activations[-1].float()
203
- batch_size = local_activations.shape[0]
204
- if batch_size < 2:
205
- return torch.tensor(0.0, requires_grad=True, device=device)
206
- sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
207
- distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
208
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
209
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
210
- return dispersive_loss
211
-
212
- def clear_activations(self):
213
- self.activations.clear()
214
-
215
- def remove_hooks(self):
216
- for hook in self.hooks:
217
- hook.remove()
218
- self.hooks.clear()
219
-
220
-
221
- # --------------------------- Инициализация WandB ---------------------------
222
- if use_wandb and accelerator.is_main_process:
223
- wandb.init(project=project+lora_name, config={
224
- "batch_size": batch_size,
225
- "base_learning_rate": base_learning_rate,
226
- "num_epochs": num_epochs,
227
- "fbp": fbp,
228
- "optimizer_type": optimizer_type,
229
- })
230
-
231
- # Включение Flash Attention 2/SDPA
232
- torch.backends.cuda.enable_flash_sdp(True)
233
- # --------------------------- Инициализация Accelerator --------------------
234
- gen = torch.Generator(device=device)
235
- gen.manual_seed(seed)
236
-
237
- # --------------------------- Загрузка моделей ---------------------------
238
- # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
239
- #vae = AutoencoderKLWan.from_pretrained("vae", variant="fp16").to(device="cpu", dtype=torch.float16).eval()
240
- #vae = AutoencoderKLWan.from_pretrained(
241
- # "AiArtLab/simplevae", subfolder="wan16x_vae_nightly",
242
- # torch_dtype=dtype
243
- # ).to(device="cpu").eval()
244
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to(device).eval()
245
-
246
- shift_factor = getattr(vae.config, "shift_factor", 0.0)
247
- if shift_factor is None:
248
- shift_factor = 0.0
249
-
250
- scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
251
- if scaling_factor is None:
252
- scaling_factor = 1.0
253
-
254
- latents_mean = getattr(vae.config, "latents_mean", None)
255
- latents_std = getattr(vae.config, "latents_std", None)
256
-
257
- # DDPMScheduler с V_Prediction и Zero-SNR
258
- scheduler = DDPMScheduler(
259
- num_train_timesteps=1000,
260
- prediction_type="v_prediction",
261
- rescale_betas_zero_snr=True,
262
- clip_sample = clip_sample,
263
- steps_offset = steps_offset
264
- )
265
-
266
-
267
- class DistributedResolutionBatchSampler(Sampler):
268
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
269
- self.dataset = dataset
270
- self.batch_size = max(1, batch_size // num_replicas)
271
- self.num_replicas = num_replicas
272
- self.rank = rank
273
- self.shuffle = shuffle
274
- self.drop_last = drop_last
275
- self.epoch = 0
276
-
277
- try:
278
- widths = np.array(dataset["width"])
279
- heights = np.array(dataset["height"])
280
- except KeyError:
281
- widths = np.zeros(len(dataset))
282
- heights = np.zeros(len(dataset))
283
-
284
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
285
- self.size_groups = {}
286
- for w, h in self.size_keys:
287
- mask = (widths == w) & (heights == h)
288
- self.size_groups[(w, h)] = np.where(mask)[0]
289
-
290
- self.group_num_batches = {}
291
- total_batches = 0
292
- for size, indices in self.size_groups.items():
293
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
294
- self.group_num_batches[size] = num_full_batches
295
- total_batches += num_full_batches
296
-
297
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
298
-
299
- def __iter__(self):
300
- if torch.cuda.is_available():
301
- torch.cuda.empty_cache()
302
- all_batches = []
303
- rng = np.random.RandomState(self.epoch)
304
-
305
- for size, indices in self.size_groups.items():
306
- indices = indices.copy()
307
- if self.shuffle:
308
- rng.shuffle(indices)
309
- num_full_batches = self.group_num_batches[size]
310
- if num_full_batches == 0:
311
- continue
312
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
313
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
314
- start_idx = self.rank * self.batch_size
315
- end_idx = start_idx + self.batch_size
316
- gpu_batches = batches[:, start_idx:end_idx]
317
- all_batches.extend(gpu_batches)
318
-
319
- if self.shuffle:
320
- rng.shuffle(all_batches)
321
- accelerator.wait_for_everyone()
322
- return iter(all_batches)
323
-
324
- def __len__(self):
325
- return self.num_batches
326
-
327
- def set_epoch(self, epoch):
328
- self.epoch = epoch
329
-
330
- # Функция для выборки фиксированных семплов по размерам
331
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
332
- size_groups = defaultdict(list)
333
- try:
334
- widths = dataset["width"]
335
- heights = dataset["height"]
336
- except KeyError:
337
- widths = [0] * len(dataset)
338
- heights = [0] * len(dataset)
339
- for i, (w, h) in enumerate(zip(widths, heights)):
340
- size = (w, h)
341
- size_groups[size].append(i)
342
-
343
- fixed_samples = {}
344
- for size, indices in size_groups.items():
345
- n_samples = min(samples_per_group, len(indices))
346
- if len(size_groups)==1:
347
- n_samples = samples_to_generate
348
- if n_samples == 0:
349
- continue
350
- sample_indices = random.sample(indices, n_samples)
351
- samples_data = [dataset[idx] for idx in sample_indices]
352
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
353
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
354
- texts = [item["text"] for item in samples_data]
355
- fixed_samples[size] = (latents, embeddings, texts)
356
-
357
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
358
- return fixed_samples
359
-
360
- if limit > 0:
361
- dataset = load_from_disk(ds_path).select(range(limit))
362
- else:
363
- dataset = load_from_disk(ds_path)
364
-
365
- def collate_fn_simple(batch):
366
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
367
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
368
- return latents, embeddings
369
-
370
- batch_sampler = DistributedResolutionBatchSampler(
371
- dataset=dataset,
372
- batch_size=batch_size,
373
- num_replicas=accelerator.num_processes,
374
- rank=accelerator.process_index,
375
- shuffle=shuffle
376
- )
377
-
378
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
379
- print("Total samples",len(dataloader))
380
- dataloader = accelerator.prepare(dataloader)
381
-
382
- start_epoch = 0
383
- global_step = 0
384
- total_training_steps = (len(dataloader) * num_epochs)
385
- world_size = accelerator.state.num_processes
386
-
387
- # Опция загрузки модели из последнего чекпоинта (если существует)
388
- latest_checkpoint = os.path.join(checkpoints_folder, project)
389
- if os.path.isdir(latest_checkpoint):
390
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
391
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
392
- if torch_compile:
393
- print("compiling")
394
- torch.set_float32_matmul_precision('high')
395
- unet = torch.compile(unet)
396
- print("compiling - ok")
397
- if unet_gradient:
398
- unet.enable_gradient_checkpointing()
399
- unet.set_use_memory_efficient_attention_xformers(False)
400
- try:
401
- unet.set_attn_processor(AttnProcessor2_0())
402
- except Exception as e:
403
- print(f"Ошибка при включении SDPA: {e}")
404
- unet.set_use_memory_efficient_attention_xformers(True)
405
-
406
- # Создаём hook для dispersive только если нужно
407
- if loss_ratios.get("dispersive", 0) > 0:
408
- dispersive_hook = AccelerateDispersiveLoss(
409
- accelerator=accelerator,
410
- temperature=dispersive_temperature,
411
- weight=dispersive_weight
412
- )
413
- else:
414
- # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
415
- raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
416
-
417
- if lora_name:
418
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
419
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
420
- from peft.tuners.lora import LoraModel
421
- import os
422
- unet.requires_grad_(False)
423
- print("Параметры базового UNet заморожены.")
424
-
425
- lora_config = LoraConfig(
426
- r=lora_rank,
427
- lora_alpha=lora_alpha,
428
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
429
- )
430
- unet.add_adapter(lora_config)
431
-
432
- from peft import get_peft_model
433
- peft_unet = get_peft_model(unet, lora_config)
434
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
435
-
436
- if accelerator.is_main_process:
437
- lora_params_count = sum(p.numel() for p in params_to_optimize)
438
- total_params_count = sum(p.numel() for p in unet.parameters())
439
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
440
- print(f"Общее количество параметров UNet: {total_params_count:,}")
441
-
442
- lora_save_path = os.path.join("lora", lora_name)
443
- os.makedirs(lora_save_path, exist_ok=True)
444
-
445
- def save_lora_checkpoint(model):
446
- if accelerator.is_main_process:
447
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
448
- from peft.utils.save_and_load import get_peft_model_state_dict
449
- lora_state_dict = get_peft_model_state_dict(model)
450
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
451
- model.peft_config["default"].save_pretrained(lora_save_path)
452
- from diffusers import StableDiffusionXLPipeline
453
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
454
-
455
- # --------------------------- Оптимизатор ---------------------------
456
- if lora_name:
457
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
458
- else:
459
- if fbp:
460
- trainable_params = list(unet.parameters())
461
-
462
- def create_optimizer(name, params):
463
- if name == "adam8bit":
464
- return bnb.optim.AdamW8bit(
465
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
466
- percentile_clipping=percentile_clipping
467
- )
468
- elif name == "adam":
469
- return torch.optim.AdamW(
470
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
471
- )
472
- elif name == "lion8bit":
473
- return bnb.optim.Lion8bit(
474
- params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
475
- percentile_clipping=percentile_clipping
476
- )
477
- elif name == "adafactor":
478
- from transformers import Adafactor
479
- return Adafactor(
480
- params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
481
- warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
482
- beta1=0.9, weight_decay=0.01
483
- )
484
- else:
485
- raise ValueError(f"Unknown optimizer: {name}")
486
-
487
- if fbp:
488
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
489
- def optimizer_hook(param):
490
- optimizer_dict[param].step()
491
- optimizer_dict[param].zero_grad(set_to_none=True)
492
- for param in trainable_params:
493
- param.register_post_accumulate_grad_hook(optimizer_hook)
494
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
495
- else:
496
- optimizer = create_optimizer(optimizer_type, unet.parameters())
497
- def lr_schedule(step):
498
- x = step / (total_training_steps * world_size)
499
- warmup = warmup_percent
500
- if not use_decay:
501
- return base_learning_rate
502
- if x < warmup:
503
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
504
- decay_ratio = (x - warmup) / (1 - warmup)
505
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
506
- (1 + math.cos(math.pi * decay_ratio))
507
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
508
-
509
- num_params = sum(p.numel() for p in unet.parameters())
510
- print(f"[rank {accelerator.process_index}] total params: {num_params}")
511
- for name, param in unet.named_parameters():
512
- if torch.isnan(param).any() or torch.isinf(param).any():
513
- print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
514
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
515
-
516
- # Регистрация хуков ПОСЛЕ prepare
517
- if loss_ratios.get("dispersive", 0) > 0:
518
- dispersive_hook.register_hooks(unet, "down_blocks.2")
519
-
520
- # --------------------------- Фиксированные семплы для генерации ---------------------------
521
- fixed_samples = get_fixed_samples_by_resolution(dataset)
522
-
523
- @torch.compiler.disable()
524
- @torch.no_grad()
525
- def generate_and_save_samples(fixed_samples_cpu, step):
526
- original_model = None
527
- try:
528
- original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
529
- vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
530
-
531
- scheduler.set_timesteps(n_diffusion_steps)
532
-
533
- all_generated_images = []
534
- all_captions = []
535
-
536
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
537
- width, height = size
538
- sample_latents = sample_latents.to(dtype=dtype, device=device)
539
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
540
-
541
- noise = torch.randn(
542
- sample_latents.shape,
543
- generator=gen,
544
- device=device,
545
- dtype=sample_latents.dtype
546
- )
547
- current_latents = noise.clone()
548
-
549
- if guidance_scale > 0:
550
- empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
551
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
552
- else:
553
- text_embeddings_batch = sample_text_embeddings
554
-
555
- for t in scheduler.timesteps:
556
- t_batch = t.repeat(current_latents.shape[0]).to(device)
557
- if guidance_scale > 0:
558
- latent_model_input = torch.cat([current_latents] * 2)
559
- else:
560
- latent_model_input = current_latents
561
-
562
- latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
563
- noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
564
-
565
- if guidance_scale > 0:
566
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
567
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
568
-
569
- current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
570
-
571
- #print(current_latents.ndim, current_latents.shape)
572
- #if current_latents.ndim == 4:
573
- # current_latents = current_latents.unsqueeze(2)
574
- # Латент в форме [B, C, T, H, W]
575
- #print(current_latents.ndim, current_latents.shape)
576
-
577
- # Параметры нормализации
578
- latent_for_vae = current_latents.detach() * scaling_factor + shift_factor
579
-
580
- if latents_mean!=None and latents_std!=None:
581
- latent_for_vae = latent_for_vae * torch.tensor(latents_std, device=device, dtype=dtype).view(1, -1, 1, 1, 1) + torch.tensor(latents_mean, device=device, dtype=dtype).view(1, -1, 1, 1, 1)
582
-
583
- decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
584
- #decoded = decoded[:, :, 0, :, :] # [3, H, W]
585
- #print(decoded.ndim, decoded.shape)
586
-
587
- decoded_fp32 = decoded.to(torch.float32)
588
- for img_idx, img_tensor in enumerate(decoded_fp32):
589
-
590
- # Форма: [3, H, W] -> преобразуем в [H, W, 3]
591
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
592
- img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
593
-
594
- #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
595
- if np.isnan(img).any():
596
- print("NaNs found, saving stopped! Step:", step)
597
- pil_img = Image.fromarray((img * 255).astype("uint8"))
598
-
599
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
600
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
601
- max_w_overall = max(255, max_w_overall)
602
- max_h_overall = max(255, max_h_overall)
603
-
604
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
605
- all_generated_images.append(padded_img)
606
-
607
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
608
- all_captions.append(caption_text)
609
-
610
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
611
- pil_img.save(sample_path, "JPEG", quality=96)
612
-
613
- if use_wandb and accelerator.is_main_process:
614
- wandb_images = [
615
- wandb.Image(img, caption=f"{all_captions[i]}")
616
- for i, img in enumerate(all_generated_images)
617
- ]
618
- wandb.log({"generated_images": wandb_images, "global_step": step})
619
- finally:
620
- # вернуть VAE на CPU (как было в твоём коде)
621
- vae.to("cpu")
622
- for var in list(locals().keys()):
623
- if isinstance(locals()[var], torch.Tensor):
624
- del locals()[var]
625
- torch.cuda.empty_cache()
626
- gc.collect()
627
-
628
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
629
- if accelerator.is_main_process:
630
- if save_model:
631
- print("Генерация сэмплов до старта обучения...")
632
- generate_and_save_samples(fixed_samples,0)
633
- accelerator.wait_for_everyone()
634
-
635
- # Модифицируем функцию сохранения модели для поддержки LoRA
636
- def save_checkpoint(unet,variant=""):
637
- if accelerator.is_main_process:
638
- if lora_name:
639
- save_lora_checkpoint(unet)
640
- else:
641
- if variant!="":
642
- accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
643
- else:
644
- accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
645
- unet = unet.to(dtype=dtype)
646
-
647
- def batch_pred_original_from_step(model_outputs, timesteps_tensor, noisy_latents, scheduler):
648
- device = noisy_latents.device
649
- dtype = noisy_latents.dtype
650
-
651
- available_ts = scheduler.timesteps
652
- if not isinstance(available_ts, torch.Tensor):
653
- available_ts = torch.tensor(available_ts, device="cpu")
654
- else:
655
- available_ts = available_ts.cpu()
656
-
657
- B = model_outputs.shape[0]
658
- preds = []
659
- for i in range(B):
660
- t_i = int(timesteps_tensor[i].item())
661
- diffs = torch.abs(available_ts - t_i)
662
- idx = int(torch.argmin(diffs).item())
663
- t_for_step = int(available_ts[idx].item())
664
- model_out_i = model_outputs[i:i+1]
665
- noisy_latent_i = noisy_latents[i:i+1]
666
- step_out = scheduler.step(model_out_i, t_for_step, noisy_latent_i)
667
- preds.append(step_out.pred_original_sample)
668
-
669
- return torch.cat(preds, dim=0).to(device=device, dtype=dtype)
670
-
671
- # --------------------------- Тренировочный цикл ---------------------------
672
- if accelerator.is_main_process:
673
- print(f"Total steps per GPU: {total_training_steps}")
674
-
675
- epoch_loss_points = []
676
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
677
-
678
- steps_per_epoch = len(dataloader)
679
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
680
- min_loss = 1.
681
-
682
- for epoch in range(start_epoch, start_epoch + num_epochs):
683
- batch_losses = []
684
- batch_tlosses = []
685
- batch_grads = []
686
- batch_sampler.set_epoch(epoch)
687
- accelerator.wait_for_everyone()
688
- unet.train()
689
- print("epoch:",epoch)
690
- for step, (latents, embeddings) in enumerate(dataloader):
691
- with accelerator.accumulate(unet):
692
- if save_model == False and step == 5 :
693
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
694
- print(f"Шаг {step}: {used_gb:.2f} GB")
695
-
696
- noise = torch.randn_like(latents, dtype=latents.dtype)
697
-
698
- progress = global_step / max(1, total_training_steps - 1)
699
- timesteps = sample_timesteps_bias(
700
- batch_size=latents.shape[0],
701
- progress=progress,
702
- num_train_timesteps=scheduler.config.num_train_timesteps,
703
- steps_offset=steps_offset,
704
- device=device
705
- )
706
-
707
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
708
-
709
- if loss_ratios.get("dispersive", 0) > 0:
710
- dispersive_hook.clear_activations()
711
-
712
- #print(latents.shape,embeddings.shape)
713
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
714
- target_pred = scheduler.get_velocity(latents, noise, timesteps)
715
-
716
- # === Losses ===
717
- losses_dict = {}
718
-
719
- mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
720
- losses_dict["mse"] = mse_loss
721
- losses_dict["mae"] = F.l1_loss(model_pred.float(), target_pred.float())
722
-
723
- # CHANGED: Huber (smooth_l1) loss added
724
- losses_dict["huber"] = F.smooth_l1_loss(model_pred.float(), target_pred.float())
725
-
726
- # === Dispersive loss ===
727
- if loss_ratios.get("dispersive", 0) > 0:
728
- disp_raw = dispersive_hook.compute_dispersive_loss().to(device) # может быть отрицательным
729
- losses_dict["dispersive"] = dispersive_hook.weight * disp_raw
730
- else:
731
- losses_dict["dispersive"] = torch.tensor(0.0, device=device)
732
-
733
- # === Нормализация всех лоссов ===
734
- abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
735
- total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
736
-
737
- # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
738
- batch_losses.append(mse_loss.detach().item())
739
-
740
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
741
- accelerator.wait_for_everyone()
742
-
743
- # Backward
744
- accelerator.backward(total_loss)
745
-
746
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
747
- accelerator.wait_for_everyone()
748
-
749
- grad = 0.0
750
- if not fbp:
751
- if accelerator.sync_gradients:
752
- with torch.amp.autocast('cuda', enabled=False):
753
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
754
- grad = float(grad_val)
755
- optimizer.step()
756
- lr_scheduler.step()
757
- optimizer.zero_grad(set_to_none=True)
758
-
759
- global_step += 1
760
- progress_bar.update(1)
761
-
762
- # Логируем метрики
763
- if accelerator.is_main_process:
764
- if fbp:
765
- current_lr = base_learning_rate
766
- else:
767
- current_lr = lr_scheduler.get_last_lr()[0]
768
- batch_tlosses.append(total_loss.detach().item())
769
- batch_grads.append(grad)
770
-
771
- # Логируем только активные лоссы (ratio>0)
772
- active_keys = [k for k, v in loss_ratios.items() if v > 0]
773
- log_data = {}
774
- for k in active_keys:
775
- v = losses_dict.get(k, None)
776
- if v is None:
777
- continue
778
- log_data[f"loss/{k}"] = (v.item() if isinstance(v, torch.Tensor) else float(v))
779
-
780
- log_data["loss/total"] = float(total_loss.item())
781
- log_data["loss/lr"] = current_lr
782
- for k, c in coeffs.items():
783
- log_data[f"coeff/{k}"] = float(c)
784
- if use_wandb and accelerator.sync_gradients:
785
- wandb.log(log_data, step=global_step)
786
-
787
- # Генерируем сэмплы с заданным интервалом
788
- if global_step % sample_interval == 0:
789
- generate_and_save_samples(fixed_samples,global_step)
790
- last_n = sample_interval
791
- avg_loss = float(np.mean(batch_losses[-last_n:])) if len(batch_losses) > 0 else 0.0
792
- avg_tloss = float(np.mean(batch_tlosses[-last_n:])) if len(batch_tlosses) > 0 else 0.0
793
- avg_grad = float(np.mean(batch_grads[-last_n:])) if len(batch_grads) > 0 else 0.0
794
- print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}")
795
-
796
- if save_model:
797
- print("saving:",avg_loss < min_loss*save_barrier)
798
- if avg_loss < min_loss*save_barrier:
799
- min_loss = avg_loss
800
- save_checkpoint(unet)
801
- if use_wandb:
802
- avg_data = {}
803
- avg_data["avg/loss"] = avg_loss
804
- avg_data["avg/tloss"] = avg_tloss
805
- avg_data["avg/grad"] = avg_grad
806
- wandb.log(avg_data, step=global_step)
807
-
808
- if accelerator.is_main_process:
809
- avg_epoch_loss = np.mean(batch_losses) if len(batch_losses)>0 else 0.0
810
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
811
- if use_wandb:
812
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
813
-
814
- # Завершение обучения - сохраняем финальную модель
815
- if loss_ratios.get("dispersive", 0) > 0:
816
- dispersive_hook.remove_hooks()
817
- if accelerator.is_main_process:
818
- print("Обучение завершено! Сохраняем финальную модель...")
819
- if save_model:
820
- save_checkpoint(unet,"fp16")
821
- accelerator.free_memory()
822
- if torch.distributed.is_initialized():
823
- torch.distributed.destroy_process_group()
824
-
825
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:34d3888f2e6c927fe69860a6e0ace71bc3e12f5534b9da3e4cbc73a64ff505dc
3
  size 6184944280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16fb28e82ad5c238b034ee615da6661457dd4151175085b46a3e48222c0d786a
3
  size 6184944280