recoilme commited on
Commit
841522d
·
1 Parent(s): 8d7557f
2b/config.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bca9f7e0281e454233618bc1bae900d4d67e2f7706d604cbe026d754afc9b317
3
- size 1777
 
 
 
 
2b/diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a92d7e036f0127b0be424dafe24e1aef3aae84fb411f61a525efbb41f712a5e7
3
- size 7993399544
 
 
 
 
micro/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18891ea8f81c705422b12d7575493f5abb949d60f2814a989a436fde65e84c82
3
  size 1873
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c41005eb2fa8b69648ec72d7cbcae0a4a6c3a63e8df9f79a2a417b3c585b210
3
  size 1873
micro/diffusion_pytorch_model.fp16.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a60767940a359caea48ee976ca37edf17916c05b4924238690ab10beae663833
3
- size 1964968816
 
 
 
 
micro/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:39f0ed5b02f53bf873b8fd0aaebf67cf589300cb2e0b3a5aecc62a0212a6e99d
3
  size 3929714960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb07de5296dcb2ff173cf5c6baad037721b2212f96b990dbc93491005dd3092d
3
  size 3929714960
result_grid.jpg CHANGED

Git LFS Details

  • SHA256: ccc5d980d4d6b07bb0cc39a5ea32ccc4227dd1e21f8369c2da645d6047a544ba
  • Pointer size: 132 Bytes
  • Size of remote file: 7.13 MB

Git LFS Details

  • SHA256: c880e830f36717315353f5b267de147af834018cbc0d26fea2f4399ad0e2488c
  • Pointer size: 132 Bytes
  • Size of remote file: 7.06 MB
samples/2b_192x384_0.jpg DELETED

Git LFS Details

  • SHA256: 0df3ab0f90ebfe052193eeeaf9b894b56380350d3efbf238c4e7640dcfab365b
  • Pointer size: 130 Bytes
  • Size of remote file: 36.1 kB
samples/2b_256x384_0.jpg DELETED

Git LFS Details

  • SHA256: ad5c7326d6295ff5feaac9e9220333978fb6ce0a976a58b7a288f182e4736094
  • Pointer size: 130 Bytes
  • Size of remote file: 64.8 kB
samples/2b_320x384_0.jpg DELETED

Git LFS Details

  • SHA256: ac216aac054b0aa2ac4aed7de7a2cbb367f9ae709c8c8100de4fbffad331df26
  • Pointer size: 130 Bytes
  • Size of remote file: 49.1 kB
samples/2b_384x192_0.jpg DELETED

Git LFS Details

  • SHA256: 971eb15ab9ecbc83d6e16b71b344d41c6c86d837d341b701501abc99e6c42a2b
  • Pointer size: 130 Bytes
  • Size of remote file: 35.9 kB
samples/2b_384x256_0.jpg DELETED

Git LFS Details

  • SHA256: ee453de08295c9f26dc9a63d6e887ac1bbbd8e26645e77815df85572ba112b20
  • Pointer size: 130 Bytes
  • Size of remote file: 38.9 kB
samples/2b_384x320_0.jpg DELETED

Git LFS Details

  • SHA256: 8ff81dda2dd78c1c96fd339cc719628996e81f5e3c9e817c0f26ac27fb2b853a
  • Pointer size: 130 Bytes
  • Size of remote file: 49.7 kB
samples/2b_384x384_0.jpg DELETED

Git LFS Details

  • SHA256: cbe5d2cf0ac98d65b0ad627aa3df89ae6893610e09b2693529d79877dcf68f35
  • Pointer size: 130 Bytes
  • Size of remote file: 68.4 kB
samples/micro_192x384_0.jpg CHANGED

Git LFS Details

  • SHA256: b38bcec72835a606d284e1d698068281a3f248e10b6017fd5c8df92dd0e1e968
  • Pointer size: 130 Bytes
  • Size of remote file: 29.4 kB

Git LFS Details

  • SHA256: 50fbe4a7140c5f7614eb01d228ef85c29686974f6fe407071fef003484e73114
  • Pointer size: 130 Bytes
  • Size of remote file: 26.5 kB
samples/micro_256x384_0.jpg CHANGED

Git LFS Details

  • SHA256: e159c96de747fb460079861ac277000705951b935d05aab0606541e04a02365c
  • Pointer size: 130 Bytes
  • Size of remote file: 51.6 kB

Git LFS Details

  • SHA256: 08b6e8cceeb5794010e07613692fc516aee3d72b57e6db60f3fe3c0bb9a6136e
  • Pointer size: 130 Bytes
  • Size of remote file: 38.2 kB
samples/micro_320x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 8dee16eca87ff2300a78a7241fd6ced8479766b5175ae9291a99f631bf587a9b
  • Pointer size: 130 Bytes
  • Size of remote file: 52.7 kB

Git LFS Details

  • SHA256: 47d10e86f8c36ed8fab410a105e9d616ef28d04fca193e2b002473bc6702e93a
  • Pointer size: 130 Bytes
  • Size of remote file: 71.1 kB
samples/micro_384x192_0.jpg CHANGED

Git LFS Details

  • SHA256: f566e1dc65b947bdefe67948642c4f4b47734606a73dfb40dbd05afbb74dde2d
  • Pointer size: 130 Bytes
  • Size of remote file: 41.7 kB

Git LFS Details

  • SHA256: d09ec9abac71e25776012822cba0050b7c55b411e0812e77e737ea47b7f2b86b
  • Pointer size: 130 Bytes
  • Size of remote file: 58.2 kB
samples/micro_384x256_0.jpg CHANGED

Git LFS Details

  • SHA256: e85ba175158161581023108b987a469eb96a2f3f78af7f16c16a6042322829ce
  • Pointer size: 130 Bytes
  • Size of remote file: 46.8 kB

Git LFS Details

  • SHA256: 5e14aef2d467c021e3eef73f461d1e1b46cb8bac25956b3eb2d2c0ff96b69c40
  • Pointer size: 130 Bytes
  • Size of remote file: 53.3 kB
samples/micro_384x320_0.jpg CHANGED

Git LFS Details

  • SHA256: c8e4fbdf8667ae9cb8b9fe9ddc03cb33d9c7e6e86ae585fb6bcc775d4cf477c5
  • Pointer size: 130 Bytes
  • Size of remote file: 50.2 kB

Git LFS Details

  • SHA256: 5d43424249007d387e5bdde038a19388a6c03f06ec6b831ee4dbfbaf8fa202eb
  • Pointer size: 130 Bytes
  • Size of remote file: 37.1 kB
samples/micro_384x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 409e48d22021d422fddd528a3231eba7280e3d02fdf098a13eebd50c62d25025
  • Pointer size: 130 Bytes
  • Size of remote file: 83.6 kB

Git LFS Details

  • SHA256: 9ddf681a45ba4aa572960f549a3fefdcbde65bc9d6b236e7cd893163528a7e51
  • Pointer size: 130 Bytes
  • Size of remote file: 70.6 kB
samples/sdxl_192x384_0.jpg DELETED

Git LFS Details

  • SHA256: cf70c7894186a9b272dc2e08f62939af276d300c1b15450dc58697e9bdadbaef
  • Pointer size: 130 Bytes
  • Size of remote file: 42.4 kB
samples/sdxl_256x384_0.jpg DELETED

Git LFS Details

  • SHA256: ca5f4d51c9331f3e35c99166e73de84e5da0eb439f9efa9c58973af741ff35a5
  • Pointer size: 130 Bytes
  • Size of remote file: 31 kB
samples/sdxl_320x384_0.jpg DELETED

Git LFS Details

  • SHA256: 4e9edbc5cf2600f78318cef708d620575f3b5ba95a5d72605981db4c4fd9127b
  • Pointer size: 130 Bytes
  • Size of remote file: 62.7 kB
samples/sdxl_384x192_0.jpg DELETED

Git LFS Details

  • SHA256: 450d1a98b3cfeabf476db781c8850d0c40b246c54b5cd91d9451f984ce5f73f5
  • Pointer size: 130 Bytes
  • Size of remote file: 27.2 kB
samples/sdxl_384x256_0.jpg DELETED

Git LFS Details

  • SHA256: c884890c77b605f2106562a71cdf46212d0a5e2abff8542e659dd04a459ca6cb
  • Pointer size: 130 Bytes
  • Size of remote file: 28.3 kB
samples/sdxl_384x320_0.jpg DELETED

Git LFS Details

  • SHA256: b7ee68f3686c418f9f736e4a296302526cdd06cdb570c9229facb4a79f1df899
  • Pointer size: 130 Bytes
  • Size of remote file: 57.3 kB
