recoilme commited on
Commit
e458bd6
·
1 Parent(s): f84b3af

dispersive_loss_disabled

Browse files
samples/unet_320x576_0.jpg CHANGED

Git LFS Details

  • SHA256: e3e1eed66df8356d9a4f3abb4fca96b2750c05a48c64689b76a003e1c396f029
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB

Git LFS Details

  • SHA256: 9753bcbd8f3bdd79b1a981ba703c56413c3cdcb652957e22b7795223abf97ada
  • Pointer size: 130 Bytes
  • Size of remote file: 73.5 kB
samples/unet_384x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 807ca39bcd04ad0dbdcedde52bb7690f1bdb73c97e37bdb1262736461f791291
  • Pointer size: 130 Bytes
  • Size of remote file: 87.1 kB

Git LFS Details

  • SHA256: 56f54bd854da6bcb33baa894ff063cbdcb8a4f9371f214aa7979020f7a93a479
  • Pointer size: 130 Bytes
  • Size of remote file: 89.2 kB
samples/unet_448x576_0.jpg CHANGED

Git LFS Details

  • SHA256: a9579bb3c1d3b0e4bd76ad0834aae34a7d51ccecb675b6af42cceddece2d4f9a
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB

Git LFS Details

  • SHA256: 29c80e76b4191508c8b8af7f3be240a8fb0d94bd498e5966a65b78e4e0ffcd7e
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
samples/unet_512x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 5dfbc40067f048ba11326fe65a21b3730cd95d0b2f274ee45d529046aa44880e
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB

Git LFS Details

  • SHA256: 370f5855f49baaa69979cc0fa830c698eb1afcf3e9ebcc73275c82518f65f5b2
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
samples/unet_576x320_0.jpg CHANGED

Git LFS Details

  • SHA256: ce9de204808a5d7f42de8f3a445fc130847ce50e3e0fa05d26fa68153d2aee5b
  • Pointer size: 130 Bytes
  • Size of remote file: 66.1 kB

Git LFS Details

  • SHA256: bab97a9bb2d6d4fb9ced68b2afa38fbe88af438da6b9e80dcc1f4407368ff058
  • Pointer size: 130 Bytes
  • Size of remote file: 84.4 kB
samples/unet_576x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 637c8af557d8c7035cfe9209163dac867f722fcd6781ecdc7e1a46600a21c79a
  • Pointer size: 130 Bytes
  • Size of remote file: 46 kB

Git LFS Details

  • SHA256: fa3e8315e008779b32fae73687f443c74dee2582b59005fc033b3239d32e7a8b
  • Pointer size: 130 Bytes
  • Size of remote file: 82.9 kB
samples/unet_576x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 7438a12c17f5b7b52160202df63d5484c96fb3ade27f57c1c3f460da532d3452
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB

Git LFS Details

  • SHA256: 1570a01ed4a0b461b98e62afcbe0a9bb6a1d127264006ba67580a27305e67b79
  • Pointer size: 130 Bytes
  • Size of remote file: 81.2 kB
samples/unet_576x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 3ad86a4b554edf5c9a6677d78f55766874cdce0ccff24205a79b779a55674aef
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB

Git LFS Details

  • SHA256: e4afe8993c88d19b20b9fa4f4856150cfe0f210553d3ef7b9810051f0b70e370
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
samples/unet_576x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 74334562d66b14cef1d6e879294999356ae3e6b2666c4f0fe1b8d066a5fb36cb
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB

Git LFS Details

  • SHA256: 70d1d7db450e142f7c60c05f615c95ef052a55a99f628480a81bf8ed07567305
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
train.py CHANGED
@@ -27,7 +27,7 @@ import torch.nn.functional as F
27
  # --------------------------- Параметры ---------------------------
28
  ds_path = "datasets/576"
29
  project = "unet"
30
- batch_size = 25
31
  base_learning_rate = 9.5e-6
32
  min_learning_rate = 8.5e-6
33
  num_epochs = 20
@@ -43,20 +43,22 @@ unet_gradient = True
43
  clip_sample = False #Scheduler
44
  fixed_seed = False
45
  shuffle = True
46
- dispersive_loss = True
47
  torch.backends.cuda.matmul.allow_tf32 = True
48
  torch.backends.cudnn.allow_tf32 = True
49
  torch.backends.cuda.enable_mem_efficient_sdp(False)
50
  dtype = torch.float32
51
  save_barrier = 1.03
52
  dispersive_temperature=0.5
53
- dispersive_weight=0.05
54
  percentile_clipping = 95 # 8bit optim
55
- clip_grad_norm = 0.5
 
 
56
  steps_offset = 1 # Scheduler
57
  limit = 0
58
  checkpoints_folder = ""
