recoilme commited on
Commit
28ee1df
·
1 Parent(s): a31578a
samples/unet_192x384_0.jpg DELETED

Git LFS Details

  • SHA256: 87014828d47898b6db57757ef86957fd30d58730ede1320a781067cbd0a83ece
  • Pointer size: 130 Bytes
  • Size of remote file: 24.5 kB
samples/unet_256x384_0.jpg DELETED

Git LFS Details

  • SHA256: a201862a1de3c3a6bfda9ff91a4fa469921da68f27eb78853dfdcb5201d6a693
  • Pointer size: 130 Bytes
  • Size of remote file: 46.3 kB
samples/unet_320x384_0.jpg DELETED

Git LFS Details

  • SHA256: 7e42532647e82db14135745a792cccd0e5728be3980bc2a7b106cf2509613e30
  • Pointer size: 130 Bytes
  • Size of remote file: 58.8 kB
samples/unet_384x192_0.jpg DELETED

Git LFS Details

  • SHA256: 93bfeeae000edcd71ddbb1d8938b99ee6325af1a11b128a815227bf21f02040c
  • Pointer size: 130 Bytes
  • Size of remote file: 36.3 kB
samples/unet_384x256_0.jpg DELETED

Git LFS Details

  • SHA256: dfc3e7c6d36621c47a551765e088fc948b51874309aef42204e17c6aa4495a0e
  • Pointer size: 130 Bytes
  • Size of remote file: 45.3 kB
samples/unet_384x320_0.jpg DELETED

Git LFS Details

  • SHA256: 7b2961101c9e8b0270c9cce9bef76a7abc69cae2598edd34e19b8397c996f974
  • Pointer size: 130 Bytes
  • Size of remote file: 60 kB
train.py CHANGED
@@ -8,7 +8,7 @@ from torch.utils.data.distributed import DistributedSampler
8
  from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
  from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL, DDPMScheduler
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
@@ -28,10 +28,10 @@ from collections import deque
28
  # --------------------------- Параметры ---------------------------
29
  ds_path = "/workspace/sdxs3d/datasets/butterfly"
30
  project = "unet"
31
- batch_size = 8
32
  base_learning_rate = 9e-5
33
  min_learning_rate = 1e-5
34
- num_epochs = 300
35
  # samples/save per epoch
36
  sample_interval_share = 1
37
  use_wandb = True
@@ -42,16 +42,17 @@ optimizer_type = "adam8bit"
42
  torch_compile = False
43
  unet_gradient = True
44
  clip_sample = False #Scheduler
45
- fixed_seed = False
46
  shuffle = True
 
 
 
47
  torch.backends.cuda.matmul.allow_tf32 = True
48
  torch.backends.cudnn.allow_tf32 = True
49
  torch.backends.cuda.enable_mem_efficient_sdp(False)
50
  dtype = torch.float32
51
  save_barrier = 1.03
52
  warmup_percent = 0.01
53
- dispersive_temperature=0.5
54
- dispersive_weight= 0.05
55
  percentile_clipping = 99 # 8bit optim
56
  betta2 = 0.995
57
  eps = 1e-8
@@ -86,16 +87,6 @@ if fixed_seed:
86
  if torch.cuda.is_available():
87
  torch.cuda.manual_seed_all(seed)
88
 
89
- # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
90
- # CHANGED: добавлен huber и dispersive в пропорции, суммы = 1.0
91
- loss_ratios = {
92
- "mse": 1.0,
93
- "mae": 0.0,
94
- "huber": 0.0,
95
- "dispersive": 0.0,
96
- }
97
- median_coeff_steps = 128 # за сколько шагов считать медианные коэффициенты
98
-
99
  # --------------------------- Параметры LoRA ---------------------------
100
  lora_name = ""
101
  lora_rank = 32
@@ -110,7 +101,7 @@ def sample_timesteps_bias(
110
  num_train_timesteps: int, # обычно 1000
111
  steps_offset: int = 0,
112
  device=None,
113
- mode: str = "uniform", # "beta", "uniform"
114
  ) -> torch.Tensor:
115
  """
116
  Возвращает timesteps с разным bias:
@@ -135,98 +126,45 @@ def sample_timesteps_bias(
135
  timesteps = steps_offset + (samples * max_idx).long().to(device)
136
  return timesteps
137
 
 
 
 
 
 
 
 
138
 
139
- # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
140
- class MedianLossNormalizer:
141
- def __init__(self, desired_ratios: dict, window_steps: int):
142
- # нормируем доли на случай, если сумма != 1
143
- s = sum(desired_ratios.values())
144
- self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
145
- self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
146
- self.window = window_steps
147
-
148
- def update_and_total(self, losses: dict):
149
- """
150
- losses: dict ключ->тензор (значения лоссов)
151
- Поведение:
152
- - буферим ABS(l) только для активных (ratio>0) лоссов
153
- - coeff = ratio / median(abs(loss))
154
- - total = sum(coeff * loss) по активным лоссам
155
- CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
156
- """
157
- # буферим только активные лоссы
158
- for k, v in losses.items():
159
- if k in self.buffers and self.ratios.get(k, 0) > 0:
160
- self.buffers[k].append(float(v.detach().abs().cpu()))
161
-
162
- meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
163
- coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
164
-
165
- # суммируем только по активным (ratio>0)
166
- total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
167
- return total, coeffs, meds
168
-
169
- # создаём normalizer после определения loss_ratios
170
- normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
171
-
172
- class AccelerateDispersiveLoss:
173
- def __init__(self, accelerator, temperature=0.5, weight=0.5):
174
- self.accelerator = accelerator
175
- self.temperature = temperature
176
- self.weight = weight
177
- self.activations = []
178
- self.hooks = []
179
-
180
- def register_hooks(self, model, target_layer="down_blocks.0"):
181
- unwrapped_model = self.accelerator.unwrap_model(model)
182
- print("=== Поиск слоев в unwrapped модели ===")
183
- for name, module in unwrapped_model.named_modules():
184
- if target_layer in name:
185
- hook = module.register_forward_hook(self.hook_fn)
186
- self.hooks.append(hook)
187
- print(f"✅ Хук зарегистрирован на: {name}")
188
- break
189
-
190
- def hook_fn(self, module, input, output):
191
- if isinstance(output, tuple):
192
- activation = output[0]
193
- else:
194
- activation = output
195
- if len(activation.shape) > 2:
196
- activation = activation.view(activation.shape[0], -1)
197
- self.activations.append(activation.detach().clone())
198
-
199
- def compute_dispersive_loss(self):
200
- if not self.activations:
201
- return torch.tensor(0.0, requires_grad=True, device=device)
202
- local_activations = self.activations[-1].float()
203
- batch_size = local_activations.shape[0]
204
- if batch_size < 2:
205
- return torch.tensor(0.0, requires_grad=True, device=device)
206
- sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
207
- distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
208
- exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
209
- dispersive_loss = torch.log(torch.mean(exp_neg_dist))
210
- return dispersive_loss
211
-
212
- def clear_activations(self):
213
- self.activations.clear()
214
-
215
- def remove_hooks(self):
216
- for hook in self.hooks:
217
- hook.remove()
218
- self.hooks.clear()
219
 
220
 
221
  # --------------------------- Инициализация WandB ---------------------------
222
- if use_wandb and accelerator.is_main_process:
223
- wandb.init(project=project+lora_name, config={
224
- "batch_size": batch_size,
225
- "base_learning_rate": base_learning_rate,
226
- "num_epochs": num_epochs,
227
- "fbp": fbp,
228
- "optimizer_type": optimizer_type,
229
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  # Включение Flash Attention 2/SDPA
232
  torch.backends.cuda.enable_flash_sdp(True)
@@ -236,12 +174,7 @@ gen.manual_seed(seed)
236
 
237
  # --------------------------- Загрузка моделей ---------------------------
238
  # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
239
- #vae = AutoencoderKLWan.from_pretrained("vae", variant="fp16").to(device="cpu", dtype=torch.float16).eval()
240
- #vae = AutoencoderKLWan.from_pretrained(
241
- # "AiArtLab/simplevae", subfolder="wan16x_vae_nightly",
242
- # torch_dtype=dtype
243
- # ).to(device="cpu").eval()
244
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to(device).eval()
245
 
246
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
247
  if shift_factor is None:
@@ -254,14 +187,7 @@ if scaling_factor is None:
254
  latents_mean = getattr(vae.config, "latents_mean", None)
255
  latents_std = getattr(vae.config, "latents_std", None)
256
 
257
- # DDPMScheduler с V_Prediction и Zero-SNR
258
- scheduler = DDPMScheduler(
259
- num_train_timesteps=1000,
260
- prediction_type="v_prediction",
261
- rescale_betas_zero_snr=True,
262
- clip_sample = clip_sample,
263
- steps_offset = steps_offset
264
- )
265
 
266
 
267
  class DistributedResolutionBatchSampler(Sampler):
@@ -403,13 +329,6 @@ if os.path.isdir(latest_checkpoint):
403
  print(f"Ошибка при включении SDPA: {e}")
404
  unet.set_use_memory_efficient_attention_xformers(True)
405
 
406
- # Создаём hook для dispersive только если нужно
407
- if loss_ratios.get("dispersive", 0) > 0:
408
- dispersive_hook = AccelerateDispersiveLoss(
409
- accelerator=accelerator,
410
- temperature=dispersive_temperature,
411
- weight=dispersive_weight
412
- )
413
  else:
414
  # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
415
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
@@ -512,10 +431,6 @@ else:
512
  if torch.isnan(param).any() or torch.isinf(param).any():
513
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
514
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
515
-
516
- # Регистрация хуков ПОСЛЕ prepare
517
- if loss_ratios.get("dispersive", 0) > 0:
518
- dispersive_hook.register_hooks(unet, "down_blocks.2")
519
 
520
  # --------------------------- Фиксированные семплы для генерации ---------------------------
521
  fixed_samples = get_fixed_samples_by_resolution(dataset)
@@ -528,7 +443,6 @@ def generate_and_save_samples(fixed_samples_cpu, step):
528
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
529
  vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
530
 
531
- scheduler.set_timesteps(n_diffusion_steps)
532
 
533
  all_generated_images = []
534
  all_captions = []
@@ -546,39 +460,33 @@ def generate_and_save_samples(fixed_samples_cpu, step):
546
  )
547
  current_latents = noise.clone()
548
 
549
- if guidance_scale > 0:
550
  empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
551
  text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
552
  else:
553
  text_embeddings_batch = sample_text_embeddings
554
 
555
- for t in scheduler.timesteps:
556
- t_batch = t.repeat(current_latents.shape[0]).to(device)
557
- if guidance_scale > 0:
558
- latent_model_input = torch.cat([current_latents] * 2)
 
 
 
559
  else:
560
  latent_model_input = current_latents
 
 
 
561
 
562
- latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
563
- noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
564
-
565
- if guidance_scale > 0:
566
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
567
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
568
 
569
- current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
570
-
571
- #print(current_latents.ndim, current_latents.shape)
572
- #if current_latents.ndim == 4:
573
- # current_latents = current_latents.unsqueeze(2)
574
- # Латент в форме [B, C, T, H, W]
575
- #print(current_latents.ndim, current_latents.shape)
576
 
577
  # Параметры нормализации
578
- latent_for_vae = current_latents.detach() * scaling_factor + shift_factor
579
-
580
- if latents_mean!=None and latents_std!=None:
581
- latent_for_vae = latent_for_vae * torch.tensor(latents_std, device=device, dtype=dtype).view(1, -1, 1, 1, 1) + torch.tensor(latents_mean, device=device, dtype=dtype).view(1, -1, 1, 1, 1)
582
 
583
  decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
584
  #decoded = decoded[:, :, 0, :, :] # [3, H, W]
@@ -615,7 +523,20 @@ def generate_and_save_samples(fixed_samples_cpu, step):
615
  wandb.Image(img, caption=f"{all_captions[i]}")
616
  for i, img in enumerate(all_generated_images)
617
  ]
618
- wandb.log({"generated_images": wandb_images, "global_step": step})
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  finally:
620
  # вернуть VAE на CPU (как было в твоём коде)
621
  vae.to("cpu")
@@ -644,30 +565,6 @@ def save_checkpoint(unet,variant=""):
644
  accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
645
  unet = unet.to(dtype=dtype)
646
 
647
- def batch_pred_original_from_step(model_outputs, timesteps_tensor, noisy_latents, scheduler):
648
- device = noisy_latents.device
649
- dtype = noisy_latents.dtype
650
-
651
- available_ts = scheduler.timesteps
652
- if not isinstance(available_ts, torch.Tensor):
653
- available_ts = torch.tensor(available_ts, device="cpu")
654
- else:
655
- available_ts = available_ts.cpu()
656
-
657
- B = model_outputs.shape[0]
658
- preds = []
659
- for i in range(B):
660
- t_i = int(timesteps_tensor[i].item())
661
- diffs = torch.abs(available_ts - t_i)
662
- idx = int(torch.argmin(diffs).item())
663
- t_for_step = int(available_ts[idx].item())
664
- model_out_i = model_outputs[i:i+1]
665
- noisy_latent_i = noisy_latents[i:i+1]
666
- step_out = scheduler.step(model_out_i, t_for_step, noisy_latent_i)
667
- preds.append(step_out.pred_original_sample)
668
-
669
- return torch.cat(preds, dim=0).to(device=device, dtype=dtype)
670
-
671
  # --------------------------- Тренировочный цикл ---------------------------
672
  if accelerator.is_main_process:
673
  print(f"Total steps per GPU: {total_training_steps}")
@@ -681,12 +578,11 @@ min_loss = 1.
681
 
682
  for epoch in range(start_epoch, start_epoch + num_epochs):
683
  batch_losses = []
684
- batch_tlosses = []
685
  batch_grads = []
686
  batch_sampler.set_epoch(epoch)
687
  accelerator.wait_for_everyone()
688
  unet.train()
689
- print("epoch:",epoch)
690
  for step, (latents, embeddings) in enumerate(dataloader):
691
  with accelerator.accumulate(unet):
692
  if save_model == False and step == 5 :
@@ -695,44 +591,15 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
695
 
696
  noise = torch.randn_like(latents, dtype=latents.dtype)
697
 
698
- progress = global_step / max(1, total_training_steps - 1)
699
- timesteps = sample_timesteps_bias(
700
- batch_size=latents.shape[0],
701
- progress=progress,
702
- num_train_timesteps=scheduler.config.num_train_timesteps,
703
- steps_offset=steps_offset,
704
- device=device
705
- )
706
-
707
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
708
-
709
- if loss_ratios.get("dispersive", 0) > 0:
710
- dispersive_hook.clear_activations()
711
 
712
- #print(latents.shape,embeddings.shape)
713
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
714
- target_pred = scheduler.get_velocity(latents, noise, timesteps)
715
 
716
- # === Losses ===
717
- losses_dict = {}
 
718
 
719
  mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
720
- losses_dict["mse"] = mse_loss
721
- losses_dict["mae"] = F.l1_loss(model_pred.float(), target_pred.float())
722
-
723
- # CHANGED: Huber (smooth_l1) loss added
724
- losses_dict["huber"] = F.smooth_l1_loss(model_pred.float(), target_pred.float())
725
-
726
- # === Dispersive loss ===
727
- if loss_ratios.get("dispersive", 0) > 0:
728
- disp_raw = dispersive_hook.compute_dispersive_loss().to(device) # может быть отрицательным
729
- losses_dict["dispersive"] = dispersive_hook.weight * disp_raw
730
- else:
731
- losses_dict["dispersive"] = torch.tensor(0.0, device=device)
732
-
733
- # === Нормализация всех лоссов ===
734
- abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
735
- total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
736
 
737
  # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
738
  batch_losses.append(mse_loss.detach().item())
@@ -741,7 +608,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
741
  accelerator.wait_for_everyone()
742
 
743
  # Backward
744
- accelerator.backward(total_loss)
745
 
746
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
747
  accelerator.wait_for_everyone()
@@ -765,32 +632,29 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
765
  current_lr = base_learning_rate
766
  else:
767
  current_lr = lr_scheduler.get_last_lr()[0]
768
- batch_tlosses.append(total_loss.detach().item())
769
  batch_grads.append(grad)
770
 
771
- # Логируем только активные лоссы (ratio>0)
772
- active_keys = [k for k, v in loss_ratios.items() if v > 0]
773
  log_data = {}
774
- for k in active_keys:
775
- v = losses_dict.get(k, None)
776
- if v is None:
777
- continue
778
- log_data[f"loss/{k}"] = (v.item() if isinstance(v, torch.Tensor) else float(v))
779
-
780
- log_data["loss/total"] = float(total_loss.item())
781
- log_data["loss/lr"] = current_lr
782
- for k, c in coeffs.items():
783
- log_data[f"coeff/{k}"] = float(c)
784
- if use_wandb and accelerator.sync_gradients:
785
- wandb.log(log_data, step=global_step)
786
 
787
  # Генерируем сэмплы с заданным интервалом
788
  if global_step % sample_interval == 0:
789
  generate_and_save_samples(fixed_samples,global_step)
790
  last_n = sample_interval
791
  avg_loss = float(np.mean(batch_losses[-last_n:])) if len(batch_losses) > 0 else 0.0
792
- avg_tloss = float(np.mean(batch_tlosses[-last_n:])) if len(batch_tlosses) > 0 else 0.0
793
  avg_grad = float(np.mean(batch_grads[-last_n:])) if len(batch_grads) > 0 else 0.0
 
 
 
 
 
794
  print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}")
795
 
796
  if save_model:
@@ -799,25 +663,26 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
799
  min_loss = avg_loss
800
  save_checkpoint(unet)
801
  if use_wandb:
802
- avg_data = {}
803
- avg_data["avg/loss"] = avg_loss
804
- avg_data["avg/tloss"] = avg_tloss
805
- avg_data["avg/grad"] = avg_grad
806
  wandb.log(avg_data, step=global_step)
 
 
 
807
 
808
  if accelerator.is_main_process:
809
  avg_epoch_loss = np.mean(batch_losses) if len(batch_losses)>0 else 0.0
810
  print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
811
  if use_wandb:
812
  wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
 
 
813
 
814
  # Завершение обучения - сохраняем финальную модель
815
- if loss_ratios.get("dispersive", 0) > 0:
816
- dispersive_hook.remove_hooks()
817
  if accelerator.is_main_process:
818
  print("Обучение завершено! Сохраняем финальную модель...")
819
  if save_model:
820
  save_checkpoint(unet,"fp16")
 
 
821
  accelerator.free_memory()
822
  if torch.distributed.is_initialized():
823
  torch.distributed.destroy_process_group()
 
8
  from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
  from torch.optim.lr_scheduler import LambdaLR
11
+ from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
 
28
  # --------------------------- Параметры ---------------------------
29
  ds_path = "/workspace/sdxs3d/datasets/butterfly"
30
  project = "unet"
31
+ batch_size = 16
32
  base_learning_rate = 9e-5
33
  min_learning_rate = 1e-5
34
+ num_epochs = 30
35
  # samples/save per epoch
36
  sample_interval_share = 1
37
  use_wandb = True
 
42
  torch_compile = False
43
  unet_gradient = True
44
  clip_sample = False #Scheduler
45
+ fixed_seed = True
46
  shuffle = True
47
+ use_comet_ml = False # Добавлен флаг для Comet ML
48
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
+ comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
  torch.backends.cuda.enable_mem_efficient_sdp(False)
53
  dtype = torch.float32
54
  save_barrier = 1.03
55
  warmup_percent = 0.01
 
 
56
  percentile_clipping = 99 # 8bit optim
57
  betta2 = 0.995
58
  eps = 1e-8
 
87
  if torch.cuda.is_available():
88
  torch.cuda.manual_seed_all(seed)
89
 
 
 
 
 
 
 
 
 
 
 
90
  # --------------------------- Параметры LoRA ---------------------------
91
  lora_name = ""
92
  lora_rank = 32
 
101
  num_train_timesteps: int, # обычно 1000
102
  steps_offset: int = 0,
103
  device=None,
104
+ mode: str = "beta", # "beta", "uniform"
105
  ) -> torch.Tensor:
106
  """