sdxl/config.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1fc6f3f339d56f6e5fc44b2b62cd0ceb46616137961a105f2579298631328f21
3
- size 1768
 
 
 
 
sdxl/diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f14e079e3164d58155294204b475f1782d48e4364a1487000a28b0a16e6a281f
3
- size 3944400016
 
 
 
 
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6a2ebd726ffc5d250ecf6e9807f925637165df36dcd29e62342bc4378d6e13d0
3
- size 5706983
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0e7fea8d086cfaf3a2a88eee3c2a63239db6ea6eec6ebb2e0d46d4d3ca81597
3
+ size 5995904
train-Copy1.py DELETED
@@ -1,1008 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
-
27
- # --------------------------- Параметры ---------------------------
28
- ds_path = "datasets/576"
29
- project = "unet"
30
- batch_size = 50
31
- base_learning_rate = 9e-6
32
- min_learning_rate = 8e-6
33
- num_epochs = 5
34
- # samples/save per epoch
35
- sample_interval_share = 5
36
- use_wandb = True
37
- save_model = True
38
- use_decay = True
39
- fbp = False # fused backward pass
40
- optimizer_type = "adam8bit"
41
- torch_compile = False
42
- unet_gradient = True
43
- clip_sample = False #Scheduler
44
- fixed_seed = False
45
- shuffle = True
46
- dispersive_loss = True
47
- torch.backends.cuda.matmul.allow_tf32 = True
48
- torch.backends.cudnn.allow_tf32 = True
49
- torch.backends.cuda.enable_mem_efficient_sdp(False)
50
- dtype = torch.float32
51
- save_barrier = 1.03
52
- dispersive_temperature=0.5
53
- dispersive_weight=0.05
54
- percentile_clipping = 90 # 8bit optim
55
- steps_offset = 1 # Scheduler
56
- limit = 0
57
- checkpoints_folder = ""
58
- mixed_precision = "fp16"
59
- accelerator = Accelerator(mixed_precision=mixed_precision)
60
- device = accelerator.device
61
-
62
- # Параметры для диффузии
63
- n_diffusion_steps = 50
64
- samples_to_generate = 12
65
- guidance_scale = 5
66
-
67
- # Папки для сохранения результатов
68
- generated_folder = "samples"
69
- os.makedirs(generated_folder, exist_ok=True)
70
-
71
- # Настройка seed для воспроизводимости
72
- current_date = datetime.now()
73
- seed = int(current_date.strftime("%Y%m%d"))
74
- if fixed_seed:
75
- torch.manual_seed(seed)
76
- np.random.seed(seed)
77
- random.seed(seed)
78
- if torch.cuda.is_available():
79
- torch.cuda.manual_seed_all(seed)
80
-
81
- # --------------------------- Параметры LoRA ---------------------------
82
- # pip install peft
83
- lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
84
- lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
85
- lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
86
-
87
- print("init")
88
-
89
- class AccelerateDispersiveLoss:
90
- def __init__(self, accelerator, temperature=0.5, weight=0.5):
91
- self.accelerator = accelerator
92
- self.temperature = temperature
93
- self.weight = weight
94
- self.activations = []
95
- self.hooks = []
96
-
97
- def register_hooks(self, model, target_layer="down_blocks.0"):
98
- unwrapped_model = self.accelerator.unwrap_model(model)
99
- print("=== Поиск слоев в unwrapped модели ===")
100
- for name, module in unwrapped_model.named_modules():
101
- if target_layer in name:
102
- hook = module.register_forward_hook(self.hook_fn)
103
- self.hooks.append(hook)
104
- print(f"✅ Хук зарегистрирован на: {name}")
105
- break
106
-
107
- def hook_fn(self, module, input, output):
108
-
109
- if isinstance(output, tuple):
110
- activation = output[0]
111
- else:
112
- activation = output
113
-
114
- if len(activation.shape) > 2:
115
- activation = activation.view(activation.shape[0], -1)
116
-
117
- self.activations.append(activation.detach())
118
-
119
- def compute_dispersive_loss(self):
120
- if not self.activations:
121
- return torch.tensor(0.0, requires_grad=True)
122
-
123
- local_activations = self.activations[-1].float()
124
-
125
- batch_size = local_activations.shape[0]
126
- if batch_size < 2:
127
- return torch.tensor(0.0, requires_grad=True)
128
-
129
- # Нормализация и вычисление loss
130
- sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
131
- distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
132
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
133
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
134
-
135
- # ВАЖНО: он отриц и должен падать
136
- return dispersive_loss
137
-
138
- def compute_dispersive_loss2(self):
139
- # Если нет активаций, возвращаем 0
140
- if not self.activations:
141
- return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
142
-
143
- # Работаем только с локальными активациями главного процесса
144
- activations = self.activations[-1].float()
145
-
146
- batch_size = activations.shape[0]
147
- if batch_size < 2:
148
- return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
149
-
150
- # Нормализация
151
- norm = torch.norm(activations, dim=1, keepdim=True).clamp(min=1e-12)
152
- sf = activations / norm
153
-
154
- # Вычисляем расстояния
155
- distance = torch.nn.functional.pdist(sf, p=2)
156
- distance = distance.clamp(min=1e-12)
157
- distance_squared = distance ** 2
158
-
159
- # Вычисляем loss с клиппингом для стабильности
160
- exp_neg_dist = torch.exp((-distance_squared / self.temperature).clamp(min=-20, max=20))
161
- exp_neg_dist = exp_neg_dist + 1e-12
162
-
163
- mean_exp = torch.mean(exp_neg_dist)
164
- dispersive_loss = torch.log(mean_exp.clamp(min=1e-12))
165
-
166
- return dispersive_loss
167
-
168
- def clear_activations(self):
169
- self.activations.clear()
170
-
171
- def remove_hooks(self):
172
- for hook in self.hooks:
173
- hook.remove()
174
- self.hooks.clear()
175
-
176
- class AccelerateDispersiveLoss2:
177
- def __init__(self, accelerator, temperature=0.5, weight=0.5):
178
- self.accelerator = accelerator
179
- self.temperature = temperature
180
- self.weight = weight
181
- self.activations = []
182
- self.hooks = []
183
-
184
- def register_hooks(self, model, target_layer="down_blocks.0"):
185
- # Получаем "чистую" модель без DDP wrapper'а
186
- unwrapped_model = self.accelerator.unwrap_model(model)
187
-
188
- print("=== Поиск слоев в unwrapped модели ===")
189
- for name, module in unwrapped_model.named_modules():
190
- if target_layer in name:
191
- hook = module.register_forward_hook(self.hook_fn)
192
- self.hooks.append(hook)
193
- print(f"✅ Хук зарегистрирован на: {name}")
194
- break
195
-
196
- def hook_fn(self, module, input, output):
197
- if isinstance(output, tuple):
198
- activation = output[0]
199
- else:
200
- activation = output
201
-
202
- if len(activation.shape) > 2:
203
- activation = activation.view(activation.shape[0], -1)
204
-
205
- self.activations.append(activation.detach())
206
-
207
- def compute_dispersive_loss_fix(self):
208
- if not self.activations:
209
- return torch.tensor(0.0, requires_grad=True)
210
-
211
- local_activations = self.activations[-1]
212
-
213
- # Собираем активации со всех GPU
214
- if self.accelerator.num_processes > 1:
215
- gathered_activations = self.accelerator.gather(local_activations)
216
- else:
217
- gathered_activations = local_activations
218
-
219
- batch_size = gathered_activations.shape[0]
220
- if batch_size < 2:
221
- return torch.tensor(0.0, requires_grad=True)
222
-
223
- # Переводим в float32 для стабильности
224
- gathered_activations = gathered_activations.float()
225
-
226
- # Нормализация с eps для стабильности
227
- norm = torch.norm(gathered_activations, dim=1, keepdim=True).clamp(min=1e-12)
228
- sf = gathered_activations / norm
229
-
230
- # Вычисляем расстояния
231
- distance = torch.nn.functional.pdist(sf, p=2)
232
- distance = distance.clamp(min=1e-12) # избегаем слишком маленьких значений
233
- distance_squared = distance ** 2
234
-
235
- # Экспонента с клиппингом
236
- exp_neg_dist = torch.exp((-distance_squared / self.temperature).clamp(min=-20, max=20))
237
- exp_neg_dist = exp_neg_dist + 1e-12 # избегаем нулей
238
-
239
- # Среднее и лог
240
- mean_exp = torch.mean(exp_neg_dist)
241
- dispersive_loss = torch.log(mean_exp.clamp(min=1e-12))
242
-
243
- return dispersive_loss
244
-
245
- def compute_dispersive_loss(self):
246
- if not self.activations:
247
- return torch.tensor(0.0, requires_grad=True)
248
-
249
- local_activations = self.activations[-1].float()
250
-
251
- # Собираем активации со всех GPU
252
- if self.accelerator.num_processes > 1:
253
- gathered_activations = self.accelerator.gather(local_activations)
254
- else:
255
- gathered_activations = local_activations
256
-
257
- batch_size = gathered_activations.shape[0]
258
- if batch_size < 2:
259
- return torch.tensor(0.0, requires_grad=True)
260
-
261
- # Нормализация и вычисление loss
262
- sf = gathered_activations / torch.norm(gathered_activations, dim=1, keepdim=True)
263
- sf = sf.float()
264
- distance = torch.nn.functional.pdist(sf, p=2) ** 2
265
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
266
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
267
-
268
- # ВАЖНО: он отриц и должен падать
269
- return dispersive_loss
270
-
271
-
272
- def compute_dispersive_loss_single(self):
273
- if not self.activations:
274
- return torch.tensor(0.0, requires_grad=True)
275
-
276
- local_activations = self.activations[-1] # Активации с текущего GPU
277
-
278
- # Собираем активации со всех GPU
279
- if self.accelerator.num_processes > 1:
280
- # Используем accelerate для сбора
281
- gathered_activations = self.accelerator.gather(local_activations)
282
- else:
283
- gathered_activations = local_activations
284
-
285
- # На главном процессе вычисляем loss
286
- if self.accelerator.is_main_process:
287
- batch_size = gathered_activations.shape[0]
288
- if batch_size < 2:
289
- return torch.tensor(0.0, requires_grad=True)
290
-
291
- # Нормализация и вычисление loss
292
- sf = gathered_activations / torch.norm(gathered_activations, dim=1, keepdim=True)
293
- distance = torch.nn.functional.pdist(sf, p=2) ** 2
294
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
295
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
296
-
297
- return dispersive_loss
298
- else:
299
- # На не-главных процессах возвращаем 0
300
- return torch.tensor(0.0, requires_grad=True)
301
-
302
- def clear_activations(self):
303
- self.activations.clear()
304
-
305
- def remove_hooks(self):
306
- for hook in self.hooks:
307
- hook.remove()
308
- self.hooks.clear()
309
-
310
-
311
- # --------------------------- Инициализация WandB ---------------------------
312
- if use_wandb and accelerator.is_main_process:
313
- wandb.init(project=project+lora_name, config={
314
- "batch_size": batch_size,
315
- "base_learning_rate": base_learning_rate,
316
- "num_epochs": num_epochs,
317
- "fbp": fbp,
318
- "optimizer_type": optimizer_type,
319
- })
320
-
321
- # Включение Flash Attention 2/SDPA
322
- torch.backends.cuda.enable_flash_sdp(True)
323
- # --------------------------- Инициализация Accelerator --------------------
324
- gen = torch.Generator(device=device)
325
- gen.manual_seed(seed)
326
-
327
- # --------------------------- Загрузка моделей ---------------------------
328
- # VAE загружается на CPU для экономии GPU-памяти
329
- vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval()
330
-
331
- # DDPMScheduler с V_Prediction и Zero-SNR
332
- scheduler = DDPMScheduler(
333
- num_train_timesteps=1000, # Полный график шагов для обучения
334
- prediction_type="v_prediction", # V-Prediction
335
- rescale_betas_zero_snr=True, # Включение Zero-SNR
336
- clip_sample = clip_sample,
337
- steps_offset = steps_offset
338
- )
339
-
340
-
341
- class DistributedResolutionBatchSampler(Sampler):
342
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
343
- self.dataset = dataset
344
- self.batch_size = max(1, batch_size // num_replicas)
345
- self.num_replicas = num_replicas
346
- self.rank = rank
347
- self.shuffle = shuffle
348
- self.drop_last = drop_last
349
- self.epoch = 0
350
-
351
- # Используем numpy для ускорения
352
- try:
353
- widths = np.array(dataset["width"])
354
- heights = np.array(dataset["height"])
355
- except KeyError:
356
- widths = np.zeros(len(dataset))
357
- heights = np.zeros(len(dataset))
358
-
359
- # Создаем уникальные ключи для размеров
360
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
361
-
362
- # Группируем индексы по размерам используя numpy
363
- self.size_groups = {}
364
- for w, h in self.size_keys:
365
- mask = (widths == w) & (heights == h)
366
- self.size_groups[(w, h)] = np.where(mask)[0]
367
-
368
- # Предварительно вычисляем количество полных батчей для каждой группы
369
- self.group_num_batches = {}
370
- total_batches = 0
371
- for size, indices in self.size_groups.items():
372
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
373
- self.group_num_batches[size] = num_full_batches
374
- total_batches += num_full_batches
375
-
376
- # Округляем до числа, делящегося на num_replicas
377
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
378
-
379
- def __iter__(self):
380
- # print(f"Rank {self.rank}: Starting iteration")
381
- # Очищаем CUDA кэш перед формированием новых батчей
382
- if torch.cuda.is_available():
383
- torch.cuda.empty_cache()
384
- all_batches = []
385
- rng = np.random.RandomState(self.epoch)
386
-
387
- for size, indices in self.size_groups.items():
388
- # print(f"Rank {self.rank}: Processing size {size}, {len(indices)} samples")
389
- indices = indices.copy()
390
- if self.shuffle:
391
- rng.shuffle(indices)
392
-
393
- num_full_batches = self.group_num_batches[size]
394
- if num_full_batches == 0:
395
- continue
396
-
397
- # Берем только индексы для полных батчей
398
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
399
-
400
- # Reshape для быстрого разделения на батчи
401
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
402
-
403
- # Выбираем часть для текущего GPU
404
- start_idx = self.rank * self.batch_size
405
- end_idx = start_idx + self.batch_size
406
- gpu_batches = batches[:, start_idx:end_idx]
407
-
408
- all_batches.extend(gpu_batches)
409
-
410
- if self.shuffle:
411
- rng.shuffle(all_batches)
412
-
413
- # Синхронизируем все процессы после формирования батчей
414
- accelerator.wait_for_everyone()
415
- # print(f"Rank {self.rank}: Created {len(all_batches)} batches")
416
- return iter(all_batches)
417
-
418
- def __len__(self):
419
- return self.num_batches
420
-
421
- def set_epoch(self, epoch):
422
- self.epoch = epoch
423
-
424
- # Функция для выборки фиксированных семплов по размерам
425
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
426
- """Выбирает фиксированные семплы для каждого уникального разрешения"""
427
- # Группируем по размерам
428
- size_groups = defaultdict(list)
429
- try:
430
- widths = dataset["width"]
431
- heights = dataset["height"]
432
- except KeyError:
433
- widths = [0] * len(dataset)
434
- heights = [0] * len(dataset)
435
- for i, (w, h) in enumerate(zip(widths, heights)):
436
- size = (w, h)
437
- size_groups[size].append(i)
438
-
439
- # Выбираем фиксированные примеры из каждой группы
440
- fixed_samples = {}
441
- for size, indices in size_groups.items():
442
- # Определяем сколько семплов брать из этой группы
443
- n_samples = min(samples_per_group, len(indices))
444
- if len(size_groups)==1:
445
- n_samples = samples_to_generate
446
- if n_samples == 0:
447
- continue
448
-
449
- # Выбираем случайные индексы
450
- sample_indices = random.sample(indices, n_samples)
451
- samples_data = [dataset[idx] for idx in sample_indices]
452
-
453
- # Собираем данные
454
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
455
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
456
- texts = [item["text"] for item in samples_data]
457
-
458
- # Сохраняем для этого размера
459
- fixed_samples[size] = (latents, embeddings, texts)
460
-
461
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
462
- return fixed_samples
463
-
464
- if limit > 0:
465
- dataset = load_from_disk(ds_path).select(range(limit))
466
- else:
467
- dataset = load_from_disk(ds_path)
468
-
469
- def collate_fn_simple(batch):
470
- # Преобразуем список в тензоры и перемещаем на девайс
471
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
472
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
473
- return latents, embeddings
474
-
475
- def collate_fn(batch):
476
- if not batch:
477
- return [], []
478
-
479
- # Берем эталонную форму
480
- ref_vae_shape = np.array(batch[0]["vae"]).shape
481
- ref_embed_shape = np.array(batch[0]["embeddings"]).shape
482
-
483
- # Фильтруем
484
- valid_latents = []
485
- valid_embeddings = []
486
- for item in batch:
487
- if (np.array(item["vae"]).shape == ref_vae_shape and
488
- np.array(item["embeddings"]).shape == ref_embed_shape):
489
- valid_latents.append(item["vae"])
490
- valid_embeddings.append(item["embeddings"])
491
-
492
- # Создаем тензоры
493
- latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype)
494
- embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
495
-
496
- return latents, embeddings
497
-
498
- # Создаем ResolutionBatchSampler на основе индексов от DistributedSampler
499
- batch_sampler = DistributedResolutionBatchSampler(
500
- dataset=dataset,
501
- batch_size=batch_size,
502
- num_replicas=accelerator.num_processes,
503
- rank=accelerator.process_index,
504
- shuffle=shuffle
505
- )
506
-
507
- # Создаем DataLoader
508
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
509
-
510
- print("Total samples",len(dataloader))
511
- dataloader = accelerator.prepare(dataloader)
512
-
513
- # Инициализация переменных для возобновления обучения
514
- start_epoch = 0
515
- global_step = 0
516
-
517
- # Расчёт общего количества шагов
518
- total_training_steps = (len(dataloader) * num_epochs)
519
- # Get the world size
520
- world_size = accelerator.state.num_processes
521
- #print(f"World Size: {world_size}")
522
-
523
- # Опция загрузки модели из последнего чекпоинта (если существует)
524
- latest_checkpoint = os.path.join(checkpoints_folder, project)
525
- if os.path.isdir(latest_checkpoint):
526
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
527
- #if dtype == torch.float32:
528
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
529
- #else:
530
- #unet = UNet2DConditionModel.from_pretrained(latest_checkpoint, variant="fp16").to(device=device,dtype=dtype)
531
- if unet_gradient:
532
- unet.enable_gradient_checkpointing()
533
- unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
534
- try:
535
- unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
536
- except Exception as e:
537
- print(f"Ошибка при включении SDPA: {e}")
538
- print("Попытка использовать enable_xformers_memory_efficient_attention.")
539
- unet.set_use_memory_efficient_attention_xformers(True)
540
-
541
- if hasattr(torch.backends.cuda, "flash_sdp_enabled"):
542
- print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}")
543
- if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"):
544
- print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
545
- if hasattr(torch.nn.functional, "get_flash_attention_available"):
546
- print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
547
-
548
- # Регистрируем хук на модел
549
- if dispersive_loss:
550
- dispersive_hook = AccelerateDispersiveLoss(
551
- accelerator=accelerator,
552
- temperature=dispersive_temperature,
553
- weight=dispersive_weight
554
- )
555
-
556
- if torch_compile:
557
- print("compiling")
558
- torch.set_float32_matmul_precision('high')
559
- unet = torch.compile(unet)#, mode="reduce-overhead", fullgraph=True)
560
- print("compiling - ok")
561
-
562
- if lora_name:
563
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
564
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
565
- from peft.tuners.lora import LoraModel
566
- import os
567
- # 1. Замораживаем все параметры UNet
568
- unet.requires_grad_(False)
569
- print("Параметры базового UNet заморожены.")
570
-
571
- # 2. Создаем конфигурацию LoRA
572
- lora_config = LoraConfig(
573
- r=lora_rank,
574
- lora_alpha=lora_alpha,
575
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
576
- )
577
- unet.add_adapter(lora_config)
578
-
579
- # 3. Оборачиваем UNet в PEFT-модель
580
- from peft import get_peft_model
581
-
582
- peft_unet = get_peft_model(unet, lora_config)
583
-
584
- # 4. Получаем параметры для оптимизации
585
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
586
-
587
-
588
- # 5. Выводим информацию о количестве параметров
589
- if accelerator.is_main_process:
590
- lora_params_count = sum(p.numel() for p in params_to_optimize)
591
- total_params_count = sum(p.numel() for p in unet.parameters())
592
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
593
- print(f"Общее количество параметров UNet: {total_params_count:,}")
594
-
595
- # 6. Путь для сохранения
596
- lora_save_path = os.path.join("lora", lora_name)
597
- os.makedirs(lora_save_path, exist_ok=True)
598
-
599
- # 7. Функция для сохранения
600
- def save_lora_checkpoint(model):
601
- if accelerator.is_main_process:
602
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
603
- from peft.utils.save_and_load import get_peft_model_state_dict
604
- # Получаем state_dict только LoRA
605
- lora_state_dict = get_peft_model_state_dict(model)
606
-
607
- # Сохраняем веса
608
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
609
-
610
- # Сохраняем конфиг
611
- model.peft_config["default"].save_pretrained(lora_save_path)
612
- # SDXL must be compatible
613
- from diffusers import StableDiffusionXLPipeline
614
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
615
-
616
- # --------------------------- Оптимизатор ---------------------------
617
- # Определяем параметры для оптимизации
618
- #unet = torch.compile(unet)
619
- if lora_name:
620
- # Если используется LoRA, оптимизируем только параметры LoRA
621
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
622
- else:
623
- # Иначе оптимизируем все параметры
624
- if fbp:
625
- trainable_params = list(unet.parameters())
626
-
627
- def create_optimizer(name, params):
628
- if name == "adam8bit":
629
- return bnb.optim.AdamW8bit(
630
- params, lr=base_learning_rate, betas=(0.9, 0.97), eps=1e-5, weight_decay=0.001,
631
- percentile_clipping=percentile_clipping
632
- )
633
- elif name == "adam":
634
- return torch.optim.AdamW(
635
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
636
- )
637
- elif name == "lion8bit":
638
- return bnb.optim.Lion8bit(
639
- params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
640
- percentile_clipping=percentile_clipping
641
- )
642
- elif name == "adafactor":
643
- from transformers import Adafactor
644
- return Adafactor(
645
- params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
646
- warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
647
- beta1=0.9, weight_decay=0.01
648
- )
649
- else:
650
- raise ValueError(f"Unknown optimizer: {name}")
651
-
652
- if fbp:
653
- # Создаем отдельный оптимизатор для каждого параметра
654
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
655
-
656
- def optimizer_hook(param):
657
- optimizer_dict[param].step()
658
- optimizer_dict[param].zero_grad(set_to_none=True)
659
-
660
- for param in trainable_params:
661
- param.register_post_accumulate_grad_hook(optimizer_hook)
662
-
663
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
664
- else:
665
- optimizer = create_optimizer(optimizer_type, unet.parameters())
666
-
667
- def lr_schedule(step):
668
- x = step / (total_training_steps * world_size)
669
- warmup = 0.05
670
-
671
- if not use_decay:
672
- return base_learning_rate
673
- if x < warmup:
674
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
675
-
676
- decay_ratio = (x - warmup) / (1 - warmup)
677
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
678
- (1 + math.cos(math.pi * decay_ratio))
679
-
680
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
681
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
682
-
683
- # Регистрация хуков ПОСЛЕ prepare
684
- if dispersive_loss:
685
- dispersive_hook.register_hooks(unet, "down_blocks.2")
686
-
687
- # --------------------------- Фиксированные семплы для генерации ---------------------------
688
- # Примеры фиксированных семплов по размерам
689
- fixed_samples = get_fixed_samples_by_resolution(dataset)
690
-
691
- @torch.compiler.disable()
692
- @torch.no_grad()
693
- def generate_and_save_samples(fixed_samples_cpu, step):
694
- """
695
- Генерирует семплы для каждого из разрешений и сохраняет их.
696
-
697
- Args:
698
- fixed_samples_cpu: Словарь, где ключи - размеры (width, height),
699
- а значения - кортежи (latents, embeddings, text) на CPU.
700
- step: Текущий шаг обучения
701
- """
702
- original_model = None # Инициализируем, чтобы finally не ругался
703
- try:
704
-
705
- original_model = accelerator.unwrap_model(unet).eval()
706
-
707
- vae.to(device=device, dtype=dtype)
708
- vae.eval()
709
-
710
- scheduler.set_timesteps(n_diffusion_steps)
711
-
712
- all_generated_images = []
713
- all_captions = []
714
-
715
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
716
- width, height = size
717
-
718
- sample_latents = sample_latents.to(dtype=dtype)
719
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype)
720
-
721
- # Инициализируем латенты случайным шумом
722
- # sample_latents уже в dtype, так что noise будет создан в dtype
723
- noise = torch.randn(
724
- sample_latents.shape, # Используем форму от sample_latents, которые теперь на GPU и fp16
725
- generator=gen,
726
- device=device,
727
- dtype=sample_latents.dtype
728
- )
729
- current_latents = noise.clone()
730
-
731
- # Подготовка текстовых эмбеддингов для guidance
732
- if guidance_scale > 0:
733
- # empty_embeddings должны быть того же типа и на том же устройстве
734
- empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
735
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
736
- else:
737
- text_embeddings_batch = sample_text_embeddings
738
-
739
- for t in scheduler.timesteps:
740
- t_batch = t.repeat(current_latents.shape[0]).to(device) # Убедимся, что t на устройстве
741
-
742
- if guidance_scale > 0:
743
- latent_model_input = torch.cat([current_latents] * 2)
744
- else:
745
- latent_model_input = current_latents
746
-
747
- latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
748
-
749
- # Предсказание шума (UNet)
750
- noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
751
-
752
- if guidance_scale > 0:
753
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
754
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
755
-
756
- current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
757
-
758
- #print(f"current_latents Min: {current_latents.min()} Max: {current_latents.max()}")
759
- # Декодирование через VAE
760
- latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
761
- decoded = vae.decode(latent_for_vae).sample
762
-
763
- # Преобразуем тензоры в PIL-изображения
764
- # Для математики с изображением (нормализация) лучше перейти в fp32
765
- decoded_fp32 = decoded.to(torch.float32)
766
- for img_idx, img_tensor in enumerate(decoded_fp32):
767
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
768
- # If NaNs or infs are present, print them
769
- if np.isnan(img).any():
770
- print("NaNs found, saving stoped! Step:", step)
771
- save_model = False
772
- pil_img = Image.fromarray((img * 255).astype("uint8"))
773
-
774
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
775
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
776
- max_w_overall = max(255, max_w_overall)
777
- max_h_overall = max(255, max_h_overall)
778
-
779
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
780
- all_generated_images.append(padded_img)
781
-
782
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
783
- all_captions.append(caption_text)
784
-
785
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
786
- pil_img.save(sample_path, "JPEG", quality=96)
787
-
788
- if use_wandb and accelerator.is_main_process:
789
- wandb_images = [
790
- wandb.Image(img, caption=f"{all_captions[i]}")
791
- for i, img in enumerate(all_generated_images)
792
- ]
793
- wandb.log({"generated_images": wandb_images, "global_step": step})
794
-
795
- finally:
796
- vae.to("cpu") # Перемещаем VAE обратно на CPU
797
- # Очистка переменных, которые являются тензорами и были созданы в функции
798
- for var in list(locals().keys()):
799
- if isinstance(locals()[var], torch.Tensor):
800
- del locals()[var]
801
-
802
- torch.cuda.empty_cache()
803
- gc.collect()
804
-
805
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
806
- if accelerator.is_main_process:
807
- if save_model:
808
- print("Генерация сэмплов до старта обучения...")
809
- generate_and_save_samples(fixed_samples,0)
810
- accelerator.wait_for_everyone()
811
-
812
- # Модифицируем функцию сохранения модели для поддержки LoRA
813
- def save_checkpoint(unet,variant=""):
814
- if accelerator.is_main_process:
815
- if lora_name:
816
- # Сохраняем только LoRA адаптеры
817
- save_lora_checkpoint(unet)
818
- else:
819
- # Сохраняем полную модель
820
- if variant!="":
821
- accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
822
- else:
823
- accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
824
- unet = unet.to(dtype=dtype)
825
-
826
- # --------------------------- Тренировочный цикл ---------------------------
827
- # Для логирования среднего лосса каждые % эпохи
828
- if accelerator.is_main_process:
829
- print(f"Total steps per GPU: {total_training_steps}")
830
-
831
- epoch_loss_points = []
832
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
833
-
834
- # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
835
- steps_per_epoch = len(dataloader)
836
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
837
- min_loss = 1.
838
-
839
- # Начинаем с указанной эпохи (полезно при возобновлении)
840
- for epoch in range(start_epoch, start_epoch + num_epochs):
841
- batch_losses = []
842
- batch_tlosses = []
843
- batch_grads = []
844
- #unet = unet.to(dtype = dtype)
845
- batch_sampler.set_epoch(epoch)
846
- accelerator.wait_for_everyone()
847
- unet.train()
848
- print("epoch:",epoch)
849
- for step, (latents, embeddings) in enumerate(dataloader):
850
- with accelerator.accumulate(unet):
851
- if save_model == False and step == 5 :
852
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
853
- print(f"Шаг {step}: {used_gb:.2f} GB")
854
-
855
- # Forward pass
856
- noise = torch.randn_like(latents, dtype=latents.dtype)
857
-
858
- timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps,
859
- (latents.shape[0],), device=device).long()
860
-
861
- # Добавляем шум к латентам
862
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
863
-
864
- # Очищаем активации перед forward pass
865
- if dispersive_loss:
866
- dispersive_hook.clear_activations()
867
-
868
- # Используем целевое значение
869
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
870
- target_pred = scheduler.get_velocity(latents, noise, timesteps)
871
-
872
- # Считаем лосс
873
- loss = torch.nn.functional.mse_loss(model_pred.float(), target_pred.float())
874
-
875
- # Dispersive Loss
876
- #Идентичные векторы: Loss = -0.0000
877
- #Ортогональные векторы: Loss = -3.9995
878
- if dispersive_loss:
879
- with torch.amp.autocast('cuda', enabled=False):
880
- dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
881
- if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss):
882
- print(f"Rank {accelerator.process_index}: Found nan/inf in dispersive_loss: {total_loss}")
883
-
884
- # Итоговый loss
885
- # dispersive_loss должен падать и тотал падать - поэтому плюс
886
- total_loss = loss + dispersive_loss
887
-
888
- # Проверяем на nan/inf перед backward
889
- if torch.isnan(loss) or torch.isinf(loss):
890
- print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
891
- save_model = False
892
- break
893
-
894
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
895
- accelerator.wait_for_everyone()
896
-
897
- # Делаем backward через Accelerator
898
- accelerator.backward(total_loss)
899
-
900
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
901
- accelerator.wait_for_everyone()
902
-
903
- grad = 0.0
904
- if not fbp:
905
- if accelerator.sync_gradients:
906
- with torch.amp.autocast('cuda', enabled=False):
907
- grad = accelerator.clip_grad_norm_(unet.parameters(), 0.25)
908
- optimizer.step()
909
- lr_scheduler.step()
910
- optimizer.zero_grad(set_to_none=True)
911
-
912
- # Увеличиваем счетчик глобальных шагов
913
- global_step += 1
914
-
915
- # Обновляем прогресс-бар
916
- progress_bar.update(1)
917
-
918
- # Логируем метрики
919
- if accelerator.is_main_process:
920
- if fbp:
921
- current_lr = base_learning_rate
922
- else:
923
- current_lr = lr_scheduler.get_last_lr()[0]
924
- batch_losses.append(loss.detach().item())
925
- batch_tlosses.append(total_loss.detach().item())
926
- batch_grads.append(grad)
927
-
928
- # Логируем в Wandb
929
- if use_wandb:
930
- wandb.log({
931
- "mse_loss": loss.detach().item(),
932
- "learning_rate": current_lr,
933
- "epoch": epoch,
934
- "grad": grad,
935
- "global_step": global_step,
936
- "dispersive_loss": dispersive_loss,
937
- "total_loss": total_loss
938
- })
939
-
940
- # Генерируем сэмплы с заданным интервалом
941
- if global_step % sample_interval == 0:
942
- generate_and_save_samples(fixed_samples,global_step)
943
-
944
- # Выводим текущий лосс
945
- avg_loss = np.mean(batch_losses[-sample_interval:])
946
- avg_tloss = np.mean(batch_tlosses[-sample_interval:])
947
- avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item()
948
- print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}")
949
-
950
- if save_model:
951
- print("saving:",avg_loss < min_loss*save_barrier)
952
- if avg_loss < min_loss*save_barrier:
953
- min_loss = avg_loss
954
- save_checkpoint(unet)
955
- if use_wandb:
956
- wandb.log({"interm_loss": avg_loss})
957
- wandb.log({"interm_totalloss": avg_tloss})
958
- wandb.log({"interm_grad": avg_grad})
959
-
960
-
961
- # По окончании эпохи
962
- #accelerator.wait_for_everyone()
963
- if accelerator.is_main_process:
964
- avg_epoch_loss = np.mean(batch_losses)
965
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
966
- if use_wandb:
967
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
968
-
969
- # Завершение обучения - сохраняем финальную модель
970
- if dispersive_loss:
971
- dispersive_hook.remove_hooks()
972
- if accelerator.is_main_process:
973
- print("Обучение завершено! Сохраняем финальную модель...")
974
- # Сохраняем основную модель
975
- if save_model:
976
- save_checkpoint(unet,"fp16")
977
- print("Готово!")
978
-
979
- # randomize ode timesteps
980
- # input_timestep = torch.round(
981
- # F.sigmoid(torch.randn((n,), device=latents.device)), decimals=3
982
- # )
983
-
984
- #def create_distribution(num_points, device=None):
985
- # # Диапазон вероятностей на оси x
986
- # x = torch.linspace(0, 1, num_points, device=device)
987
-
988
- # Пользовательская функция плотности вероятности
989
- # probabilities = -7.7 * ((x - 0.5) ** 2) + 2
990
-
991
- # Нормализация, чтобы сумма равнялась 1
992
- # probabilities /= probabilities.sum()
993
-
994
- # return x, probabilities
995
-
996
- #def sample_from_distribution(x, probabilities, n, device=None):
997
- # Выбор индексов на основе распределения вероятностей
998
- # indices = torch.multinomial(probabilities, n, replacement=True)
999
- # return x[indices]
1000
-
1001
- # Пример использования
1002
- #num_points = 1000 # Количество точек в диапазоне
1003
- #n = latents.shape[0] # Количество временных шагов для выборки
1004
- #x, probabilities = create_distribution(num_points, device=latents.device)
1005
- #timesteps = sample_from_distribution(x, probabilities, n, device=latents.device)
1006
-
1007
- # Преобразование в формат, подходящий для вашего кода
1008
- #timesteps = (timesteps * (scheduler.config.num_train_timesteps - 1)).long()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train-Copy2.py DELETED
@@ -1,874 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
-
27
- # --------------------------- Параметры ---------------------------
28
- ds_path = "datasets/384"
29
- project = "micro"
30
- batch_size = 64
31
- base_learning_rate = 1e-4
32
- min_learning_rate = 5e-5
33
- num_epochs = 50
34
- # samples/save per epoch
35
- sample_interval_share = 10
36
- use_wandb = True
37
- save_model = True
38
- use_decay = True
39
- fbp = False # fused backward pass
40
- optimizer_type = "adam8bit"
41
- torch_compile = False
42
- unet_gradient = True
43
- clip_sample = False #Scheduler
44
- fixed_seed = False
45
- shuffle = True
46
- dispersive_loss_enabled = True
47
- torch.backends.cuda.matmul.allow_tf32 = True
48
- torch.backends.cudnn.allow_tf32 = True
49
- torch.backends.cuda.enable_mem_efficient_sdp(False)
50
- dtype = torch.float32
51
- save_barrier = 1.03
52
- warmup_percent = 0.01
53
- dispersive_temperature=0.5
54
- dispersive_weight= 0.05
55
- percentile_clipping = 95 # 8bit optim
56
- betta2 = 0.97
57
- eps = 1e-6
58
- clip_grad_norm = 1.0
59
- steps_offset = 0 # Scheduler
60
- limit = 0
61
- checkpoints_folder = ""
62
- mixed_precision = "no" #"fp16"
63
- gradient_accumulation_steps = 1
64
- accelerator = Accelerator(
65
- mixed_precision=mixed_precision,
66
- gradient_accumulation_steps=gradient_accumulation_steps
67
- )
68
- device = accelerator.device
69
-
70
- # Параметры для диффузии
71
- n_diffusion_steps = 50
72
- samples_to_generate = 12
73
- guidance_scale = 5
74
-
75
- # Папки для сохранения результатов
76
- generated_folder = "samples"
77
- os.makedirs(generated_folder, exist_ok=True)
78
-
79
- # Настройка seed для воспроизводимости
80
- current_date = datetime.now()
81
- seed = int(current_date.strftime("%Y%m%d"))
82
- if fixed_seed:
83
- torch.manual_seed(seed)
84
- np.random.seed(seed)
85
- random.seed(seed)
86
- if torch.cuda.is_available():
87
- torch.cuda.manual_seed_all(seed)
88
-
89
- # --------------------------- Параметры LoRA ---------------------------
90
- # pip install peft
91
- lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
92
- lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
93
- lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
94
-
95
- print("init")
96
-
97
- class AccelerateDispersiveLoss:
98
- def __init__(self, accelerator, temperature=0.5, weight=0.5):
99
- self.accelerator = accelerator
100
- self.temperature = temperature
101
- self.weight = weight
102
- self.activations = []
103
- self.hooks = []
104
-
105
- def register_hooks(self, model, target_layer="down_blocks.0"):
106
- unwrapped_model = self.accelerator.unwrap_model(model)
107
- print("=== Поиск слоев в unwrapped модели ===")
108
- for name, module in unwrapped_model.named_modules():
109
- if target_layer in name:
110
- hook = module.register_forward_hook(self.hook_fn)
111
- self.hooks.append(hook)
112
- print(f"✅ Хук зарегистрирован на: {name}")
113
- break
114
-
115
- def hook_fn(self, module, input, output):
116
-
117
- if isinstance(output, tuple):
118
- activation = output[0]
119
- else:
120
- activation = output
121
-
122
- if len(activation.shape) > 2:
123
- activation = activation.view(activation.shape[0], -1)
124
-
125
- self.activations.append(activation.detach())
126
-
127
- def compute_dispersive_loss(self):
128
- if not self.activations:
129
- return torch.tensor(0.0, requires_grad=True)
130
-
131
- local_activations = self.activations[-1].float()
132
-
133
- batch_size = local_activations.shape[0]
134
- if batch_size < 2:
135
- return torch.tensor(0.0, requires_grad=True)
136
-
137
- # Нормализация и вычисление loss
138
- sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
139
- distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
140
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
141
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
142
-
143
- # ВАЖНО: он отриц и должен падать
144
- return dispersive_loss
145
-
146
- def clear_activations(self):
147
- self.activations.clear()
148
-
149
- def remove_hooks(self):
150
- for hook in self.hooks:
151
- hook.remove()
152
- self.hooks.clear()
153
-
154
-
155
-
156
- # --------------------------- Инициализация WandB ---------------------------
157
- if use_wandb and accelerator.is_main_process:
158
- wandb.init(project=project+lora_name, config={
159
- "batch_size": batch_size,
160
- "base_learning_rate": base_learning_rate,
161
- "num_epochs": num_epochs,
162
- "fbp": fbp,
163
- "optimizer_type": optimizer_type,
164
- })
165
-
166
- # Включение Flash Attention 2/SDPA
167
- torch.backends.cuda.enable_flash_sdp(True)
168
- # --------------------------- Инициализация Accelerator --------------------
169
- gen = torch.Generator(device=device)
170
- gen.manual_seed(seed)
171
-
172
- # --------------------------- Загрузка моделей ---------------------------
173
- # VAE загружается на CPU для экономии GPU-памяти
174
- vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval()
175
-
176
- # DDPMScheduler с V_Prediction и Zero-SNR
177
- scheduler = DDPMScheduler(
178
- num_train_timesteps=1000, # Полный график шагов для обучения
179
- prediction_type="v_prediction", # V-Prediction
180
- rescale_betas_zero_snr=True, # Включение Zero-SNR
181
- clip_sample = clip_sample,
182
- steps_offset = steps_offset
183
- )
184
-
185
-
186
- class DistributedResolutionBatchSampler(Sampler):
187
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
188
- self.dataset = dataset
189
- self.batch_size = max(1, batch_size // num_replicas)
190
- self.num_replicas = num_replicas
191
- self.rank = rank
192
- self.shuffle = shuffle
193
- self.drop_last = drop_last
194
- self.epoch = 0
195
-
196
- # Используем numpy для ускорения
197
- try:
198
- widths = np.array(dataset["width"])
199
- heights = np.array(dataset["height"])
200
- except KeyError:
201
- widths = np.zeros(len(dataset))
202
- heights = np.zeros(len(dataset))
203
-
204
- # Создаем уникальные ключи для размеров
205
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
206
-
207
- # Группируем индексы по размерам используя numpy
208
- self.size_groups = {}
209
- for w, h in self.size_keys:
210
- mask = (widths == w) & (heights == h)
211
- self.size_groups[(w, h)] = np.where(mask)[0]
212
-
213
- # Предварительно вычисляем количество полных батчей для каждой группы
214
- self.group_num_batches = {}
215
- total_batches = 0
216
- for size, indices in self.size_groups.items():
217
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
218
- self.group_num_batches[size] = num_full_batches
219
- total_batches += num_full_batches
220
-
221
- # Округляем до числа, делящегося на num_replicas
222
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
223
-
224
- def __iter__(self):
225
- # print(f"Rank {self.rank}: Starting iteration")
226
- # Очищаем CUDA кэш перед формированием новых батчей
227
- if torch.cuda.is_available():
228
- torch.cuda.empty_cache()
229
- all_batches = []
230
- rng = np.random.RandomState(self.epoch)
231
-
232
- for size, indices in self.size_groups.items():
233
- # print(f"Rank {self.rank}: Processing size {size}, {len(indices)} samples")
234
- indices = indices.copy()
235
- if self.shuffle:
236
- rng.shuffle(indices)
237
-
238
- num_full_batches = self.group_num_batches[size]
239
- if num_full_batches == 0:
240
- continue
241
-
242
- # Берем только индексы для полных батчей
243
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
244
-
245
- # Reshape для быстрого разделения на батчи
246
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
247
-
248
- # Выбираем часть для текущего GPU
249
- start_idx = self.rank * self.batch_size
250
- end_idx = start_idx + self.batch_size
251
- gpu_batches = batches[:, start_idx:end_idx]
252
-
253
- all_batches.extend(gpu_batches)
254
-
255
- if self.shuffle:
256
- rng.shuffle(all_batches)
257
-
258
- # Синхронизируем все процессы после формирования батчей
259
- accelerator.wait_for_everyone()
260
- # print(f"Rank {self.rank}: Created {len(all_batches)} batches")
261
- return iter(all_batches)
262
-
263
- def __len__(self):
264
- return self.num_batches
265
-
266
- def set_epoch(self, epoch):
267
- self.epoch = epoch
268
-
269
- # Функция для выборки фиксированных семплов по размерам
270
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
271
- """Выбирает фиксированные семплы для каждого уникального разрешения"""
272
- # Группируем по размерам
273
- size_groups = defaultdict(list)
274
- try:
275
- widths = dataset["width"]
276
- heights = dataset["height"]
277
- except KeyError:
278
- widths = [0] * len(dataset)
279
- heights = [0] * len(dataset)
280
- for i, (w, h) in enumerate(zip(widths, heights)):
281
- size = (w, h)
282
- size_groups[size].append(i)
283
-
284
- # Выбираем фиксированные примеры из каждой группы
285
- fixed_samples = {}
286
- for size, indices in size_groups.items():
287
- # Определяем сколько семплов брать из этой группы
288
- n_samples = min(samples_per_group, len(indices))
289
- if len(size_groups)==1:
290
- n_samples = samples_to_generate
291
- if n_samples == 0:
292
- continue
293
-
294
- # Выбираем случайные индексы
295
- sample_indices = random.sample(indices, n_samples)
296
- samples_data = [dataset[idx] for idx in sample_indices]
297
-
298
- # Собираем данные
299
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
300
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
301
- texts = [item["text"] for item in samples_data]
302
-
303
- # Сохраняем для этого размера
304
- fixed_samples[size] = (latents, embeddings, texts)
305
-
306
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
307
- return fixed_samples
308
-
309
- if limit > 0:
310
- dataset = load_from_disk(ds_path).select(range(limit))
311
- else:
312
- dataset = load_from_disk(ds_path)
313
-
314
- def collate_fn_simple(batch):
315
- # Преобразуем список в тензоры и перемещаем на девайс
316
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
317
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
318
- return latents, embeddings
319
-
320
- def collate_fn(batch):
321
- if not batch:
322
- return [], []
323
-
324
- # Берем эталонную форму
325
- ref_vae_shape = np.array(batch[0]["vae"]).shape
326
- ref_embed_shape = np.array(batch[0]["embeddings"]).shape
327
-
328
- # Фильтруем
329
- valid_latents = []
330
- valid_embeddings = []
331
- for item in batch:
332
- if (np.array(item["vae"]).shape == ref_vae_shape and
333
- np.array(item["embeddings"]).shape == ref_embed_shape):
334
- valid_latents.append(item["vae"])
335
- valid_embeddings.append(item["embeddings"])
336
-
337
- # Создаем тензоры
338
- latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype)
339
- embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
340
-
341
- return latents, embeddings
342
-
343
- # Создаем ResolutionBatchSampler на основе индексов от DistributedSampler
344
- batch_sampler = DistributedResolutionBatchSampler(
345
- dataset=dataset,
346
- batch_size=batch_size,
347
- num_replicas=accelerator.num_processes,
348
- rank=accelerator.process_index,
349
- shuffle=shuffle
350
- )
351
-
352
- # Создаем DataLoader
353
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
354
-
355
- print("Total samples",len(dataloader))
356
- dataloader = accelerator.prepare(dataloader)
357
-
358
- # Инициализация переменных для возобновления обучения
359
- start_epoch = 0
360
- global_step = 0
361
-
362
- # Расчёт общего количества шагов
363
- total_training_steps = (len(dataloader) * num_epochs)
364
- # Get the world size
365
- world_size = accelerator.state.num_processes
366
- #print(f"World Size: {world_size}")
367
-
368
- # Опция загрузки модели из последнего чекпоинта (если существует)
369
- latest_checkpoint = os.path.join(checkpoints_folder, project)
370
- if os.path.isdir(latest_checkpoint):
371
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
372
- #if dtype == torch.float32:
373
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
374
- #else:
375
- #unet = UNet2DConditionModel.from_pretrained(latest_checkpoint, variant="fp16").to(device=device,dtype=dtype)
376
- if unet_gradient:
377
- unet.enable_gradient_checkpointing()
378
- unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
379
- try:
380
- unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
381
- except Exception as e:
382
- print(f"Оши��ка при включении SDPA: {e}")
383
- print("Попытка использовать enable_xformers_memory_efficient_attention.")
384
- unet.set_use_memory_efficient_attention_xformers(True)
385
-
386
- if hasattr(torch.backends.cuda, "flash_sdp_enabled"):
387
- print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}")
388
- if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"):
389
- print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
390
- if hasattr(torch.nn.functional, "get_flash_attention_available"):
391
- print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
392
-
393
- # Регистрируем хук на модел
394
- if dispersive_loss_enabled:
395
- dispersive_hook = AccelerateDispersiveLoss(
396
- accelerator=accelerator,
397
- temperature=dispersive_temperature,
398
- weight=dispersive_weight
399
- )
400
-
401
- if torch_compile:
402
- print("compiling")
403
- torch.set_float32_matmul_precision('high')
404
- unet = torch.compile(unet, mode="reduce-overhead", fullgraph=False)
405
- print("compiling - ok")
406
-
407
- if lora_name:
408
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
409
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
410
- from peft.tuners.lora import LoraModel
411
- import os
412
- # 1. Замораживаем все параметры UNet
413
- unet.requires_grad_(False)
414
- print("Параметры базового UNet заморожены.")
415
-
416
- # 2. Создаем конфигурацию LoRA
417
- lora_config = LoraConfig(
418
- r=lora_rank,
419
- lora_alpha=lora_alpha,
420
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
421
- )
422
- unet.add_adapter(lora_config)
423
-
424
- # 3. Оборачиваем UNet в PEFT-модель
425
- from peft import get_peft_model
426
-
427
- peft_unet = get_peft_model(unet, lora_config)
428
-
429
- # 4. Получаем параметры для оптимизации
430
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
431
-
432
-
433
- # 5. Выводим информацию о количестве параметров
434
- if accelerator.is_main_process:
435
- lora_params_count = sum(p.numel() for p in params_to_optimize)
436
- total_params_count = sum(p.numel() for p in unet.parameters())
437
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
438
- print(f"Общее количество параметров UNet: {total_params_count:,}")
439
-
440
- # 6. Путь для сохранения
441
- lora_save_path = os.path.join("lora", lora_name)
442
- os.makedirs(lora_save_path, exist_ok=True)
443
-
444
- # 7. Функция для сохранения
445
- def save_lora_checkpoint(model):
446
- if accelerator.is_main_process:
447
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
448
- from peft.utils.save_and_load import get_peft_model_state_dict
449
- # Получаем state_dict только LoRA
450
- lora_state_dict = get_peft_model_state_dict(model)
451
-
452
- # Сохраняем веса
453
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
454
-
455
- # Сохраняем конфиг
456
- model.peft_config["default"].save_pretrained(lora_save_path)
457
- # SDXL must be compatible
458
- from diffusers import StableDiffusionXLPipeline
459
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
460
-
461
- # --------------------------- Оптимизатор ---------------------------
462
- # Определяем параметры для оптимизации
463
- #unet = torch.compile(unet)
464
- if lora_name:
465
- # Если используется LoRA, оптимизируем только параметры LoRA
466
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
467
- else:
468
- # Иначе оптимизируем все параметры
469
- if fbp:
470
- trainable_params = list(unet.parameters())
471
-
472
- def create_optimizer(name, params):
473
- if name == "adam8bit":
474
- return bnb.optim.AdamW8bit(
475
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001,
476
- percentile_clipping=percentile_clipping
477
- )
478
- elif name == "adam":
479
- return torch.optim.AdamW(
480
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
481
- )
482
- elif name == "lion8bit":
483
- return bnb.optim.Lion8bit(
484
- params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
485
- percentile_clipping=percentile_clipping
486
- )
487
- elif name == "adafactor":
488
- from transformers import Adafactor
489
- return Adafactor(
490
- params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
491
- warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
492
- beta1=0.9, weight_decay=0.01
493
- )
494
- else:
495
- raise ValueError(f"Unknown optimizer: {name}")
496
-
497
- if fbp:
498
- # Создаем отдельный оптимизатор для каждого параметра
499
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
500
-
501
- def optimizer_hook(param):
502
- optimizer_dict[param].step()
503
- optimizer_dict[param].zero_grad(set_to_none=True)
504
-
505
- for param in trainable_params:
506
- param.register_post_accumulate_grad_hook(optimizer_hook)
507
-
508
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
509
- else:
510
- optimizer = create_optimizer(optimizer_type, unet.parameters())
511
-
512
- def lr_schedule(step):
513
- x = step / (total_training_steps * world_size)
514
- warmup = warmup_percent
515
-
516
- if not use_decay:
517
- return base_learning_rate
518
- if x < warmup:
519
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
520
-
521
- decay_ratio = (x - warmup) / (1 - warmup)
522
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
523
- (1 + math.cos(math.pi * decay_ratio))
524
-
525
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
526
-
527
- num_params = sum(p.numel() for p in unet.parameters())
528
- print(f"[rank {accelerator.process_index}] total params: {num_params}")
529
- # Проверка на NaN/Inf
530
- for name, param in unet.named_parameters():
531
- if torch.isnan(param).any() or torch.isinf(param).any():
532
- print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
533
- # Опционально: заменить на нормальные значения
534
- #param.data = torch.randn_like(param) * 0.01
535
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
536
-
537
- # Регистрация хуков ПОСЛЕ prepare
538
- if dispersive_loss_enabled:
539
- dispersive_hook.register_hooks(unet, "down_blocks.2")
540
-
541
- # --------------------------- Фиксированные семплы для генерации ---------------------------
542
- # Примеры фиксированных семплов по размерам
543
- fixed_samples = get_fixed_samples_by_resolution(dataset)
544
-
545
- @torch.compiler.disable()
546
- @torch.no_grad()
547
- def generate_and_save_samples(fixed_samples_cpu, step):
548
- """
549
- Генерирует семплы для каждого из разрешений и сохраняет их.
550
-
551
- Args:
552
- fixed_samples_cpu: Словарь, где ключи - размеры (width, height),
553
- а значения - кортежи (latents, embeddings, text) на CPU.
554
- step: Текущий шаг обучения
555
- """
556
- original_model = None # Инициализируем, чтобы finally не ругался
557
- try:
558
-
559
- original_model = accelerator.unwrap_model(unet).eval()
560
-
561
- vae.to(device=device, dtype=dtype)
562
- vae.eval()
563
-
564
- scheduler.set_timesteps(n_diffusion_steps)
565
-
566
- all_generated_images = []
567
- all_captions = []
568
-
569
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
570
- width, height = size
571
-
572
- sample_latents = sample_latents.to(dtype=dtype)
573
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype)
574
-
575
- # Инициализируем латенты случайным шумом
576
- # sample_latents уже в dtype, так что noise будет создан в dtype
577
- noise = torch.randn(
578
- sample_latents.shape, # Используем форму от sample_latents, которые теперь на GPU и fp16
579
- generator=gen,
580
- device=device,
581
- dtype=sample_latents.dtype
582
- )
583
- current_latents = noise.clone()
584
-
585
- # Подготовка текстовых эмбеддингов для guidance
586
- if guidance_scale > 0:
587
- # empty_embeddings должны быть того же типа и на том же устройстве
588
- empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
589
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
590
- else:
591
- text_embeddings_batch = sample_text_embeddings
592
-
593
- for t in scheduler.timesteps:
594
- t_batch = t.repeat(current_latents.shape[0]).to(device) # Убедимся, что t на устройстве
595
-
596
- if guidance_scale > 0:
597
- latent_model_input = torch.cat([current_latents] * 2)
598
- else:
599
- latent_model_input = current_latents
600
-
601
- latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
602
-
603
- # Предсказание шума (UNet)
604
- noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
605
-
606
- if guidance_scale > 0:
607
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
608
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
609
-
610
- current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
611
-
612
- #print(f"current_latents Min: {current_latents.min()} Max: {current_latents.max()}")
613
- # Декодирование через VAE
614
- latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
615
- decoded = vae.decode(latent_for_vae).sample
616
-
617
- # Преобразуем тензоры в PIL-изображения
618
- # Для математики с изображением (нормализация) лучше перейти в fp32
619
- decoded_fp32 = decoded.to(torch.float32)
620
- for img_idx, img_tensor in enumerate(decoded_fp32):
621
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
622
- # If NaNs or infs are present, print them
623
- if np.isnan(img).any():
624
- print("NaNs found, saving stoped! Step:", step)
625
- save_model = False
626
- pil_img = Image.fromarray((img * 255).astype("uint8"))
627
-
628
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
629
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
630
- max_w_overall = max(255, max_w_overall)
631
- max_h_overall = max(255, max_h_overall)
632
-
633
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
634
- all_generated_images.append(padded_img)
635
-
636
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
637
- all_captions.append(caption_text)
638
-
639
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
640
- pil_img.save(sample_path, "JPEG", quality=96)
641
-
642
- if use_wandb and accelerator.is_main_process:
643
- wandb_images = [
644
- wandb.Image(img, caption=f"{all_captions[i]}")
645
- for i, img in enumerate(all_generated_images)
646
- ]
647
- wandb.log({"generated_images": wandb_images, "global_step": step})
648
-
649
- finally:
650
- vae.to("cpu") # Перемещаем VAE обратно на CPU
651
- # Очистка переменных, которые являются тензорами и были созданы в функции
652
- for var in list(locals().keys()):
653
- if isinstance(locals()[var], torch.Tensor):
654
- del locals()[var]
655
-
656
- torch.cuda.empty_cache()
657
- gc.collect()
658
-
659
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
660
- if accelerator.is_main_process:
661
- if save_model:
662
- print("Генерация сэмплов до старта обучения...")
663
- generate_and_save_samples(fixed_samples,0)
664
- accelerator.wait_for_everyone()
665
-
666
- # Модифицируем функцию сохранения модели для поддержки LoRA
667
- def save_checkpoint(unet,variant=""):
668
- if accelerator.is_main_process:
669
- if lora_name:
670
- # Сохраняем только LoRA адаптеры
671
- save_lora_checkpoint(unet)
672
- else:
673
- # Сохраняем полную модель
674
- if variant!="":
675
- accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
676
- else:
677
- accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
678
- unet = unet.to(dtype=dtype)
679
-
680
- # --------------------------- Тренировочный цикл ---------------------------
681
- # Для логирования среднего лосса каждые % эпохи
682
- if accelerator.is_main_process:
683
- print(f"Total steps per GPU: {total_training_steps}")
684
-
685
- epoch_loss_points = []
686
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
687
-
688
- # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
689
- steps_per_epoch = len(dataloader)
690
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
691
- min_loss = 1.
692
-
693
- # Начинаем с указанной эпохи (полезно при возобновлени��)
694
- for epoch in range(start_epoch, start_epoch + num_epochs):
695
- batch_losses = []
696
- batch_tlosses = []
697
- batch_grads = []
698
- #unet = unet.to(dtype = dtype)
699
- batch_sampler.set_epoch(epoch)
700
- accelerator.wait_for_everyone()
701
- unet.train()
702
- print("epoch:",epoch)
703
- for step, (latents, embeddings) in enumerate(dataloader):
704
- with accelerator.accumulate(unet):
705
- if save_model == False and step == 5 :
706
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
707
- print(f"Шаг {step}: {used_gb:.2f} GB")
708
-
709
- # Forward pass
710
- noise = torch.randn_like(latents, dtype=latents.dtype)
711
-
712
- timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps,
713
- (latents.shape[0],), device=device).long()
714
-
715
- # Добавляем шум к латентам
716
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
717
-
718
- # Очищаем активации перед forward pass
719
- if dispersive_loss_enabled:
720
- dispersive_hook.clear_activations()
721
-
722
- # Используем целевое значение
723
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
724
- target_pred = scheduler.get_velocity(latents, noise, timesteps)
725
-
726
- # Считаем лосс
727
- loss = torch.nn.functional.mse_loss(model_pred.float(), target_pred.float())
728
-
729
- # Dispersive Loss
730
- #Идентичные векторы: Loss = -0.0000
731
- #Ортогональные векторы: Loss = -3.9995
732
- if dispersive_loss_enabled:
733
- with torch.amp.autocast('cuda', enabled=False):
734
- dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
735
- if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss):
736
- print(f"Rank {accelerator.process_index}: Found nan/inf in dispersive_loss: {total_loss}")
737
-
738
- # Итоговый loss
739
- # dispersive_loss должен падать и тотал падать - поэтому плюс
740
- if dispersive_loss_enabled:
741
- total_loss = loss + dispersive_loss
742
- else:
743
- total_loss = loss
744
-
745
- # Проверяем на nan/inf перед backward
746
- if torch.isnan(loss) or torch.isinf(loss):
747
- print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
748
- save_model = False
749
- break
750
-
751
- if torch.isnan(total_loss) or torch.isinf(total_loss):
752
- print(f"Rank {accelerator.process_index}: Found nan/inf in total_loss: {total_loss}")
753
- print(f"Проблемный батч: step={step}, latents.shape={latents.shape}, embeddings.shape={embeddings.shape}")
754
- continue
755
-
756
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
757
- accelerator.wait_for_everyone()
758
-
759
- # Делаем backward через Accelerator
760
- accelerator.backward(total_loss)
761
-
762
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
763
- accelerator.wait_for_everyone()
764
-
765
- grad = torch.tensor(0.0, device=device)
766
- if not fbp:
767
- if accelerator.sync_gradients:
768
- with torch.amp.autocast('cuda', enabled=False):
769
- grad = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
770
- optimizer.step()
771
- lr_scheduler.step()
772
- optimizer.zero_grad(set_to_none=True)
773
-
774
- # Увеличиваем счетчик глобальных шагов
775
- global_step += 1
776
-
777
- # Обновляем прогресс-бар
778
- progress_bar.update(1)
779
-
780
- # Логируем метрики
781
- if accelerator.is_main_process:
782
- if fbp:
783
- current_lr = base_learning_rate
784
- else:
785
- current_lr = lr_scheduler.get_last_lr()[0]
786
- batch_losses.append(loss.detach().item())
787
- batch_tlosses.append(total_loss.detach().item())
788
- batch_grads.append(grad)
789
-
790
- # Логируем в Wandb
791
- if use_wandb and accelerator.sync_gradients:
792
- wandb.log({
793
- "mse_loss": loss.detach().item(),
794
- "learning_rate": current_lr,
795
- "epoch": epoch,
796
- "grad": grad,
797
- "global_step": global_step,
798
- **({"dispersive_loss": dispersive_loss} if dispersive_loss_enabled else {}),
799
- **({"total_loss": total_loss} if dispersive_loss_enabled else {})
800
- })
801
-
802
- # Генерируем сэмплы с заданным интервалом
803
- if global_step % sample_interval == 0:
804
- generate_and_save_samples(fixed_samples,global_step)
805
-
806
- # Выводим текущий лосс
807
- avg_loss = np.mean(batch_losses[-sample_interval:])
808
- avg_tloss = np.mean(batch_tlosses[-sample_interval:])
809
- avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item()
810
- print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}")
811
-
812
- if save_model:
813
- print("saving:",avg_loss < min_loss*save_barrier)
814
- if avg_loss < min_loss*save_barrier:
815
- min_loss = avg_loss
816
- save_checkpoint(unet)
817
- if use_wandb:
818
- wandb.log({"interm_loss": avg_loss})
819
- wandb.log({"interm_grad": avg_grad})
820
- if dispersive_loss_enabled:
821
- wandb.log({"interm_totalloss": avg_tloss})
822
-
823
- # По окончании эпохи
824
- #accelerator.wait_for_everyone()
825
- if accelerator.is_main_process:
826
- avg_epoch_loss = np.mean(batch_losses)
827
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
828
- if use_wandb:
829
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
830
-
831
- # Завершение обучения - сохраняем финальную модель
832
- if dispersive_loss:
833
- dispersive_hook.remove_hooks()
834
- if accelerator.is_main_process:
835
- print("Обучение завершено! Сохраняем финальную модель...")
836
- # Сохраняем основную модель
837
- if save_model:
838
- save_checkpoint(unet,"fp16")
839
- accelerator.free_memory()
840
- if torch.distributed.is_initialized():
841
- torch.distributed.destroy_process_group()
842
-
843
- print("Готово!")
844
-
845
- # randomize ode timesteps
846
- # input_timestep = torch.round(
847
- # F.sigmoid(torch.randn((n,), device=latents.device)), decimals=3
848
- # )
849
-
850
- #def create_distribution(num_points, device=None):
851
- # # Диапазон вероятностей на оси x
852
- # x = torch.linspace(0, 1, num_points, device=device)
853
-
854
- # Пользовательская функция плотности вероятности
855
- # probabilities = -7.7 * ((x - 0.5) ** 2) + 2
856
-
857
- # Нормализация, чтобы сумма равнялась 1
858
- # probabilities /= probabilities.sum()
859
-
860
- # return x, probabilities
861
-
862
- #def sample_from_distribution(x, probabilities, n, device=None):
863
- # Выбор индексов на основе распределения вероятностей
864
- # indices = torch.multinomial(probabilities, n, replacement=True)
865
- # return x[indices]
866
-
867
- # Пример использования
868
- #num_points = 1000 # Количество точек в диапазоне
869
- #n = latents.shape[0] # Количество временных шагов для выборки
870
- #x, probabilities = create_distribution(num_points, device=latents.device)
871
- #timesteps = sample_from_distribution(x, probabilities, n, device=latents.device)
872
-
873
- # Преобразование в формат, подходящий для вашего кода
874
- #timesteps = (timesteps * (scheduler.config.num_train_timesteps - 1)).long()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py CHANGED
@@ -23,6 +23,7 @@ from diffusers.models.attention_processor import AttnProcessor2_0
23
  from datetime import datetime