59
- mixed_precision = "fp16"
60
  gradient_accumulation_steps = 2
61
  accelerator = Accelerator(
62
  mixed_precision=mixed_precision,
@@ -140,36 +142,6 @@ class AccelerateDispersiveLoss:
140
  # ВАЖНО: он отриц и должен падать
141
  return dispersive_loss
142
 
143
- def compute_dispersive_loss2(self):
144
- # Если нет активаций, возвращаем 0
145
- if not self.activations:
146
- return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
147
-
148
- # Работаем только с локальными активациями главного процесса
149
- activations = self.activations[-1].float()
150
-
151
- batch_size = activations.shape[0]
152
- if batch_size < 2:
153
- return torch.tensor(0.0, device=self.accelerator.device, requires_grad=True)
154
-
155
- # Нормализация
156
- norm = torch.norm(activations, dim=1, keepdim=True).clamp(min=1e-12)
157
- sf = activations / norm
158
-
159
- # Вычисляем расстояния
160
- distance = torch.nn.functional.pdist(sf, p=2)
161
- distance = distance.clamp(min=1e-12)
162
- distance_squared = distance ** 2
163
-
164
- # Вычисляем loss с клиппингом для стабильности
165
- exp_neg_dist = torch.exp((-distance_squared / self.temperature).clamp(min=-20, max=20))
166
- exp_neg_dist = exp_neg_dist + 1e-12
167
-
168
- mean_exp = torch.mean(exp_neg_dist)
169
- dispersive_loss = torch.log(mean_exp.clamp(min=1e-12))
170
-
171
- return dispersive_loss
172
-
173
  def clear_activations(self):
174
  self.activations.clear()
175
 
@@ -177,140 +149,7 @@ class AccelerateDispersiveLoss:
177
  for hook in self.hooks:
178
  hook.remove()
179
  self.hooks.clear()
180
-
181
- class AccelerateDispersiveLoss2:
182
- def __init__(self, accelerator, temperature=0.5, weight=0.5):
183
- self.accelerator = accelerator
184
- self.temperature = temperature
185
- self.weight = weight
186
- self.activations = []
187
- self.hooks = []
188
-
189
- def register_hooks(self, model, target_layer="down_blocks.0"):
190
- # Получаем "чистую" модель без DDP wrapper'а
191
- unwrapped_model = self.accelerator.unwrap_model(model)
192
-
193
- print("=== Поиск слоев в unwrapped модели ===")
194
- for name, module in unwrapped_model.named_modules():
195
- if target_layer in name:
196
- hook = module.register_forward_hook(self.hook_fn)
197
- self.hooks.append(hook)
198
- print(f"✅ Хук зарегистрирован на: {name}")
199
- break
200
-
201
- def hook_fn(self, module, input, output):
202
- if isinstance(output, tuple):
203
- activation = output[0]
204
- else:
205
- activation = output
206
-
207
- if len(activation.shape) > 2:
208
- activation = activation.view(activation.shape[0], -1)
209
-
210
- self.activations.append(activation.detach())
211
-
212
- def compute_dispersive_loss_fix(self):
213
- if not self.activations:
214
- return torch.tensor(0.0, requires_grad=True)
215
-
216
- local_activations = self.activations[-1]
217
-
218
- # Собираем активации со всех GPU
219
- if self.accelerator.num_processes > 1:
220
- gathered_activations = self.accelerator.gather(local_activations)
221
- else:
222
- gathered_activations = local_activations
223
-
224
- batch_size = gathered_activations.shape[0]
225
- if batch_size < 2:
226
- return torch.tensor(0.0, requires_grad=True)
227
-
228
- # Переводим в float32 для стабильности
229
- gathered_activations = gathered_activations.float()
230
-
231
- # Нормализация с eps для стабильности
232
- norm = torch.norm(gathered_activations, dim=1, keepdim=True).clamp(min=1e-12)
233
- sf = gathered_activations / norm
234
-
235
- # Вычисляем расстояния
236
- distance = torch.nn.functional.pdist(sf, p=2)
237
- distance = distance.clamp(min=1e-12) # избегаем слишком маленьких значений
238
- distance_squared = distance ** 2
239
-
240
- # Экспонента с клиппингом
241
- exp_neg_dist = torch.exp((-distance_squared / self.temperature).clamp(min=-20, max=20))
242
- exp_neg_dist = exp_neg_dist + 1e-12 # избегаем нулей
243
-
244
- # Среднее и лог
245
- mean_exp = torch.mean(exp_neg_dist)
246
- dispersive_loss = torch.log(mean_exp.clamp(min=1e-12))
247
-
248
- return dispersive_loss
249
-
250
- def compute_dispersive_loss(self):
251
- if not self.activations:
252
- return torch.tensor(0.0, requires_grad=True)
253
-
254
- local_activations = self.activations[-1].float()
255
-
256
- # Собираем активации со всех GPU
257
- if self.accelerator.num_processes > 1:
258
- gathered_activations = self.accelerator.gather(local_activations)
259
- else:
260
- gathered_activations = local_activations
261
-
262
- batch_size = gathered_activations.shape[0]
263
- if batch_size < 2:
264
- return torch.tensor(0.0, requires_grad=True)
265
-
266
- # Нормализация и вычисление loss
267
- sf = gathered_activations / torch.norm(gathered_activations, dim=1, keepdim=True)
268
- sf = sf.float()
269
- distance = torch.nn.functional.pdist(sf, p=2) ** 2
270
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
271
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
272
-
273
- # ВАЖНО: он отриц и должен падать
274
- return dispersive_loss
275
 
276
-
277
- def compute_dispersive_loss_single(self):
278
- if not self.activations:
279
- return torch.tensor(0.0, requires_grad=True)
280
-
281
- local_activations = self.activations[-1] # Активации с текущего GPU
282
-
283
- # Собираем активации со всех GPU
284
- if self.accelerator.num_processes > 1:
285
- # Используем accelerate для сбора
286
- gathered_activations = self.accelerator.gather(local_activations)
287
- else:
288
- gathered_activations = local_activations
289
-
290
- # На главном процессе вычисляем loss
291
- if self.accelerator.is_main_process:
292
- batch_size = gathered_activations.shape[0]
293
- if batch_size < 2:
294
- return torch.tensor(0.0, requires_grad=True)
295
-
296
- # Нормализация и вычисление loss
297
- sf = gathered_activations / torch.norm(gathered_activations, dim=1, keepdim=True)
298
- distance = torch.nn.functional.pdist(sf, p=2) ** 2
299
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
300
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
301
-
302
- return dispersive_loss
303
- else:
304
- # На не-главных процессах возвращаем 0
305
- return torch.tensor(0.0, requires_grad=True)
306
-
307
- def clear_activations(self):
308
- self.activations.clear()
309
-
310
- def remove_hooks(self):
311
- for hook in self.hooks:
312
- hook.remove()
313
- self.hooks.clear()
314
 
315
 
316
  # --------------------------- Инициализация WandB ---------------------------
@@ -551,7 +390,7 @@ if os.path.isdir(latest_checkpoint):
551
  print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
552
 
553
  # Регистрируем хук на модел
554
- if dispersive_loss:
555
  dispersive_hook = AccelerateDispersiveLoss(
556
  accelerator=accelerator,
557
  temperature=dispersive_temperature,
@@ -632,7 +471,7 @@ else:
632
  def create_optimizer(name, params):
633
  if name == "adam8bit":
634
  return bnb.optim.AdamW8bit(
635
- params, lr=base_learning_rate, betas=(0.9, 0.97), eps=1e-5, weight_decay=0.001,
636
  percentile_clipping=percentile_clipping
637
  )
638
  elif name == "adam":
@@ -686,7 +525,7 @@ else:
686
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
687
 
688
  # Регистрация хуков ПОСЛЕ prepare
689
- if dispersive_loss:
690
  dispersive_hook.register_hooks(unet, "down_blocks.2")
691
 
692
  # --------------------------- Фиксированные семплы для генерации ---------------------------
@@ -867,7 +706,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
867
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
868
 
869
  # Очищаем активации перед forward pass
870
- if dispersive_loss:
871
  dispersive_hook.clear_activations()
872
 
873
  # Используем целевое значение
@@ -880,7 +719,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
880
  # Dispersive Loss
881
  #Идентичные векторы: Loss = -0.0000
882
  #Ортогональные векторы: Loss = -3.9995
883
- if dispersive_loss:
884
  with torch.amp.autocast('cuda', enabled=False):
885
  dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
886
  if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss):
@@ -888,13 +727,21 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
888
 
889
  # Итоговый loss
890
  # dispersive_loss должен падать и тотал падать - поэтому плюс
891
- total_loss = loss + dispersive_loss
 
 
 
892
 
893
  # Проверяем на nan/inf перед backward
894
  if torch.isnan(loss) or torch.isinf(loss):
895
  print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
896
  save_model = False
897
  break
 
 
 
 
 
898
 
899
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
900
  accelerator.wait_for_everyone()
@@ -938,8 +785,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
938
  "epoch": epoch,
939
  "grad": grad,
940
  "global_step": global_step,
941
- "dispersive_loss": dispersive_loss,
942
- "total_loss": total_loss
943
  })