107
  Возвращает timesteps с разным bias:
 
126
  timesteps = steps_offset + (samples * max_idx).long().to(device)
127
  return timesteps
128
 
129
+ def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
+ normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
+
132
+ logit_normal_samples = torch.sigmoid(normal_samples)
133
+
134
+ return logit_normal_samples
135
+
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  # --------------------------- Инициализация WandB ---------------------------
140
+ if accelerator.is_main_process:
141
+ if use_wandb:
142
+ wandb.init(project=project+lora_name, config={
143
+ "batch_size": batch_size,
144
+ "base_learning_rate": base_learning_rate,
145
+ "num_epochs": num_epochs,
146
+ "fbp": fbp,
147
+ "optimizer_type": optimizer_type,
148
+ })
149
+ if use_comet_ml:
150
+ from comet_ml import Experiment
151
+ comet_experiment = Experiment(
152
+ api_key=comet_ml_api_key,
153
+ project_name=project,
154
+ workspace=comet_ml_workspace
155
+ )
156
+ # Логируем гиперпараметры в Comet ML
157
+ hyper_params = {
158
+ "batch_size": batch_size,
159
+ "base_learning_rate": base_learning_rate,
160
+ "min_learning_rate": min_learning_rate,
161
+ "num_epochs": num_epochs,
162
+ "n_diffusion_steps": n_diffusion_steps,
163
+ "guidance_scale": guidance_scale,
164
+ "optimizer_type": optimizer_type,
165
+ "mixed_precision": mixed_precision,
166
+ }
167
+ comet_experiment.log_parameters(hyper_params)
168
 
169
  # Включение Flash Attention 2/SDPA
170
  torch.backends.cuda.enable_flash_sdp(True)
 
174
 
175
  # --------------------------- Загрузка моделей ---------------------------
176
  # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