24
  import bitsandbytes as bnb
25
  import torch.nn.functional as F
 
26
 
27
  # --------------------------- Параметры ---------------------------
28
  ds_path = "datasets/384"
@@ -43,7 +44,6 @@ unet_gradient = True
43
  clip_sample = False #Scheduler
44
  fixed_seed = False
45
  shuffle = True
46
- dispersive_loss_enabled = True
47
  torch.backends.cuda.matmul.allow_tf32 = True
48
  torch.backends.cudnn.allow_tf32 = True
49
  torch.backends.cuda.enable_mem_efficient_sdp(False)
@@ -86,14 +86,81 @@ if fixed_seed:
86
  if torch.cuda.is_available():
87
  torch.cuda.manual_seed_all(seed)
88
 
 
 
 
 
 
 
 
 
 
 
89
  # --------------------------- Параметры LoRA ---------------------------
90
- # pip install peft
91
- lora_name = "" #"nusha" # Имя для сохранения/загрузки LoRA адаптеров
92
- lora_rank = 32 # Ранг LoRA (чем меньше, тем компактнее модель)
93
- lora_alpha = 64 # Альфа параметр LoRA, определяющий масштаб
94
 
95
  print("init")
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  class AccelerateDispersiveLoss:
98
  def __init__(self, accelerator, temperature=0.5, weight=0.5):