944
 
945
  # Генерируем сэмплы с заданным интервалом
 
27
  # --------------------------- Параметры ---------------------------
28
  ds_path = "datasets/576"
29
  project = "unet"
30
+ batch_size = 30
31
  base_learning_rate = 9.5e-6
32
  min_learning_rate = 8.5e-6
33
  num_epochs = 20
 
43
  clip_sample = False #Scheduler
44
  fixed_seed = False
45
  shuffle = True
46
+ dispersive_loss_enabled = False
47
  torch.backends.cuda.matmul.allow_tf32 = True
48
  torch.backends.cudnn.allow_tf32 = True
49
  torch.backends.cuda.enable_mem_efficient_sdp(False)
50
  dtype = torch.float32
51
  save_barrier = 1.03
52
  dispersive_temperature=0.5
53
+ dispersive_weight= 0.05
54
  percentile_clipping = 95 # 8bit optim
55
+ betta2 = 0.99 #0.97
56
+ eps = 1e-6 #1e-5
57
+ clip_grad_norm = 1. #.5
58
  steps_offset = 1 # Scheduler
59
  limit = 0
60
  checkpoints_folder = ""
61
+ mixed_precision = "bf16" #"fp16"
62
  gradient_accumulation_steps = 2
63
  accelerator = Accelerator(
64
  mixed_precision=mixed_precision,
 
142
  # ВАЖНО: он отриц и должен падать
143
  return dispersive_loss
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def clear_activations(self):
146
  self.activations.clear()
147
 
 
149
  for hook in self.hooks:
150
  hook.remove()
151
  self.hooks.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
  # --------------------------- Инициализация WandB ---------------------------
 
390
  print(f"torch.nn.functional.get_flash_attention_available(): {torch.nn.functional.get_flash_attention_available()}")
391
 
392
  # Регистрируем хук на модел
393
+ if dispersive_loss_enabled:
394
  dispersive_hook = AccelerateDispersiveLoss(
395
  accelerator=accelerator,
396
  temperature=dispersive_temperature,
 
471
  def create_optimizer(name, params):
472
  if name == "adam8bit":
473
  return bnb.optim.AdamW8bit(
474
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001,
475
  percentile_clipping=percentile_clipping
476
  )
477
  elif name == "adam":
 
525
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
526
 
527
  # Регистрация хуков ПОСЛЕ prepare
528
+ if dispersive_loss_enabled:
529
  dispersive_hook.register_hooks(unet, "down_blocks.2")
530
 
531
  # --------------------------- Фиксированные семплы для генерации ---------------------------
 
706
  noisy_latents = scheduler.add_noise(latents, noise, timesteps)
707
 
708
  # Очищаем активации перед forward pass
709
+ if dispersive_loss_enabled:
710
  dispersive_hook.clear_activations()
711
 
712
  # Используем целевое значение
 
719
  # Dispersive Loss
720
  #Идентичные векторы: Loss = -0.0000
721
  #Ортогональные векторы: Loss = -3.9995
722
+ if dispersive_loss_enabled:
723
  with torch.amp.autocast('cuda', enabled=False):
724
  dispersive_loss = dispersive_hook.weight * dispersive_hook.compute_dispersive_loss()
725
  if torch.isnan(dispersive_loss) or torch.isinf(dispersive_loss):
 
727
 
728
  # Итоговый loss
729
  # dispersive_loss должен падать и тотал падать - поэтому плюс
730
+ if dispersive_loss_enabled:
731
+ total_loss = loss + dispersive_loss
732
+ else:
733
+ total_loss = loss
734
 
735
  # Проверяем на nan/inf перед backward
736
  if torch.isnan(loss) or torch.isinf(loss):
737
  print(f"Rank {accelerator.process_index}: Found nan/inf in loss: {loss}")
738
  save_model = False
739
  break
740
+
741
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
742
+ print(f"Rank {accelerator.process_index}: Found nan/inf in total_loss: {total_loss}")
743
+ print(f"Проблемный батч: step={step}, latents.shape={latents.shape}, embeddings.shape={embeddings.shape}")
744
+ continue
745
 
746
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
747
  accelerator.wait_for_everyone()
 
785
  "epoch": epoch,
786
  "grad": grad,
787
  "global_step": global_step,
788
+ **({"dispersive_loss": dispersive_loss} if dispersive_loss_enabled else {}),
789
+ **({"total_loss": total_loss} if dispersive_loss_enabled else {})
790
  })
791
 
792
  # Генерируем сэмплы с заданным интервалом
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd195f44f05b8269dfc948cbc9046d1585b6c126485aba158c549a58dee09ae3
3
  size 7014306128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e14b52a3c3ef2ae70a86b58094ade0d9dd40bee48b05a08a1fde027ed735e6f
3
  size 7014306128