177
+ vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to("cpu").eval()
 
 
 
 
 
178
 
179
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
180
  if shift_factor is None:
 
187
  latents_mean = getattr(vae.config, "latents_mean", None)
188
  latents_std = getattr(vae.config, "latents_std", None)
189
 
190
+
 
 
 
 
 
 
 
191
 
192
 
193
  class DistributedResolutionBatchSampler(Sampler):
 
329
  print(f"Ошибка при включении SDPA: {e}")
330
  unet.set_use_memory_efficient_attention_xformers(True)
331
 
 
 
 
 
 
 
 
332
  else:
333
  # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
334
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
 
431
  if torch.isnan(param).any() or torch.isinf(param).any():
432
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
433
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
 
 
 
 
434
 
435
  # --------------------------- Фиксированные семплы для генерации ---------------------------
436
  fixed_samples = get_fixed_samples_by_resolution(dataset)
 
443
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
444
  vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
445
 
 
446
 
447
  all_generated_images = []
448
  all_captions = []
 
460
  )
461
  current_latents = noise.clone()
462
 
463
+ if guidance_scale != 1:
464
  empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
465
  text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
466
  else:
467
  text_embeddings_batch = sample_text_embeddings
468
 
469
+ timesteps = torch.linspace(0, 1, n_diffusion_steps+1, device=device, dtype=sample_latents.dtype)
470
+ for i in range(0, n_diffusion_steps):
471
+ t_cur = timesteps[i].unsqueeze(0)
472
+ t_next = timesteps[i+1]
473
+ dt = t_next - t_cur
474
+ if guidance_scale != 1:
475
+ latent_model_input = torch.cat((current_latents, current_latents))
476
  else:
477
  latent_model_input = current_latents
478
+ t_batch = t_cur.repeat(latent_model_input.shape[0]).to(device)
479
+ t_batch = (t_batch * 1000).long().view(-1)
480
+ flow = original_model(latent_model_input, t_batch, text_embeddings_batch).sample
481
 
482
+ if guidance_scale != 1:
483
+ flow_uncond, flow_cond = flow.chunk(2)
484
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
 
 
 
485
 
486
+ current_latents = current_latents + flow * dt.to(device)
 
 
 
 
 
 
487
 
488
  # Параметры нормализации
489
+ latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
 
 
 
490
 
491
  decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
492
  #decoded = decoded[:, :, 0, :, :] # [3, H, W]
 
523
  wandb.Image(img, caption=f"{all_captions[i]}")
524
  for i, img in enumerate(all_generated_images)
525
  ]
526
+ wandb.log({"generated_images": wandb_images})
527
+ if use_comet_ml and accelerator.is_main_process:
528
+ for i, img in enumerate(all_generated_images):
529
+ comet_experiment.log_image(
530
+ image_data=img,
531
+ name=f"step_{step}_img_{i}",
532
+ step=step,
533
+ metadata={
534
+ "caption": all_captions[i],
535
+ "width": img.width,
536
+ "height": img.height,
537
+ "global_step": step
538
+ }
539
+ )
540
  finally:
541
  # вернуть VAE на CPU (как было в твоём коде)
542
  vae.to("cpu")
 
565
  accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
566
  unet = unet.to(dtype=dtype)
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  # --------------------------- Тренировочный цикл ---------------------------
569
  if accelerator.is_main_process:
570
  print(f"Total steps per GPU: {total_training_steps}")
 
578
 
579
  for epoch in range(start_epoch, start_epoch + num_epochs):
580
  batch_losses = []
 
581
  batch_grads = []
582
  batch_sampler.set_epoch(epoch)
583
  accelerator.wait_for_everyone()
584
  unet.train()
585
+ #print("epoch:",epoch)
586
  for step, (latents, embeddings) in enumerate(dataloader):
587
  with accelerator.accumulate(unet):
588
  if save_model == False and step == 5 :
 
591
 
592
  noise = torch.randn_like(latents, dtype=latents.dtype)
593
 
594
+ t = logit_normal_samples((batch_size, 1, 1, 1), mu=0.0, sigma=1.0, device=latents.device, dtype=latents.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
+ noisy_latents = (1 - t) * noise + t * latents
 
 
597
 
598
+ t_for_unet = (t * 1000).long().view(-1)
599
+ model_pred = unet(noisy_latents, t_for_unet, embeddings).sample
600
+ target_pred = latents - noise
601
 
602
  mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
  # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
605
  batch_losses.append(mse_loss.detach().item())
 
608
  accelerator.wait_for_everyone()
609
 
610
  # Backward
611
+ accelerator.backward(mse_loss)
612
 
613
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
614
  accelerator.wait_for_everyone()
 
632
  current_lr = base_learning_rate
633
  else:
634
  current_lr = lr_scheduler.get_last_lr()[0]
 
635
  batch_grads.append(grad)
636
 
 
 
637
  log_data = {}
638
+ log_data["loss"] = mse_loss.detach().item()
639
+ log_data["lr"] = current_lr
640
+ log_data["grad"] = grad
641
+ if accelerator.sync_gradients:
642
+ if use_wandb:
643
+ wandb.log(log_data, step=global_step)
644
+ if use_comet_ml:
645
+ comet_experiment.log_metrics(log_data, step=global_step)
 
 
 
 
646
 
647
  # Генерируем сэмплы с заданным интервалом
648
  if global_step % sample_interval == 0:
649
  generate_and_save_samples(fixed_samples,global_step)
650
  last_n = sample_interval
651
  avg_loss = float(np.mean(batch_losses[-last_n:])) if len(batch_losses) > 0 else 0.0
 
652
  avg_grad = float(np.mean(batch_grads[-last_n:])) if len(batch_grads) > 0 else 0.0
653
+
654
+ avg_data = {}
655
+ avg_data["avg_loss"] = avg_loss
656
+ avg_data["avg_grad"] = avg_grad
657
+
658
  print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}")
659
 
660
  if save_model:
 
663
  min_loss = avg_loss
664
  save_checkpoint(unet)
665
  if use_wandb:
 
 
 
 
666
  wandb.log(avg_data, step=global_step)
667
+ if use_comet_ml:
668
+ comet_experiment.log_metrics(avg_data, step=global_step)
669
+
670
 
671
  if accelerator.is_main_process:
672
  avg_epoch_loss = np.mean(batch_losses) if len(batch_losses)>0 else 0.0
673
  print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
674
  if use_wandb:
675
  wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
676
+ #if use_comet_ml:
677
+ # comet_experiment.log_metrics(epoch_data)
678
 
679
  # Завершение обучения - сохраняем финальную модель
 
 
680
  if accelerator.is_main_process:
681
  print("Обучение завершено! Сохраняем финальную модель...")
682
  if save_model:
683
  save_checkpoint(unet,"fp16")
684
+ if use_comet_ml:
685
+ comet_experiment.end()
686
  accelerator.free_memory()
687
  if torch.distributed.is_initialized():
688
  torch.distributed.destroy_process_group()