99
  self.accelerator = accelerator
@@ -113,35 +180,26 @@ class AccelerateDispersiveLoss:
113
  break
114
 
115
  def hook_fn(self, module, input, output):
116
-
117
  if isinstance(output, tuple):
118
  activation = output[0]
119
  else:
120
  activation = output
121
-
122
  if len(activation.shape) > 2:
123
  activation = activation.view(activation.shape[0], -1)
124
-
125
- self.activations.append(activation.detach())
126
 
127
  def compute_dispersive_loss(self):
128
- if not self.activations:
129
- return torch.tensor(0.0, requires_grad=True)
130
-
131
- local_activations = self.activations[-1].float()
132
-
133
- batch_size = local_activations.shape[0]
134
- if batch_size < 2:
135
- return torch.tensor(0.0, requires_grad=True)
136
-
137
- # Нормализация и вычисление loss
138
- sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
139
- distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
140
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
141
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
142
-
143
- # ВАЖНО: он отриц и должен падать
144
- return dispersive_loss
145
 
146
  def clear_activations(self):
147
  self.activations.clear()
@@ -152,7 +210,6 @@ class AccelerateDispersiveLoss:
152
  self.hooks.clear()
153
 
154
 
155
-
156
  # --------------------------- Инициализация WandB ---------------------------
157
  if use_wandb and accelerator.is_main_process:
158
  wandb.init(project=project+lora_name, config={
@@ -170,14 +227,14 @@ gen = torch.Generator(device=device)
170
  gen.manual_seed(seed)
171
 
172
  # --------------------------- Загрузка моделей ---------------------------
173
- # VAE загружается на CPU для экономии GPU-памяти
174
- vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to("cpu").eval()
175
 
176
  # DDPMScheduler с V_Prediction и Zero-SNR
177
  scheduler = DDPMScheduler(
178
- num_train_timesteps=1000, # Полный график шагов для обучения
179
- prediction_type="v_prediction", # V-Prediction
180
- rescale_betas_zero_snr=True, # Включение Zero-SNR
181
  clip_sample = clip_sample,
182
  steps_offset = steps_offset
183
  )
@@ -193,7 +250,6 @@ class DistributedResolutionBatchSampler(Sampler):
193
  self.drop_last = drop_last
194
  self.epoch = 0
195
 
196
- # Используем numpy для ускорения
197
  try:
198
  widths = np.array(dataset["width"])
199
  heights = np.array(dataset["height"])
@@ -201,16 +257,12 @@ class DistributedResolutionBatchSampler(Sampler):
201
  widths = np.zeros(len(dataset))
202
  heights = np.zeros(len(dataset))
203
 
204
- # Создаем уникальные ключи для размеров
205
  self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
206
-
207
- # Группируем индексы по размерам используя numpy
208
  self.size_groups = {}
209
  for w, h in self.size_keys:
210
  mask = (widths == w) & (heights == h)
211
  self.size_groups[(w, h)] = np.where(mask)[0]
212
 
213
- # Предварительно вычисляем количество полных батчей для каждой группы
214
  self.group_num_batches = {}
215
  total_batches = 0
216
  for size, indices in self.size_groups.items():
@@ -218,46 +270,31 @@ class DistributedResolutionBatchSampler(Sampler):
218
  self.group_num_batches[size] = num_full_batches
219
  total_batches += num_full_batches
220
 
221
- # Округляем до числа, делящегося на num_replicas
222
  self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
223
 
224
  def __iter__(self):
225
- # print(f"Rank {self.rank}: Starting iteration")
226
- # Очищаем CUDA кэш перед формированием новых батчей
227
  if torch.cuda.is_available():
228
  torch.cuda.empty_cache()
229
  all_batches = []
230
  rng = np.random.RandomState(self.epoch)
231
 
232
  for size, indices in self.size_groups.items():
233
- # print(f"Rank {self.rank}: Processing size {size}, {len(indices)} samples")
234
  indices = indices.copy()
235
  if self.shuffle:
236
  rng.shuffle(indices)
237
-
238
  num_full_batches = self.group_num_batches[size]
239
  if num_full_batches == 0:
240
  continue
241
-
242
- # Берем только индексы для полных батчей
243
  valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
244
-
245
- # Reshape для быстрого разделения на батчи
246
  batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
247
-
248
- # Выбираем часть для текущего GPU
249
  start_idx = self.rank * self.batch_size
250
  end_idx = start_idx + self.batch_size
251
  gpu_batches = batches[:, start_idx:end_idx]
252
-
253
  all_batches.extend(gpu_batches)
254
 
255
  if self.shuffle:
256
  rng.shuffle(all_batches)
257
-
258
- # Синхронизируем все процессы после формирования батчей
259
  accelerator.wait_for_everyone()
260
- # print(f"Rank {self.rank}: Created {len(all_batches)} batches")
261
  return iter(all_batches)
262
 
263
  def __len__(self):
@@ -268,8 +305,6 @@ class DistributedResolutionBatchSampler(Sampler):
268
 
269
  # Функция для выборки фиксированных семплов по размерам
270
  def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
271
- """Выбирает фиксированные семплы для каждого уникального разрешения"""
272
- # Группируем по размерам
273
  size_groups = defaultdict(list)
274
  try:
275
  widths = dataset["width"]
@@ -281,26 +316,18 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
281
  size = (w, h)
282
  size_groups[size].append(i)
283
 
284
- # Выбираем фиксированные примеры из каждой группы
285
  fixed_samples = {}
286
  for size, indices in size_groups.items():
287
- # Определяем сколько семплов брать из этой группы
288
  n_samples = min(samples_per_group, len(indices))
289
  if len(size_groups)==1:
290
  n_samples = samples_to_generate
291
  if n_samples == 0:
292
  continue
293
-
294
- # Выбираем случайные индексы
295
  sample_indices = random.sample(indices, n_samples)
296
  samples_data = [dataset[idx] for idx in sample_indices]
297
-
298
- # Собираем данные
299
  latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
300
  embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
301
  texts = [item["text"] for item in samples_data]
302
-
303
- # Сохраняем для этого размера
304
  fixed_samples[size] = (latents, embeddings, texts)
305
 
306
  print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
@@ -312,7 +339,6 @@ else:
312
  dataset = load_from_disk(ds_path)
313
 
314
  def collate_fn_simple(batch):
315
- # Преобразуем список в тензоры и перемещаем на девайс
316
  latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
317
  embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
318
  return latents, embeddings
@@ -320,12 +346,8 @@ def collate_fn_simple(batch):
320
  def collate_fn(batch):
321
  if not batch:
322
  return [], []
323
-
324
- # Берем эталонную форму
325
  ref_vae_shape = np.array(batch[0]["vae"]).shape
326
  ref_embed_shape = np.array(batch[0]["embeddings"]).shape
327
-
328
- # Фильтруем
329
  valid_latents = []
330
  valid_embeddings = []
331
  for item in batch:
@@ -333,14 +355,10 @@ def collate_fn(batch):
333
  np.array(item["embeddings"]).shape == ref_embed_shape):
334
  valid_latents.append(item["vae"])
335
  valid_embeddings.append(item["embeddings"])
336
-
337
- # Создаем тензоры
338
  latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype)