train_velocity.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader, Sampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ from collections import defaultdict
10
+ from torch.optim.lr_scheduler import LambdaLR
11
+ from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL, DDPMScheduler
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image,ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+
28
+ # --------------------------- Параметры ---------------------------
29
+ ds_path = "/workspace/sdxs3d/datasets/butterfly"
30
+ project = "unet"
31
+ batch_size = 16
32
+ base_learning_rate = 9e-5
33
+ min_learning_rate = 1e-5
34
+ num_epochs = 300
35
+ # samples/save per epoch
36
+ sample_interval_share = 1
37
+ use_wandb = True
38
+ save_model = True
39
+ use_decay = True
40
+ fbp = False # fused backward pass
41
+ optimizer_type = "adam8bit"
42
+ torch_compile = False
43
+ unet_gradient = True
44
+ clip_sample = False #Scheduler
45
+ fixed_seed = True
46
+ shuffle = True
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+ torch.backends.cudnn.allow_tf32 = True
49
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
50
+ dtype = torch.float32
51
+ save_barrier = 2.03 # TODO
52
+ warmup_percent = 0.01
53
+ dispersive_temperature=0.5
54
+ dispersive_weight= 0.05
55
+ percentile_clipping = 99 # 8bit optim
56
+ betta2 = 0.995
57
+ eps = 1e-8
58
+ clip_grad_norm = 1.0
59
+ steps_offset = 0 # Scheduler
60
+ limit = 0
61
+ checkpoints_folder = ""
62
+ mixed_precision = "no" #"fp16"
63
+ gradient_accumulation_steps = 1
64
+ accelerator = Accelerator(
65
+ mixed_precision=mixed_precision,
66
+ gradient_accumulation_steps=gradient_accumulation_steps
67
+ )
68
+ device = accelerator.device
69
+
70
+ # Параметры для диффузии
71
+ n_diffusion_steps = 50
72
+ samples_to_generate = 12
73
+ guidance_scale = 5
74
+
75
+ # Папки для сохранения результатов
76
+ generated_folder = "samples"
77
+ os.makedirs(generated_folder, exist_ok=True)
78
+
79
+ # Настройка seed для воспроизводимости
80
+ current_date = datetime.now()
81
+ seed = int(current_date.strftime("%Y%m%d"))
82
+ if fixed_seed:
83
+ torch.manual_seed(seed)
84
+ np.random.seed(seed)
85
+ random.seed(seed)
86
+ if torch.cuda.is_available():
87
+ torch.cuda.manual_seed_all(seed)
88
+
89
+ # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
90
+ # CHANGED: добавлен huber и dispersive в пропорции, суммы = 1.0
91
+ loss_ratios = {
92
+ "mse": 1.0,
93
+ "mae": 0.0,
94
+ "huber": 0.0,
95
+ "dispersive": 0.0,
96
+ }
97
+ median_coeff_steps = 128 # за сколько шагов считать медианные коэффициенты
98
+
99
+ # --------------------------- Параметры LoRA ---------------------------
100
+ lora_name = ""
101
+ lora_rank = 32
102
+ lora_alpha = 64
103
+
104
+ print("init")
105
+
106
+ # --------------------------- вспомогательные функции ---------------------------
107
+ def sample_timesteps_bias(
108
+ batch_size: int,
109
+ progress: float, # [0..1]
110
+ num_train_timesteps: int, # обычно 1000
111
+ steps_offset: int = 0,
112
+ device=None,
113
+ mode: str = "uniform", # "beta", "uniform"
114
+ ) -> torch.Tensor:
115
+ """
116
+ Возвращает timesteps с разным bias:
117
+ - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
118
+ - normal : около середины (гауссовое распределение)
119
+ - uniform: равномерно по всем timestep’ам
120
+ """
121
+
122
+ max_idx = num_train_timesteps - 1 - steps_offset
123
+
124
+ if mode == "beta":
125
+ alpha = 1.0 + .5 * (1.0 - progress)
126
+ beta = 1.0 + .5 * progress
127
+ samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
128
+
129
+ elif mode == "uniform":
130
+ samples = torch.rand(batch_size)
131
+
132
+ else:
133
+ raise ValueError(f"Unknown mode: {mode}")
134
+
135
+ timesteps = steps_offset + (samples * max_idx).long().to(device)
136
+ return timesteps
137
+
138
+
139
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
140
+ class MedianLossNormalizer:
141
+ def __init__(self, desired_ratios: dict, window_steps: int):
142
+ # нормируем доли на случай, если сумма != 1
143
+ s = sum(desired_ratios.values())
144
+ self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
145
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
146
+ self.window = window_steps
147
+
148
+ def update_and_total(self, losses: dict):
149
+ """
150
+ losses: dict ключ->тензор (значения лоссов)
151
+ Поведение:
152
+ - буферим ABS(l) только для активных (ratio>0) лоссов
153
+ - coeff = ratio / median(abs(loss))
154
+ - total = sum(coeff * loss) по активным лоссам
155
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
156
+ """
157
+ # буферим только активные лоссы
158
+ for k, v in losses.items():
159
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
160
+ self.buffers[k].append(float(v.detach().abs().cpu()))
161
+
162
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
163
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
164
+
165
+ # суммируем только по активным (ratio>0)
166
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
167
+ return total, coeffs, meds
168
+
169
+ # создаём normalizer после определения loss_ratios
170
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
171
+
172
+ class AccelerateDispersiveLoss:
173
+ def __init__(self, accelerator, temperature=0.5, weight=0.5):
174
+ self.accelerator = accelerator
175
+ self.temperature = temperature
176
+ self.weight = weight
177
+ self.activations = []
178
+ self.hooks = []
179
+
180
+ def register_hooks(self, model, target_layer="down_blocks.0"):
181
+ unwrapped_model = self.accelerator.unwrap_model(model)
182
+ print("=== Поиск слоев в unwrapped модели ===")
183
+ for name, module in unwrapped_model.named_modules():
184
+ if target_layer in name:
185
+ hook = module.register_forward_hook(self.hook_fn)
186
+ self.hooks.append(hook)
187
+ print(f"✅ Хук зарегистрирован на: {name}")
188
+ break
189
+
190
+ def hook_fn(self, module, input, output):
191
+ if isinstance(output, tuple):
192
+ activation = output[0]
193
+ else:
194
+ activation = output
195
+ if len(activation.shape) > 2:
196
+ activation = activation.view(activation.shape[0], -1)
197
+ self.activations.append(activation.detach().clone())
198
+
199
+ def compute_dispersive_loss(self):
200
+ if not self.activations:
201
+ return torch.tensor(0.0, requires_grad=True, device=device)
202
+ local_activations = self.activations[-1].float()
203
+ batch_size = local_activations.shape[0]
204
+ if batch_size < 2:
205
+ return torch.tensor(0.0, requires_grad=True, device=device)
206
+ sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
207
+ distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
208
+ exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
209
+ dispersive_loss = torch.log(torch.mean(exp_neg_dist))
210
+ return dispersive_loss
211
+
212
+ def clear_activations(self):
213
+ self.activations.clear()
214
+
215
+ def remove_hooks(self):
216
+ for hook in self.hooks:
217
+ hook.remove()
218
+ self.hooks.clear()
219
+
220
+
221
+ # --------------------------- Инициализация WandB ---------------------------
222
+ if use_wandb and accelerator.is_main_process:
223
+ wandb.init(project=project+lora_name, config={
224
+ "batch_size": batch_size,
225
+ "base_learning_rate": base_learning_rate,
226
+ "num_epochs": num_epochs,
227
+ "fbp": fbp,
228
+ "optimizer_type": optimizer_type,
229
+ })
230
+
231
+ # Включение Flash Attention 2/SDPA
232
+ torch.backends.cuda.enable_flash_sdp(True)
233
+ # --------------------------- Инициализация Accelerator --------------------
234
+ gen = torch.Generator(device=device)
235
+ gen.manual_seed(seed)
236
+
237
+ # --------------------------- Загрузка моделей ---------------------------
238
+ # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
239
+ #vae = AutoencoderKLWan.from_pretrained("vae", variant="fp16").to(device="cpu", dtype=torch.float16).eval()
240
+ #vae = AutoencoderKLWan.from_pretrained(
241
+ # "AiArtLab/simplevae", subfolder="wan16x_vae_nightly",
242
+ # torch_dtype=dtype
243
+ # ).to(device="cpu").eval()
244
+ vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to(device).eval()
245
+
246
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
247
+ if shift_factor is None:
248
+ shift_factor = 0.0
249
+
250
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
251
+ if scaling_factor is None:
252
+ scaling_factor = 1.0
253
+
254
+ latents_mean = getattr(vae.config, "latents_mean", None)
255
+ latents_std = getattr(vae.config, "latents_std", None)
256
+
257
+ # DDPMScheduler с V_Prediction и Zero-SNR
258
+ scheduler = DDPMScheduler(
259
+ num_train_timesteps=1000,
260
+ prediction_type="v_prediction",
261
+ rescale_betas_zero_snr=True,
262
+ clip_sample = clip_sample,
263
+ steps_offset = steps_offset
264
+ )
265
+
266
+
267
+ class DistributedResolutionBatchSampler(Sampler):
268
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
269
+ self.dataset = dataset
270
+ self.batch_size = max(1, batch_size // num_replicas)
271
+ self.num_replicas = num_replicas
272
+ self.rank = rank
273
+ self.shuffle = shuffle
274
+ self.drop_last = drop_last
275
+ self.epoch = 0
276
+
277
+ try:
278
+ widths = np.array(dataset["width"])
279
+ heights = np.array(dataset["height"])
280
+ except KeyError:
281
+ widths = np.zeros(len(dataset))
282
+ heights = np.zeros(len(dataset))
283
+
284
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
285
+ self.size_groups = {}
286
+ for w, h in self.size_keys:
287
+ mask = (widths == w) & (heights == h)
288
+ self.size_groups[(w, h)] = np.where(mask)[0]
289
+
290
+ self.group_num_batches = {}
291
+ total_batches = 0
292
+ for size, indices in self.size_groups.items():
293
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
294
+ self.group_num_batches[size] = num_full_batches
295
+ total_batches += num_full_batches
296
+
297
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
298
+
299
+ def __iter__(self):
300
+ if torch.cuda.is_available():
301
+ torch.cuda.empty_cache()
302
+ all_batches = []
303
+ rng = np.random.RandomState(self.epoch)
304
+
305
+ for size, indices in self.size_groups.items():
306
+ indices = indices.copy()
307
+ if self.shuffle:
308
+ rng.shuffle(indices)
309
+ num_full_batches = self.group_num_batches[size]
310
+ if num_full_batches == 0:
311
+ continue
312
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
313
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
314
+ start_idx = self.rank * self.batch_size
315
+ end_idx = start_idx + self.batch_size
316
+ gpu_batches = batches[:, start_idx:end_idx]
317
+ all_batches.extend(gpu_batches)
318
+
319
+ if self.shuffle:
320
+ rng.shuffle(all_batches)
321
+ accelerator.wait_for_everyone()
322
+ return iter(all_batches)
323
+
324
+ def __len__(self):
325
+ return self.num_batches
326
+
327
+ def set_epoch(self, epoch):
328
+ self.epoch = epoch
329
+
330
+ # Функция для выборки фиксированных семплов по размерам
331
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
332
+ size_groups = defaultdict(list)
333
+ try:
334
+ widths = dataset["width"]
335
+ heights = dataset["height"]
336
+ except KeyError:
337
+ widths = [0] * len(dataset)
338
+ heights = [0] * len(dataset)
339
+ for i, (w, h) in enumerate(zip(widths, heights)):
340
+ size = (w, h)
341
+ size_groups[size].append(i)
342
+
343
+ fixed_samples = {}
344
+ for size, indices in size_groups.items():
345
+ n_samples = min(samples_per_group, len(indices))
346
+ if len(size_groups)==1:
347
+ n_samples = samples_to_generate
348
+ if n_samples == 0:
349
+ continue
350
+ sample_indices = random.sample(indices, n_samples)
351
+ samples_data = [dataset[idx] for idx in sample_indices]
352
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
353
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
354
+ texts = [item["text"] for item in samples_data]
355
+ fixed_samples[size] = (latents, embeddings, texts)
356
+
357
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
358
+ return fixed_samples
359
+
360
+ if limit > 0:
361
+ dataset = load_from_disk(ds_path).select(range(limit))
362
+ else:
363
+ dataset = load_from_disk(ds_path)
364
+
365
+ def collate_fn_simple(batch):
366
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
367
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
368
+ return latents, embeddings
369
+
370
+ batch_sampler = DistributedResolutionBatchSampler(
371
+ dataset=dataset,
372
+ batch_size=batch_size,
373
+ num_replicas=accelerator.num_processes,
374
+ rank=accelerator.process_index,
375
+ shuffle=shuffle
376
+ )
377
+
378
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
379
+ print("Total samples",len(dataloader))
380
+ dataloader = accelerator.prepare(dataloader)
381
+
382
+ start_epoch = 0
383
+ global_step = 0
384
+ total_training_steps = (len(dataloader) * num_epochs)
385
+ world_size = accelerator.state.num_processes
386
+
387
+ # Опция загрузки модели из последнего чекпоинта (если существует)
388
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
389
+ if os.path.isdir(latest_checkpoint):
390
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
391
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
392
+ if torch_compile:
393
+ print("compiling")
394
+ torch.set_float32_matmul_precision('high')
395
+ unet = torch.compile(unet)
396
+ print("compiling - ok")
397
+ if unet_gradient:
398
+ unet.enable_gradient_checkpointing()
399
+ unet.set_use_memory_efficient_attention_xformers(False)
400
+ try:
401
+ unet.set_attn_processor(AttnProcessor2_0())
402
+ except Exception as e:
403
+ print(f"Ошибка при включении SDPA: {e}")
404
+ unet.set_use_memory_efficient_attention_xformers(True)
405
+
406
+ # Создаём hook для dispersive только если нужно
407
+ if loss_ratios.get("dispersive", 0) > 0:
408
+ dispersive_hook = AccelerateDispersiveLoss(
409
+ accelerator=accelerator,
410
+ temperature=dispersive_temperature,
411
+ weight=dispersive_weight
412
+ )
413
+ else:
414
+ # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
415
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
416
+
417
+ if lora_name:
418
+ print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
419
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
420
+ from peft.tuners.lora import LoraModel
421
+ import os
422
+ unet.requires_grad_(False)
423
+ print("Параметры базового UNet заморожены.")
424
+
425
+ lora_config = LoraConfig(
426
+ r=lora_rank,
427
+ lora_alpha=lora_alpha,
428
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
429
+ )
430
+ unet.add_adapter(lora_config)
431
+
432
+ from peft import get_peft_model
433
+ peft_unet = get_peft_model(unet, lora_config)
434
+ params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
435
+
436
+ if accelerator.is_main_process:
437
+ lora_params_count = sum(p.numel() for p in params_to_optimize)
438
+ total_params_count = sum(p.numel() for p in unet.parameters())
439
+ print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
440
+ print(f"Общее количество параметров UNet: {total_params_count:,}")
441
+
442
+ lora_save_path = os.path.join("lora", lora_name)
443
+ os.makedirs(lora_save_path, exist_ok=True)
444
+
445
+ def save_lora_checkpoint(model):
446
+ if accelerator.is_main_process:
447
+ print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
448
+ from peft.utils.save_and_load import get_peft_model_state_dict
449
+ lora_state_dict = get_peft_model_state_dict(model)
450
+ torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
451
+ model.peft_config["default"].save_pretrained(lora_save_path)
452
+ from diffusers import StableDiffusionXLPipeline
453
+ StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
454
+
455
+ # --------------------------- Оптимизатор ---------------------------
456
+ if lora_name:
457
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
458
+ else:
459
+ if fbp:
460
+ trainable_params = list(unet.parameters())
461
+
462
+ def create_optimizer(name, params):
463
+ if name == "adam8bit":
464
+ return bnb.optim.AdamW8bit(
465
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
466
+ percentile_clipping=percentile_clipping
467
+ )
468
+ elif name == "adam":
469
+ return torch.optim.AdamW(
470
+ params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
471
+ )
472
+ elif name == "lion8bit":
473
+ return bnb.optim.Lion8bit(
474
+ params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
475
+ percentile_clipping=percentile_clipping
476
+ )
477
+ elif name == "adafactor":
478
+ from transformers import Adafactor
479
+ return Adafactor(
480
+ params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
481
+ warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
482
+ beta1=0.9, weight_decay=0.01
483
+ )
484
+ else:
485
+ raise ValueError(f"Unknown optimizer: {name}")
486
+
487
+ if fbp:
488
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
489
+ def optimizer_hook(param):
490
+ optimizer_dict[param].step()
491
+ optimizer_dict[param].zero_grad(set_to_none=True)
492
+ for param in trainable_params:
493
+ param.register_post_accumulate_grad_hook(optimizer_hook)
494
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
495
+ else:
496
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
497
+ def lr_schedule(step):
498
+ x = step / (total_training_steps * world_size)
499
+ warmup = warmup_percent
500
+ if not use_decay:
501
+ return base_learning_rate
502
+ if x < warmup:
503
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
504
+ decay_ratio = (x - warmup) / (1 - warmup)
505
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
506
+ (1 + math.cos(math.pi * decay_ratio))
507
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
508
+
509
+ num_params = sum(p.numel() for p in unet.parameters())
510
+ print(f"[rank {accelerator.process_index}] total params: {num_params}")
511
+ for name, param in unet.named_parameters():
512
+ if torch.isnan(param).any() or torch.isinf(param).any():
513
+ print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
514
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
515
+
516
+ # Регистрация хуков ПОСЛЕ prepare
517
+ if loss_ratios.get("dispersive", 0) > 0:
518
+ dispersive_hook.register_hooks(unet, "down_blocks.2")
519
+
520
+ # --------------------------- Фиксированные семплы для генерации ---------------------------
521
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
522
+
523
+ @torch.compiler.disable()
524
+ @torch.no_grad()
525
+ def generate_and_save_samples(fixed_samples_cpu, step):
526
+ original_model = None
527
+ try:
528
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
529
+ vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
530
+
531
+ scheduler.set_timesteps(n_diffusion_steps)
532
+
533
+ all_generated_images = []
534
+ all_captions = []
535
+
536
+ for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
537
+ width, height = size
538
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
539
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
540
+
541
+ noise = torch.randn(
542
+ sample_latents.shape,
543
+ generator=gen,
544
+ device=device,
545
+ dtype=sample_latents.dtype
546
+ )
547
+ current_latents = noise.clone()
548
+
549
+ if guidance_scale > 0:
550
+ empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
551
+ text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
552
+ else:
553
+ text_embeddings_batch = sample_text_embeddings
554
+
555
+ for t in scheduler.timesteps:
556
+ t_batch = t.repeat(current_latents.shape[0]).to(device)
557
+ if guidance_scale > 0:
558
+ latent_model_input = torch.cat([current_latents] * 2)
559
+ else:
560
+ latent_model_input = current_latents
561
+
562
+ latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
563
+ noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
564
+
565
+ if guidance_scale > 0:
566
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
567
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
568
+
569
+ current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
570
+
571
+ #print(current_latents.ndim, current_latents.shape)
572
+ #if current_latents.ndim == 4:
573
+ # current_latents = current_latents.unsqueeze(2)
574
+ # Латент в форме [B, C, T, H, W]
575
+ #print(current_latents.ndim, current_latents.shape)
576
+
577
+ # Параметры нормализации
578
+ latent_for_vae = current_latents.detach() * scaling_factor + shift_factor
579
+
580
+ if latents_mean!=None and latents_std!=None:
581
+ latent_for_vae = latent_for_vae * torch.tensor(latents_std, device=device, dtype=dtype).view(1, -1, 1, 1, 1) + torch.tensor(latents_mean, device=device, dtype=dtype).view(1, -1, 1, 1, 1)
582
+
583
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
584
+ #decoded = decoded[:, :, 0, :, :] # [3, H, W]
585
+ #print(decoded.ndim, decoded.shape)
586
+
587
+ decoded_fp32 = decoded.to(torch.float32)
588
+ for img_idx, img_tensor in enumerate(decoded_fp32):
589
+
590
+ # Форма: [3, H, W] -> преобразуем в [H, W, 3]
591
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
592
+ img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
593
+
594
+ #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
595
+ if np.isnan(img).any():
596
+ print("NaNs found, saving stopped! Step:", step)
597
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
598
+
599
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
600
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
601
+ max_w_overall = max(255, max_w_overall)
602
+ max_h_overall = max(255, max_h_overall)
603
+
604
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
605
+ all_generated_images.append(padded_img)
606
+
607
+ caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
608
+ all_captions.append(caption_text)
609
+
610
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
611
+ pil_img.save(sample_path, "JPEG", quality=96)
612
+
613
+ if use_wandb and accelerator.is_main_process:
614
+ wandb_images = [
615
+ wandb.Image(img, caption=f"{all_captions[i]}")
616
+ for i, img in enumerate(all_generated_images)
617
+ ]
618
+ wandb.log({"generated_images": wandb_images, "global_step": step})
619
+ finally:
620
+ # вернуть VAE на CPU (как было в твоём коде)
621
+ vae.to("cpu")
622
+ for var in list(locals().keys()):
623
+ if isinstance(locals()[var], torch.Tensor):
624
+ del locals()[var]
625
+ torch.cuda.empty_cache()
626
+ gc.collect()
627
+
628
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
629
+ if accelerator.is_main_process:
630
+ if save_model:
631
+ print("Генерация сэмплов до старта обучения...")
632
+ generate_and_save_samples(fixed_samples,0)
633
+ accelerator.wait_for_everyone()
634
+
635
+ # Модифицируем функцию сохранения модели для поддержки LoRA
636
+ def save_checkpoint(unet,variant=""):
637
+ if accelerator.is_main_process:
638
+ if lora_name:
639
+ save_lora_checkpoint(unet)
640
+ else:
641
+ if variant!="":
642
+ accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
643
+ else:
644
+ accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
645
+ unet = unet.to(dtype=dtype)
646
+
647
+ def batch_pred_original_from_step(model_outputs, timesteps_tensor, noisy_latents, scheduler):
648
+ device = noisy_latents.device
649
+ dtype = noisy_latents.dtype
650
+
651
+ available_ts = scheduler.timesteps
652
+ if not isinstance(available_ts, torch.Tensor):
653
+ available_ts = torch.tensor(available_ts, device="cpu")
654
+ else:
655
+ available_ts = available_ts.cpu()
656
+
657
+ B = model_outputs.shape[0]
658
+ preds = []
659
+ for i in range(B):
660
+ t_i = int(timesteps_tensor[i].item())
661
+ diffs = torch.abs(available_ts - t_i)
662
+ idx = int(torch.argmin(diffs).item())
663
+ t_for_step = int(available_ts[idx].item())
664
+ model_out_i = model_outputs[i:i+1]
665
+ noisy_latent_i = noisy_latents[i:i+1]
666
+ step_out = scheduler.step(model_out_i, t_for_step, noisy_latent_i)
667
+ preds.append(step_out.pred_original_sample)
668
+
669
+ return torch.cat(preds, dim=0).to(device=device, dtype=dtype)
670
+
671
+ # --------------------------- Тренировочный цикл ---------------------------
672
+ if accelerator.is_main_process:
673
+ print(f"Total steps per GPU: {total_training_steps}")
674
+
675
+ epoch_loss_points = []
676
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
677
+
678
+ steps_per_epoch = len(dataloader)
679
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
680
+ min_loss = 1.
681
+
682
+ for epoch in range(start_epoch, start_epoch + num_epochs):
683
+ batch_losses = []
684
+ batch_tlosses = []
685
+ batch_grads = []
686
+ batch_sampler.set_epoch(epoch)
687
+ accelerator.wait_for_everyone()
688
+ unet.train()
689
+ print("epoch:",epoch)
690
+ for step, (latents, embeddings) in enumerate(dataloader):
691
+ with accelerator.accumulate(unet):
692
+ if save_model == False and step == 5 :
693
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
694
+ print(f"Шаг {step}: {used_gb:.2f} GB")
695
+
696
+ noise = torch.randn_like(latents, dtype=latents.dtype)
697
+
698
+ progress = global_step / max(1, total_training_steps - 1)
699
+ timesteps = sample_timesteps_bias(
700
+ batch_size=latents.shape[0],
701
+ progress=progress,
702
+ num_train_timesteps=scheduler.config.num_train_timesteps,
703
+ steps_offset=steps_offset,
704
+ device=device
705
+ )
706
+
707
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
708
+
709
+ if loss_ratios.get("dispersive", 0) > 0:
710
+ dispersive_hook.clear_activations()
711
+
712
+ #print(latents.shape,embeddings.shape)
713
+ model_pred = unet(noisy_latents, timesteps, embeddings).sample
714
+ target_pred = scheduler.get_velocity(latents, noise, timesteps)
715
+
716
+ # === Losses ===
717
+ losses_dict = {}
718
+
719
+ mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
720
+ losses_dict["mse"] = mse_loss
721
+ losses_dict["mae"] = F.l1_loss(model_pred.float(), target_pred.float())
722
+
723
+ # CHANGED: Huber (smooth_l1) loss added
724
+ losses_dict["huber"] = F.smooth_l1_loss(model_pred.float(), target_pred.float())
725
+
726
+ # === Dispersive loss ===
727
+ if loss_ratios.get("dispersive", 0) > 0:
728
+ disp_raw = dispersive_hook.compute_dispersive_loss().to(device) # может быть отрицательным
729
+ losses_dict["dispersive"] = dispersive_hook.weight * disp_raw
730
+ else:
731
+ losses_dict["dispersive"] = torch.tensor(0.0, device=device)
732
+
733
+ # === Нормализация всех лоссов ===
734
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
735
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
736
+
737
+ # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
738
+ batch_losses.append(mse_loss.detach().item())
739
+
740
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
741
+ accelerator.wait_for_everyone()
742
+
743
+ # Backward
744
+ accelerator.backward(total_loss)
745
+
746
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
747
+ accelerator.wait_for_everyone()
748
+
749
+ grad = 0.0
750
+ if not fbp:
751
+ if accelerator.sync_gradients:
752
+ with torch.amp.autocast('cuda', enabled=False):
753
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
754
+ grad = float(grad_val)
755
+ optimizer.step()
756
+ lr_scheduler.step()
757
+ optimizer.zero_grad(set_to_none=True)
758
+
759
+ global_step += 1
760
+ progress_bar.update(1)
761
+
762
+ # Логируем метрики
763
+ if accelerator.is_main_process:
764
+ if fbp:
765
+ current_lr = base_learning_rate
766
+ else:
767
+ current_lr = lr_scheduler.get_last_lr()[0]
768
+ batch_tlosses.append(total_loss.detach().item())
769
+ batch_grads.append(grad)
770
+
771
+ # Логируем только активные лоссы (ratio>0)
772
+ active_keys = [k for k, v in loss_ratios.items() if v > 0]
773
+ log_data = {}
774
+ for k in active_keys:
775
+ v = losses_dict.get(k, None)
776
+ if v is None:
777
+ continue
778
+ log_data[f"loss/{k}"] = (v.item() if isinstance(v, torch.Tensor) else float(v))
779
+
780
+ log_data["loss/total"] = float(total_loss.item())
781
+ log_data["loss/lr"] = current_lr
782
+ for k, c in coeffs.items():
783
+ log_data[f"coeff/{k}"] = float(c)
784
+ if use_wandb and accelerator.sync_gradients:
785
+ wandb.log(log_data, step=global_step)
786
+
787
+ # Генерируем сэмплы с заданным интервалом
788
+ if global_step % sample_interval == 0:
789
+ generate_and_save_samples(fixed_samples,global_step)
790
+ last_n = sample_interval
791
+ avg_loss = float(np.mean(batch_losses[-last_n:])) if len(batch_losses) > 0 else 0.0
792
+ avg_tloss = float(np.mean(batch_tlosses[-last_n:])) if len(batch_tlosses) > 0 else 0.0
793
+ avg_grad = float(np.mean(batch_grads[-last_n:])) if len(batch_grads) > 0 else 0.0
794
+ print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}")
795
+
796
+ if save_model:
797
+ print("saving:",avg_loss < min_loss*save_barrier)
798
+ if avg_loss < min_loss*save_barrier:
799
+ min_loss = avg_loss
800
+ save_checkpoint(unet)
801
+ if use_wandb:
802
+ avg_data = {}
803
+ avg_data["avg/loss"] = avg_loss
804
+ avg_data["avg/tloss"] = avg_tloss
805
+ avg_data["avg/grad"] = avg_grad
806
+ wandb.log(avg_data, step=global_step)
807
+
808
+ if accelerator.is_main_process:
809
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses)>0 else 0.0
810
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
811
+ if use_wandb:
812
+ wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
813
+
814
+ # Завершение обучения - сохраняем финальную модель
815
+ if loss_ratios.get("dispersive", 0) > 0:
816
+ dispersive_hook.remove_hooks()
817
+ if accelerator.is_main_process:
818
+ print("Обучение завершено! Сохраняем финальную модель...")
819
+ if save_model:
820
+ save_checkpoint(unet,"fp16")
821
+ accelerator.free_memory()
822
+ if torch.distributed.is_initialized():
823
+ torch.distributed.destroy_process_group()
824
+
825
+ print("Готово!")
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:23e75f6769d03f650a4673248b33bfd2af15f07d4e8bc75bbc8eca6420f33a8d
3
  size 3092571208
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15dbbdb970577298a421e4424ea6ff535526b6369a73decc5eac24b2401487d6
3
  size 3092571208