339
  embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
340
-
341
  return latents, embeddings
342
 
343
- # Создаем ResolutionBatchSampler на основе индексов от DistributedSampler
344
  batch_sampler = DistributedResolutionBatchSampler(
345
  dataset=dataset,
346
  batch_size=batch_size,
@@ -349,71 +367,53 @@ batch_sampler = DistributedResolutionBatchSampler(
349
  shuffle=shuffle
350
  )
351
 
352
- # Создаем DataLoader
353
  dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
354
-
355
  print("Total samples",len(dataloader))
356
  dataloader = accelerator.prepare(dataloader)
357
 
358
- # Инициализация переменных для возобновления обучения
359
  start_epoch = 0
360
  global_step = 0
361
-
362
- # Расчёт общего количества шагов
363
  total_training_steps = (len(dataloader) * num_epochs)
364
- # Get the world size
365
  world_size = accelerator.state.num_processes
366
- #print(f"World Size: {world_size}")
367
 
368
  # Опция загрузки модели из последнего чекпоинта (если существует)
369
  latest_checkpoint = os.path.join(checkpoints_folder, project)
370
  if os.path.isdir(latest_checkpoint):
371
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
372
- #if dtype == torch.float32:
373
  unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
374
- #else:
375
- #unet = UNet2DConditionModel.from_pretrained(latest_checkpoint, variant="fp16").to(device=device,dtype=dtype)
 
 
 
376
  if unet_gradient:
377
  unet.enable_gradient_checkpointing()
378
- unet.set_use_memory_efficient_attention_xformers(False) # отключаем xformers
379
  try:
380
- unet.set_attn_processor(AttnProcessor2_0()) # Используем стандартный AttnProcessor
381
  except Exception as e:
382
  print(f"Ошибка при включении SDPA: {e}")
383
- print("Попытка использовать enable_xformers_memory_efficient_attention.")
384
  unet.set_use_memory_efficient_attention_xformers(True)
385
 
386
- if hasattr(torch.backends.cuda, "flash_sdp_enabled"):
387
- print(f"torch.backends.cuda.flash_sdp_enabled(): {torch.backends.cuda.flash_sdp_enabled()}")
388
- if hasattr(torch.backends.cuda, "mem_efficient_sdp_enabled"):
389
- print(f"torch.backends.cuda.mem_efficient_sdp_enabled(): {torch.backends.cuda.mem_efficient_sdp_enabled()}")
390
- if hasattr(torch.nn.functional, "get_flash_attention_available"):
391
- print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
392
-
393
- # Регистрируем хук на модел
394
- if dispersive_loss_enabled:
395
  dispersive_hook = AccelerateDispersiveLoss(
396
  accelerator=accelerator,
397
  temperature=dispersive_temperature,
398
  weight=dispersive_weight
399
  )
400
-
401
- if torch_compile:
402
- print("compiling")
403
- torch.set_float32_matmul_precision('high')
404
- unet = torch.compile(unet, mode="reduce-overhead", fullgraph=False)
405
- print("compiling - ok")
406
 
407
  if lora_name:
408
  print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
409
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
410
  from peft.tuners.lora import LoraModel
411
  import os
412
- # 1. Замораживаем все параметры UNet
413
  unet.requires_grad_(False)
414
  print("Параметры базового UNet заморожены.")
415
 
416
- # 2. Создаем конфигурацию LoRA
417
  lora_config = LoraConfig(
418
  r=lora_rank,
419
  lora_alpha=lora_alpha,
@@ -421,51 +421,33 @@ if lora_name:
421
  )
422
  unet.add_adapter(lora_config)
423
 
424
- # 3. Оборачиваем UNet в PEFT-модель
425
  from peft import get_peft_model
426
-
427
  peft_unet = get_peft_model(unet, lora_config)
428
-
429
- # 4. Получаем параметры для оптимизации
430
  params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
431
-
432
 
433
- # 5. Выводим информацию о количестве параметров
434
  if accelerator.is_main_process:
435
  lora_params_count = sum(p.numel() for p in params_to_optimize)
436
  total_params_count = sum(p.numel() for p in unet.parameters())
437
  print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
438
  print(f"Общее количество параметров UNet: {total_params_count:,}")
439
 
440
- # 6. Путь для сохранения
441
  lora_save_path = os.path.join("lora", lora_name)
442
- os.makedirs(lora_save_path, exist_ok=True)
443
 
444
- # 7. Функция для сохранения
445
  def save_lora_checkpoint(model):
446
  if accelerator.is_main_process:
447
  print(f"Сохраняем LoRA ��даптеры в {lora_save_path}")
448
  from peft.utils.save_and_load import get_peft_model_state_dict
449
- # Получаем state_dict только LoRA
450
  lora_state_dict = get_peft_model_state_dict(model)
451
-
452
- # Сохраняем веса
453
  torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
454
-
455
- # Сохраняем конфиг
456
  model.peft_config["default"].save_pretrained(lora_save_path)
457
- # SDXL must be compatible
458
  from diffusers import StableDiffusionXLPipeline
459
  StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
460
 
461
  # --------------------------- Оптимизатор ---------------------------
462
- # Определяем параметры для оптимизации
463
- #unet = torch.compile(unet)
464
  if lora_name:
465
- # Если используется LoRA, оптимизируем только параметры LoRA
466
  trainable_params = [p for p in unet.parameters() if p.requires_grad]
467
  else:
468
- # Иначе оптимизируем все параметры
469
  if fbp:
470
  trainable_params = list(unet.parameters())
471
 
@@ -495,71 +477,48 @@ def create_optimizer(name, params):
495
  raise ValueError(f"Unknown optimizer: {name}")
496
 
497
  if fbp:
498
- # Создаем отдельный оптимизатор для каждого параметра
499
  optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
500
-
501
  def optimizer_hook(param):
502
  optimizer_dict[param].step()
503
  optimizer_dict[param].zero_grad(set_to_none=True)
504
-
505
  for param in trainable_params:
506
  param.register_post_accumulate_grad_hook(optimizer_hook)
507
-
508
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
509
  else:
510
  optimizer = create_optimizer(optimizer_type, unet.parameters())
511
-
512
  def lr_schedule(step):
513
  x = step / (total_training_steps * world_size)
514
  warmup = warmup_percent
515
-
516
  if not use_decay:
517
  return base_learning_rate
518
  if x < warmup:
519
  return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
520
-
521
  decay_ratio = (x - warmup) / (1 - warmup)
522
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
523
  (1 + math.cos(math.pi * decay_ratio))
524
-
525
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
526
 
527
  num_params = sum(p.numel() for p in unet.parameters())
528
  print(f"[rank {accelerator.process_index}] total params: {num_params}")
529
- # Проверка на NaN/Inf
530
  for name, param in unet.named_parameters():
531
  if torch.isnan(param).any() or torch.isinf(param).any():
532
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
533
- # Опционально: заменить на нормальные значения
534
- #param.data = torch.randn_like(param) * 0.01
535
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
536
-
537
  # Регистрация хуков ПОСЛЕ prepare
538
- if dispersive_loss_enabled:
539
  dispersive_hook.register_hooks(unet, "down_blocks.2")
540
 
541
  # --------------------------- Фиксированные семплы для генерации ---------------------------
542
- # Примеры фиксированных семплов по размерам
543
  fixed_samples = get_fixed_samples_by_resolution(dataset)
544
 
545
  @torch.compiler.disable()
546
  @torch.no_grad()
547
  def generate_and_save_samples(fixed_samples_cpu, step):
548
- """
549
- Генерирует семплы для каждого из разрешений и сохраняет их.
550
-
551
- Args:
552
- fixed_samples_cpu: Словарь, где ключи - размеры (width, height),
553
- а значения - кортежи (latents, embeddings, text) на CPU.
554
- step: Текущий шаг обучения
555
- """
556
- original_model = None # Инициализируем, чтобы finally не ругался
557
  try:
558
-
559
- original_model = accelerator.unwrap_model(unet).eval()
560
-
561
- vae.to(device=device, dtype=dtype)
562
- vae.eval()
563
 
564
  scheduler.set_timesteps(n_diffusion_steps)
565
 
@@ -568,40 +527,32 @@ def generate_and_save_samples(fixed_samples_cpu, step):
568
 
569
  for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
570
  width, height = size
 
 
571
 
572
- sample_latents = sample_latents.to(dtype=dtype)
573
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype)
574
-
575
- # Инициализируем латенты случайным шумом
576
- # sample_latents уже в dtype, так что noise будет создан в dtype
577
  noise = torch.randn(
578
- sample_latents.shape, # Используем форму от sample_latents, которые теперь на GPU и fp16
579
  generator=gen,
580
  device=device,
581
  dtype=sample_latents.dtype
582
  )
583
  current_latents = noise.clone()
584
 
585
- # Подготовка текстовых эмбеддингов для guidance
586
  if guidance_scale > 0:
587
- # empty_embeddings должны быть того же типа и на том же устройстве
588
  empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
589
  text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
590
  else:
591
  text_embeddings_batch = sample_text_embeddings
592
 
593
  for t in scheduler.timesteps:
594
- t_batch = t.repeat(current_latents.shape[0]).to(device) # Убедимся, что t на устройстве
595
-
596
  if guidance_scale > 0:
597
  latent_model_input = torch.cat([current_latents] * 2)
598
  else:
599
  latent_model_input = current_latents
600
 
601
  latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
602
-
603
- # Предсказание шума (UNet)
604
- noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
605
 
606
  if guidance_scale > 0:
607
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -609,20 +560,14 @@ def generate_and_save_samples(fixed_samples_cpu, step):
609
 
610
  current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
611
 
612
- #print(f"current_latents Min: {current_latents.min()} Max: {current_latents.max()}")
613
- # Декодирование через VAE
614
- latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + vae.config.shift_factor
615
- decoded = vae.decode(latent_for_vae).sample
616
 
617
- # Преобразуем тензоры в PIL-изображения
618
- # Для математики с изображением (нормализация) лучше перейти в fp32
619
  decoded_fp32 = decoded.to(torch.float32)
620
  for img_idx, img_tensor in enumerate(decoded_fp32):
621
  img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
622
- # If NaNs or infs are present, print them
623
  if np.isnan(img).any():
624
- print("NaNs found, saving stoped! Step:", step)
625
- save_model = False
626
  pil_img = Image.fromarray((img * 255).astype("uint8"))
627
 
628
  max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
@@ -645,17 +590,15 @@ def generate_and_save_samples(fixed_samples_cpu, step):
645
  for i, img in enumerate(all_generated_images)
646
  ]
647
  wandb.log({"generated_images": wandb_images, "global_step": step})
648
-
649
- finally:
650
- vae.to("cpu") # Перемещаем VAE обратно на CPU
651
- # Очистка переменных, которые являются тензорами и были созданы в функции
652
  for var in list(locals().keys()):
653
  if isinstance(locals()[var], torch.Tensor):
654
  del locals()[var]
655
-
656
  torch.cuda.empty_cache()
657
  gc.collect()
658
-
659
  # --------------------------- Генерация сэмплов перед обучением ---------------------------
660
  if accelerator.is_main_process:
661
  if save_model:
@@ -667,35 +610,53 @@ accelerator.wait_for_everyone()
667
  def save_checkpoint(unet,variant=""):
668
  if accelerator.is_main_process:
669
  if lora_name:
670
- # Сохраняем только LoRA адаптеры
671
  save_lora_checkpoint(unet)
672
  else:
673
- # Сохраняем полную модель
674
  if variant!="":
675
  accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
676
  else:
677
  accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
678
  unet = unet.to(dtype=dtype)
679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
  # --------------------------- Тренировочный цикл ---------------------------
681
- # Для логирования среднего лосса каждые % эпохи
682
  if accelerator.is_main_process:
683
  print(f"Total steps per GPU: {total_training_steps}")
684
 
685
  epoch_loss_points = []
686
  progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
687
 
688
- # Определяем интервал для сэмплирования и логирования в пределах эпохи (10% эпохи)
689
  steps_per_epoch = len(dataloader)
690
  sample_interval = max(1, steps_per_epoch // sample_interval_share)
691
  min_loss = 1.
692
 
693
- # Начинаем с указанной эпохи (полезно при возобновлении)
694
  for epoch in range(start_epoch, start_epoch + num_epochs):
695
  batch_losses = []
696
  batch_tlosses = []
697
  batch_grads = []
698
- #unet = unet.to(dtype = dtype)
699
  batch_sampler.set_epoch(epoch)
700
  accelerator.wait_for_everyone()
701
  unet.train()
@@ -706,107 +667,103 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
706
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
707
  print(f"Шаг {step}: {used_gb:.2f} GB")
708
 
709
- # Forward pass
710
  noise = torch.randn_like(latents, dtype=latents.dtype)
711
-
712
- timesteps = torch.randint(steps_offset, scheduler.config.num_train_timesteps,
713
- (latents.shape[0],), device=device).long()
714
-
715
- # Добавляем шум к латентам
 
 
 
 
 
716
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
717
 
718
- # Очищаем активации перед forward pass
719
- if dispersive_loss_enabled:
720
  dispersive_hook.clear_activations()
721
 
722
- # Используем целевое значение
723
  model_pred = unet(noisy_latents, timesteps, embeddings).sample
724
  target_pred = scheduler.get_velocity(latents, noise, timesteps)
725
 
726
- # Считаем лосс
727
- loss = torch.nn.functional.mse_loss(model_pred.float(), target_pred.float())
728
-
729
- # Dispersive Loss
730
- #Идентичные векторы: Loss = -0.0000
731
- #Ортогональные векторы: Loss = -3.9995
732
- if dispersive_loss_enabled:
733
- with torch.amp.autocast('cuda', enabled=False):
734
- dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
735
- if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss):
736
- print(f"Rank {accelerator.process_index}: Found nan/inf in dispersive_loss: {total_loss}")
737
-
738
- # Итоговый loss
739
- # dispersive_loss должен падать и тотал падать - поэтому плюс
740
- if dispersive_loss_enabled:
741
- total_loss = loss + dispersive_loss
742
  else:
743
- total_loss = loss
744
 
745
- # Проверяем на nan/inf перед backward
746
- if torch.isnan(loss) or torch.isinf(loss):
747
- print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
748
- save_model = False
749
- break
 
750
 
751
- if torch.isnan(total_loss) or torch.isinf(total_loss):
752
- print(f"Rank {accelerator.process_index}: Found nan/inf in total_loss: {total_loss}")
753
- print(f"Проблемный батч: step={step}, latents.shape={latents.shape}, embeddings.shape={embeddings.shape}")
754
- continue
755
-
756
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
757
  accelerator.wait_for_everyone()
758
 
759
- # Делаем backward через Accelerator
760
  accelerator.backward(total_loss)
761
 
762
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
763
  accelerator.wait_for_everyone()
764
 
765
- grad = torch.tensor(0.0, device=device)
766
  if not fbp:
767
  if accelerator.sync_gradients:
768
  with torch.amp.autocast('cuda', enabled=False):
769
- grad = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
 
770
  optimizer.step()
771
  lr_scheduler.step()
772
  optimizer.zero_grad(set_to_none=True)
773
 
774
- # Увеличиваем счетчик глобальных шагов
775
  global_step += 1
776
-
777
- # Обновляем прогресс-бар
778
  progress_bar.update(1)
779
-
780
  # Логируем метрики
781
  if accelerator.is_main_process:
782
  if fbp:
783
  current_lr = base_learning_rate
784
  else:
785
  current_lr = lr_scheduler.get_last_lr()[0]
786
- batch_losses.append(loss.detach().item())
787
  batch_tlosses.append(total_loss.detach().item())
788
  batch_grads.append(grad)
789
-
790
- # Логируем в Wandb
 
 
 
 
 
 
 
 
 
 
 
 
791
  if use_wandb and accelerator.sync_gradients:
792
- wandb.log({
793
- "mse_loss": loss.detach().item(),
794
- "learning_rate": current_lr,
795
- "epoch": epoch,
796
- "grad": grad,
797
- "global_step": global_step,
798
- **({"dispersive_loss": dispersive_loss} if dispersive_loss_enabled else {}),
799
- **({"total_loss": total_loss} if dispersive_loss_enabled else {})
800
- })
801
-
802
  # Генерируем сэмплы с заданным интервалом
803
  if global_step % sample_interval == 0:
804
  generate_and_save_samples(fixed_samples,global_step)
805
-
806
- # Выводим текущий лосс
807
- avg_loss = np.mean(batch_losses[-sample_interval:])
808
- avg_tloss = np.mean(batch_tlosses[-sample_interval:])
809
- avg_grad = torch.mean(torch.stack(batch_grads[-sample_interval:])).cpu().item()
810
  print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}")
811
 
812
  if save_model:
@@ -815,25 +772,23 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
815
  min_loss = avg_loss
816
  save_checkpoint(unet)
817
  if use_wandb:
818
- wandb.log({"interm_loss": avg_loss})
819
- wandb.log({"interm_grad": avg_grad})
820
- if dispersive_loss_enabled:
821
- wandb.log({"interm_totalloss": avg_tloss})
 
822
 
823
- # По окончании эпохи
824
- #accelerator.wait_for_everyone()
825
  if accelerator.is_main_process:
826
- avg_epoch_loss = np.mean(batch_losses)
827
  print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
828
  if use_wandb:
829
  wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
830
 
831
  # Завершение обучения - сохраняем финальную модель
832
- if dispersive_loss:
833
  dispersive_hook.remove_hooks()
834
  if accelerator.is_main_process:
835
  print("Обучение завершено! Сохраняем финальную модель...")
836
- # Сохраняем основную модель
837
  if save_model:
838
  save_checkpoint(unet,"fp16")
839
  accelerator.free_memory()
@@ -841,34 +796,3 @@ if torch.distributed.is_initialized():
841
  torch.distributed.destroy_process_group()
842
 
843
  print("Готово!")
844
-
845
- # randomize ode timesteps
846
- # input_timestep = torch.round(
847
- # F.sigmoid(torch.randn((n,), device=latents.device)), decimals=3
848
- # )
849
-
850
- #def create_distribution(num_points, device=None):
851
- # # Диапазон вероятностей на оси x
852
- # x = torch.linspace(0, 1, num_points, device=device)
853
-
854
- # Пользовательская функция плотности вероятности
855
- # probabilities = -7.7 * ((x - 0.5) ** 2) + 2
856
-
857
- # Нормализация, чтобы сумма равнялась 1
858
- # probabilities /= probabilities.sum()
859
-
860
- # return x, probabilities
861
-
862
- #def sample_from_distribution(x, probabilities, n, device=None):
863
- # Выбор индексов на основе распределения вероятностей
864
- # indices = torch.multinomial(probabilities, n, replacement=True)
865
- # return x[indices]
866
-
867
- # Пример использования
868
- #num_points = 1000 # Количество точек в диапазоне
869
- #n = latents.shape[0] # Количество временных шагов для выборки
870
- #x, probabilities = create_distribution(num_points, device=latents.device)
871
- #timesteps = sample_from_distribution(x, probabilities, n, device=latents.device)
872
-
873
- # Преобразование в формат, подходящий для вашего кода
874
- #timesteps = (timesteps * (scheduler.config.num_train_timesteps - 1)).long()
 
23
  from datetime import datetime
24
  import bitsandbytes as bnb
25
  import torch.nn.functional as F
26
+ from collections import deque
27
 
28
  # --------------------------- Параметры ---------------------------
29
  ds_path = "datasets/384"
 
44
  clip_sample = False #Scheduler
45
  fixed_seed = False
46
  shuffle = True
 
47
  torch.backends.cuda.matmul.allow_tf32 = True
48
  torch.backends.cudnn.allow_tf32 = True
49
  torch.backends.cuda.enable_mem_efficient_sdp(False)
 
86
  if torch.cuda.is_available():
87
  torch.cuda.manual_seed_all(seed)
88
 
89
+ # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
90
+ # CHANGED: добавлен huber и dispersive в пропорции, суммы = 1.0
91
+ loss_ratios = {
92
+ "mse": 0.50,
93
+ "mae": 0.25,
94
+ "huber": 0.20,
95
+ "dispersive": 0.05,
96
+ }
97
+ median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
98
+
99
  # --------------------------- Параметры LoRA ---------------------------
100
+ lora_name = ""
101
+ lora_rank = 32
102
+ lora_alpha = 64
 
103
 
104
  print("init")
105
 
106
+ # --------------------------- вспомогательные функции ---------------------------
107
+ def sample_timesteps_bias(
108
+ batch_size: int,
109
+ progress: float, # [0..1]
110
+ num_train_timesteps: int, # обычно 1000
111
+ steps_offset: int = 0,
112
+ device=None
113
+ ) -> torch.Tensor:
114
+ """
115
+ Возвращает псевдослучайные timesteps во всём диапазоне,
116
+ но с bias: на старте больше вероятности брать max (999),
117
+ к концу — больше вероятности брать min (0).
118
+
119
+ FIX: исправлена формула alpha/beta (раньше было перевёрнуто).
120
+ """
121
+ # Параметры Beta-распределения (FIX: alpha и beta поменяны местами по логике)
122
+ alpha = 1.0 + 4.0 * (1.0 - progress) # при progress=0 -> alpha ~10 (сдвиг к 1.0)
123
+ beta = 1.0 + 4.0 * progress # при progress=0 -> beta ~1
124
+
125
+ samples = torch.distributions.Beta(alpha, beta).sample((batch_size,)).to(device)
126
+
127
+ max_idx = num_train_timesteps - 1 - steps_offset
128
+ timesteps = steps_offset + (samples * max_idx).long()
129
+ return timesteps
130
+
131
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
132
+ class MedianLossNormalizer:
133
+ def __init__(self, desired_ratios: dict, window_steps: int):
134
+ # нормируем доли на случай, если сумма != 1
135
+ s = sum(desired_ratios.values())
136
+ self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
137
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
138
+ self.window = window_steps
139
+
140
+ def update_and_total(self, losses: dict):
141
+ """
142
+ losses: dict ключ->тензор (значения лоссов)
143
+ Поведение:
144
+ - буферим ABS(l) только для активных (ratio>0) лоссов
145
+ - coeff = ratio / median(abs(loss))
146
+ - total = sum(coeff * loss) по активным лоссам
147
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
148
+ """
149
+ # буферим только активные лоссы
150
+ for k, v in losses.items():
151
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
152
+ self.buffers[k].append(float(v.detach().abs().cpu()))
153
+
154
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
155
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
156
+
157
+ # суммируем только по активным (ratio>0)
158
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
159
+ return total, coeffs, meds
160
+
161
+ # создаём normalizer после определения loss_ratios
162
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
163
+
164
  class AccelerateDispersiveLoss:
165
  def __init__(self, accelerator, temperature=0.5, weight=0.5):
166
  self.accelerator = accelerator
 
180
  break
181
 
182
  def hook_fn(self, module, input, output):
 
183
  if isinstance(output, tuple):
184
  activation = output[0]
185
  else:
186
  activation = output
 
187
  if len(activation.shape) > 2:
188
  activation = activation.view(activation.shape[0], -1)
189
+ self.activations.append(activation.detach().clone())
 
190
 
191
  def compute_dispersive_loss(self):
192
+ if not self.activations:
193
+ return torch.tensor(0.0, requires_grad=True, device=device)
194
+ local_activations = self.activations[-1].float()
195
+ batch_size = local_activations.shape[0]
196
+ if batch_size < 2:
197
+ return torch.tensor(0.0, requires_grad=True, device=device)
198
+ sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
199
+ distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
200
+ exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
201
+ dispersive_loss = torch.log(torch.mean(exp_neg_dist))
202
+ return dispersive_loss
 
 
 
 
 
 
203
 
204
  def clear_activations(self):
205
  self.activations.clear()
 
210
  self.hooks.clear()
211
 
212
 
 
213
  # --------------------------- Инициализация WandB ---------------------------
214
  if use_wandb and accelerator.is_main_process:
215
  wandb.init(project=project+lora_name, config={
 
227
  gen.manual_seed(seed)
228
 
229
  # --------------------------- Загрузка моделей ---------------------------
230
+ # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
231
+ vae = AutoencoderKL.from_pretrained("vae", variant="fp16").to(device="cpu", dtype=torch.float16).eval()
232
 
233
  # DDPMScheduler с V_Prediction и Zero-SNR
234
  scheduler = DDPMScheduler(
235
+ num_train_timesteps=1000,
236
+ prediction_type="v_prediction",
237
+ rescale_betas_zero_snr=True,
238
  clip_sample = clip_sample,
239
  steps_offset = steps_offset
240
  )
 
250
  self.drop_last = drop_last
251
  self.epoch = 0
252
 
 
253
  try:
254
  widths = np.array(dataset["width"])
255
  heights = np.array(dataset["height"])
 
257
  widths = np.zeros(len(dataset))
258
  heights = np.zeros(len(dataset))
259
 
 
260
  self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
 
 
261
  self.size_groups = {}
262
  for w, h in self.size_keys:
263
  mask = (widths == w) & (heights == h)
264
  self.size_groups[(w, h)] = np.where(mask)[0]
265
 
 
266
  self.group_num_batches = {}
267
  total_batches = 0
268
  for size, indices in self.size_groups.items():
 
270
  self.group_num_batches[size] = num_full_batches
271
  total_batches += num_full_batches
272
 
 
273
  self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
274
 
275
  def __iter__(self):
 
 
276
  if torch.cuda.is_available():
277
  torch.cuda.empty_cache()
278
  all_batches = []
279
  rng = np.random.RandomState(self.epoch)
280
 
281
  for size, indices in self.size_groups.items():
 
282
  indices = indices.copy()
283
  if self.shuffle:
284
  rng.shuffle(indices)
 
285
  num_full_batches = self.group_num_batches[size]
286
  if num_full_batches == 0:
287
  continue
 
 
288
  valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
 
 
289
  batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
 
 
290
  start_idx = self.rank * self.batch_size
291
  end_idx = start_idx + self.batch_size
292
  gpu_batches = batches[:, start_idx:end_idx]
 
293
  all_batches.extend(gpu_batches)
294
 
295
  if self.shuffle:
296
  rng.shuffle(all_batches)
 
 
297
  accelerator.wait_for_everyone()
 
298
  return iter(all_batches)
299
 
300
  def __len__(self):
 
305
 
306
  # Функция для выборки фиксированных семплов по размерам
307
  def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
 
 
308
  size_groups = defaultdict(list)
309
  try:
310
  widths = dataset["width"]
 
316
  size = (w, h)
317
  size_groups[size].append(i)
318
 
 
319
  fixed_samples = {}
320
  for size, indices in size_groups.items():
 
321
  n_samples = min(samples_per_group, len(indices))
322
  if len(size_groups)==1:
323
  n_samples = samples_to_generate
324
  if n_samples == 0:
325
  continue
 
 
326
  sample_indices = random.sample(indices, n_samples)
327
  samples_data = [dataset[idx] for idx in sample_indices]
 
 
328
  latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
329
  embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
330
  texts = [item["text"] for item in samples_data]
 
 
331
  fixed_samples[size] = (latents, embeddings, texts)
332
 
333
  print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
 
339
  dataset = load_from_disk(ds_path)
340
 
341
  def collate_fn_simple(batch):
 
342
  latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
343
  embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
344
  return latents, embeddings
 
346
  def collate_fn(batch):
347
  if not batch:
348
  return [], []
 
 
349
  ref_vae_shape = np.array(batch[0]["vae"]).shape
350
  ref_embed_shape = np.array(batch[0]["embeddings"]).shape
 
 
351
  valid_latents = []
352
  valid_embeddings = []
353
  for item in batch:
 
355
  np.array(item["embeddings"]).shape == ref_embed_shape):
356
  valid_latents.append(item["vae"])
357
  valid_embeddings.append(item["embeddings"])
 
 
358
  latents = torch.tensor(np.array(valid_latents)).to(device,dtype=dtype)
359
  embeddings = torch.tensor(np.array(valid_embeddings)).to(device,dtype=dtype)
 
360
  return latents, embeddings
361
 
 
362
  batch_sampler = DistributedResolutionBatchSampler(
363
  dataset=dataset,
364
  batch_size=batch_size,
 
367
  shuffle=shuffle
368
  )
369
 
 
370
  dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
 
371
  print("Total samples",len(dataloader))
372
  dataloader = accelerator.prepare(dataloader)
373
 
 
374
  start_epoch = 0
375
  global_step = 0
 
 
376
  total_training_steps = (len(dataloader) * num_epochs)
 
377
  world_size = accelerator.state.num_processes
 
378
 
379
  # Опция загрузки модели из последнего чекпоинта (если существует)
380
  latest_checkpoint = os.path.join(checkpoints_folder, project)
381
  if os.path.isdir(latest_checkpoint):
382
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
 
383
  unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
384
+ if torch_compile:
385
+ print("compiling")
386
+ torch.set_float32_matmul_precision('high')
387
+ unet = torch.compile(unet)
388
+ print("compiling - ok")
389
  if unet_gradient:
390
  unet.enable_gradient_checkpointing()
391
+ unet.set_use_memory_efficient_attention_xformers(False)
392
  try:
393
+ unet.set_attn_processor(AttnProcessor2_0())
394
  except Exception as e:
395
  print(f"Ошибка при включении SDPA: {e}")
 
396
  unet.set_use_memory_efficient_attention_xformers(True)
397
 
398
+ # Создаём hook для dispersive только если нужно
399
+ if loss_ratios.get("dispersive", 0) > 0:
 
 
 
 
 
 
 
400
  dispersive_hook = AccelerateDispersiveLoss(
401
  accelerator=accelerator,
402
  temperature=dispersive_temperature,
403
  weight=dispersive_weight
404
  )
405
+ else:
406
+ # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
407
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
 
 
 
408
 
409
  if lora_name:
410
  print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
411
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
412
  from peft.tuners.lora import LoraModel
413
  import os
 
414
  unet.requires_grad_(False)
415
  print("Параметры базового UNet заморожены.")
416
 
 
417
  lora_config = LoraConfig(
418
  r=lora_rank,
419
  lora_alpha=lora_alpha,
 
421
  )
422
  unet.add_adapter(lora_config)
423
 
 
424
  from peft import get_peft_model
 
425
  peft_unet = get_peft_model(unet, lora_config)
 
 
426
  params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
 
427
 
 
428
  if accelerator.is_main_process:
429
  lora_params_count = sum(p.numel() for p in params_to_optimize)
430
  total_params_count = sum(p.numel() for p in unet.parameters())
431
  print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
432
  print(f"Общее количество параметров UNet: {total_params_count:,}")
433
 
 
434
  lora_save_path = os.path.join("lora", lora_name)
435
+ os.makedirs(lora_save_path, exist_ok=True)
436
 
 
437
  def save_lora_checkpoint(model):
438
  if accelerator.is_main_process:
439
  print(f"Сохраняем LoRA ��даптеры в {lora_save_path}")
440
  from peft.utils.save_and_load import get_peft_model_state_dict
 
441
  lora_state_dict = get_peft_model_state_dict(model)
 
 
442
  torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
 
 
443
  model.peft_config["default"].save_pretrained(lora_save_path)
 
444
  from diffusers import StableDiffusionXLPipeline
445
  StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
446
 
447
  # --------------------------- Оптимизатор ---------------------------
 
 
448
  if lora_name:
 
449
  trainable_params = [p for p in unet.parameters() if p.requires_grad]
450
  else:
 
451
  if fbp:
452
  trainable_params = list(unet.parameters())
453
 
 
477
  raise ValueError(f"Unknown optimizer: {name}")
478
 
479
  if fbp:
 
480
  optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
 
481
  def optimizer_hook(param):
482
  optimizer_dict[param].step()
483
  optimizer_dict[param].zero_grad(set_to_none=True)
 
484
  for param in trainable_params:
485
  param.register_post_accumulate_grad_hook(optimizer_hook)
 
486
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
487
  else:
488
  optimizer = create_optimizer(optimizer_type, unet.parameters())
 
489
  def lr_schedule(step):
490
  x = step / (total_training_steps * world_size)
491
  warmup = warmup_percent
 
492
  if not use_decay:
493
  return base_learning_rate
494
  if x < warmup:
495
  return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
 
496
  decay_ratio = (x - warmup) / (1 - warmup)
497
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
498
  (1 + math.cos(math.pi * decay_ratio))
 
499
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
500
 
501
  num_params = sum(p.numel() for p in unet.parameters())
502
  print(f"[rank {accelerator.process_index}] total params: {num_params}")
 
503
  for name, param in unet.named_parameters():
504
  if torch.isnan(param).any() or torch.isinf(param).any():
505
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
 
 
506
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
507
+
508
  # Регистрация хуков ПОСЛЕ prepare
509
+ if loss_ratios.get("dispersive", 0) > 0:
510
  dispersive_hook.register_hooks(unet, "down_blocks.2")
511
 
512
  # --------------------------- Фиксированные семплы для генерации ---------------------------
 
513
  fixed_samples = get_fixed_samples_by_resolution(dataset)
514
 
515
  @torch.compiler.disable()
516
  @torch.no_grad()
517
  def generate_and_save_samples(fixed_samples_cpu, step):
518
+ original_model = None
 
 
 
 
 
 
 
 
519
  try:
520
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
521
+ vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
 
 
 
522
 
523
  scheduler.set_timesteps(n_diffusion_steps)
524
 
 
527
 
528
  for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
529
  width, height = size
530
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
531
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
532
 
 
 
 
 
 
533
  noise = torch.randn(
534
+ sample_latents.shape,
535
  generator=gen,
536
  device=device,
537
  dtype=sample_latents.dtype
538
  )
539
  current_latents = noise.clone()
540
 
 
541
  if guidance_scale > 0:
 
542
  empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
543
  text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
544
  else:
545
  text_embeddings_batch = sample_text_embeddings
546
 
547
  for t in scheduler.timesteps:
548
+ t_batch = t.repeat(current_latents.shape[0]).to(device)
 
549
  if guidance_scale > 0:
550
  latent_model_input = torch.cat([current_latents] * 2)
551
  else:
552
  latent_model_input = current_latents
553
 
554
  latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
555
+ noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
 
 
556
 
557
  if guidance_scale > 0:
558
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
 
560
 
561
  current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
562
 
563
+ latent_for_vae = (current_latents.detach() / vae.config.scaling_factor) + getattr(vae.config, "shift_factor", 0.0)
564
+ decoded = vae.decode(latent_for_vae.to(torch.float16)).sample
 
 
565
 
 
 
566
  decoded_fp32 = decoded.to(torch.float32)
567
  for img_idx, img_tensor in enumerate(decoded_fp32):
568
  img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
 
569
  if np.isnan(img).any():
570
+ print("NaNs found, saving stopped! Step:", step)
 
571
  pil_img = Image.fromarray((img * 255).astype("uint8"))
572
 
573
  max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
 
590
  for i, img in enumerate(all_generated_images)
591
  ]
592
  wandb.log({"generated_images": wandb_images, "global_step": step})
593
+ finally:
594
+ # вернуть VAE на CPU (как было в твоём коде)
595
+ vae.to("cpu")
 
596
  for var in list(locals().keys()):
597
  if isinstance(locals()[var], torch.Tensor):
598
  del locals()[var]
 
599
  torch.cuda.empty_cache()
600
  gc.collect()
601
+
602
  # --------------------------- Генерация сэмплов перед обучением ---------------------------
603
  if accelerator.is_main_process:
604
  if save_model:
 
610
  def save_checkpoint(unet,variant=""):
611
  if accelerator.is_main_process:
612
  if lora_name:
 
613
  save_lora_checkpoint(unet)
614
  else:
 
615
  if variant!="":
616
  accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
617
  else:
618
  accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
619
  unet = unet.to(dtype=dtype)
620
 
621
+ def batch_pred_original_from_step(model_outputs, timesteps_tensor, noisy_latents, scheduler):
622
+ device = noisy_latents.device
623
+ dtype = noisy_latents.dtype
624
+
625
+ available_ts = scheduler.timesteps
626
+ if not isinstance(available_ts, torch.Tensor):
627
+ available_ts = torch.tensor(available_ts, device="cpu")
628
+ else:
629
+ available_ts = available_ts.cpu()
630
+
631
+ B = model_outputs.shape[0]
632
+ preds = []
633
+ for i in range(B):
634
+ t_i = int(timesteps_tensor[i].item())
635
+ diffs = torch.abs(available_ts - t_i)
636
+ idx = int(torch.argmin(diffs).item())
637
+ t_for_step = int(available_ts[idx].item())
638
+ model_out_i = model_outputs[i:i+1]
639
+ noisy_latent_i = noisy_latents[i:i+1]
640
+ step_out = scheduler.step(model_out_i, t_for_step, noisy_latent_i)
641
+ preds.append(step_out.pred_original_sample)
642
+
643
+ return torch.cat(preds, dim=0).to(device=device, dtype=dtype)
644
+
645
  # --------------------------- Тренировочный цикл ---------------------------
 
646
  if accelerator.is_main_process:
647
  print(f"Total steps per GPU: {total_training_steps}")
648
 
649
  epoch_loss_points = []
650
  progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
651
 
 
652
  steps_per_epoch = len(dataloader)
653
  sample_interval = max(1, steps_per_epoch // sample_interval_share)
654
  min_loss = 1.
655
 
 
656
  for epoch in range(start_epoch, start_epoch + num_epochs):
657
  batch_losses = []
658
  batch_tlosses = []
659
  batch_grads = []
 
660
  batch_sampler.set_epoch(epoch)
661
  accelerator.wait_for_everyone()
662
  unet.train()
 
667
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
668
  print(f"Шаг {step}: {used_gb:.2f} GB")
669
 
 
670
  noise = torch.randn_like(latents, dtype=latents.dtype)
671
+
672
+ progress = global_step / max(1, total_training_steps - 1)
673
+ timesteps = sample_timesteps_bias(
674
+ batch_size=latents.shape[0],
675
+ progress=progress,
676
+ num_train_timesteps=scheduler.config.num_train_timesteps,
677
+ steps_offset=steps_offset,
678
+ device=device
679
+ )
680
+
681
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
682
 
683
+ if loss_ratios.get("dispersive", 0) > 0:
 
684
  dispersive_hook.clear_activations()
685
 
 
686
  model_pred = unet(noisy_latents, timesteps, embeddings).sample
687
  target_pred = scheduler.get_velocity(latents, noise, timesteps)
688
 
689
+ # === Losses ===
690
+ losses_dict = {}
691
+
692
+ mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
693
+ losses_dict["mse"] = mse_loss
694
+ losses_dict["mae"] = F.l1_loss(model_pred.float(), target_pred.float())
695
+
696
+ # CHANGED: Huber (smooth_l1) loss added
697
+ losses_dict["huber"] = F.smooth_l1_loss(model_pred.float(), target_pred.float())
698
+
699
+ # === Dispersive loss ===
700
+ if loss_ratios.get("dispersive", 0) > 0:
701
+ disp_raw = dispersive_hook.compute_dispersive_loss().to(device) # может быть отрицательным
702
+ losses_dict["dispersive"] = dispersive_hook.weight * disp_raw
 
 
703
  else:
704
+ losses_dict["dispersive"] = torch.tensor(0.0, device=device)
705
 
706
+ # === Нормализация всех лоссов ===
707
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
708
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
709
+
710
+ # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
711
+ batch_losses.append(mse_loss.detach().item())
712
 
 
 
 
 
 
713
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
714
  accelerator.wait_for_everyone()
715
 
716
+ # Backward
717
  accelerator.backward(total_loss)
718
 
719
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
720
  accelerator.wait_for_everyone()
721
 
722
+ grad = 0.0
723
  if not fbp:
724
  if accelerator.sync_gradients:
725
  with torch.amp.autocast('cuda', enabled=False):
726
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
727
+ grad = float(grad_val)
728
  optimizer.step()
729
  lr_scheduler.step()
730
  optimizer.zero_grad(set_to_none=True)
731
 
 
732
  global_step += 1
 
 
733
  progress_bar.update(1)
734
+
735
  # Логируем метрики
736
  if accelerator.is_main_process:
737
  if fbp:
738
  current_lr = base_learning_rate
739
  else:
740
  current_lr = lr_scheduler.get_last_lr()[0]
 
741
  batch_tlosses.append(total_loss.detach().item())
742
  batch_grads.append(grad)
743
+
744
+ # Логируем только активные лоссы (ratio>0)
745
+ active_keys = [k for k, v in loss_ratios.items() if v > 0]
746
+ log_data = {}
747
+ for k in active_keys:
748
+ v = losses_dict.get(k, None)
749
+ if v is None:
750
+ continue
751
+ log_data[f"loss/{k}"] = (v.item() if isinstance(v, torch.Tensor) else float(v))
752
+
753
+ log_data["loss/total"] = float(total_loss.item())
754
+ log_data["loss/lr"] = current_lr
755
+ for k, c in coeffs.items():
756
+ log_data[f"coeff/{k}"] = float(c)
757
  if use_wandb and accelerator.sync_gradients:
758
+ wandb.log(log_data, step=global_step)
759
+
 
 
 
 
 
 
 
 
760
  # Генерируем сэмплы с заданным интервалом
761
  if global_step % sample_interval == 0:
762
  generate_and_save_samples(fixed_samples,global_step)
763
+ last_n = sample_interval
764
+ avg_loss = float(np.mean(batch_losses[-last_n:])) if len(batch_losses) > 0 else 0.0
765
+ avg_tloss = float(np.mean(batch_tlosses[-last_n:])) if len(batch_tlosses) > 0 else 0.0
766
+ avg_grad = float(np.mean(batch_grads[-last_n:])) if len(batch_grads) > 0 else 0.0
 
767
  print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}")
768
 
769
  if save_model:
 
772
  min_loss = avg_loss
773
  save_checkpoint(unet)
774
  if use_wandb:
775
+ avg_data = {}
776
+ avg_data["avg/loss"] = avg_loss
777
+ avg_data["avg/tloss"] = avg_tloss
778
+ avg_data["avg/grad"] = avg_grad
779
+ wandb.log(avg_data, step=global_step)
780
 
 
 
781
  if accelerator.is_main_process:
782
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses)>0 else 0.0
783
  print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
784
  if use_wandb:
785
  wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
786
 
787
  # Завершение обучения - сохраняем финальную модель
788
+ if loss_ratios.get("dispersive", 0) > 0:
789
  dispersive_hook.remove_hooks()
790
  if accelerator.is_main_process:
791
  print("Обучение завершено! Сохраняем финальную модель...")
 
792
  if save_model:
793
  save_checkpoint(unet,"fp16")
794
  accelerator.free_memory()
 
796
  torch.distributed.destroy_process_group()
797
 
798
  print("Готово!")