recoilme commited on
Commit
8da9419
·
1 Parent(s): 833b212

undo vae unet

Browse files
samples/unet_192x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 1ea618c4d7e1d3702d5d28dad56eaeb97e50c32f3f62430ab7436b522a3bf3f6
  • Pointer size: 130 Bytes
  • Size of remote file: 35.9 kB

Git LFS Details

  • SHA256: bcb173072314fbf5a824e178af40e04b9c750aa0bc7e754f082019c137f6bf22
  • Pointer size: 130 Bytes
  • Size of remote file: 23.2 kB
samples/unet_256x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 4aeeeefcde43a381f65a82fa265cdb2aa126f321712148c010ddef276bca21b7
  • Pointer size: 130 Bytes
  • Size of remote file: 55.3 kB

Git LFS Details

  • SHA256: 749e4fe4839117d7832a7f8ef24f370d4ed7ac3f361335f18ddee906ea5aa748
  • Pointer size: 130 Bytes
  • Size of remote file: 42.4 kB
samples/unet_320x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 2f788c2916d532e9e0746990fea5491396823643058429671d67d049148fb6a5
  • Pointer size: 130 Bytes
  • Size of remote file: 54.8 kB

Git LFS Details

  • SHA256: 6eb9a0e045f66be7d1b9f001fb9bb116a4758d9f0f1168c07a6dcd9428d83c6b
  • Pointer size: 130 Bytes
  • Size of remote file: 59.9 kB
samples/unet_384x192_0.jpg CHANGED

Git LFS Details

  • SHA256: 8554cc23e82658aaca86123477ad5fa10b7e1014027b36aa76590ffbc6a7c9ad
  • Pointer size: 130 Bytes
  • Size of remote file: 39 kB

Git LFS Details

  • SHA256: a59d870d69412bd49a6deae0453e882b08c0f162d5a6c491688ea080785e9cf0
  • Pointer size: 130 Bytes
  • Size of remote file: 25.6 kB
samples/unet_384x256_0.jpg CHANGED

Git LFS Details

  • SHA256: d5de510fb64c94daf6aafb9d7069c345c34ca5f14a59a7ee8e35db0f2c7d96fc
  • Pointer size: 130 Bytes
  • Size of remote file: 61.6 kB

Git LFS Details

  • SHA256: 2060a9dea0fe09e96f695ea1d0903cd3ed0bc797046a2778cd2f027165e4c311
  • Pointer size: 130 Bytes
  • Size of remote file: 50.4 kB
samples/unet_384x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 135743d6e4efe5613b25df42bcf8569928c20a4665782730edb28d7310b844fa
  • Pointer size: 130 Bytes
  • Size of remote file: 59.4 kB

Git LFS Details

  • SHA256: cb0804fb7cd387e93e32c99977d41d1643f2351f2e48dac4e720edf1d0416c63
  • Pointer size: 130 Bytes
  • Size of remote file: 32.1 kB
samples/unet_384x384_0.jpg CHANGED

Git LFS Details

  • SHA256: b4c2d04b25fb69ddf17a72f4aabc7cbd0c7951f3bb3806a1cbb2e2d195bceafe
  • Pointer size: 130 Bytes
  • Size of remote file: 67.4 kB

Git LFS Details

  • SHA256: 77ac1339c455c90c8e23e76733b07b7d58c989dea893e4263f1b711d2d4ac8bd
  • Pointer size: 130 Bytes
  • Size of remote file: 90.3 kB
src/model_create.ipynb CHANGED
@@ -603,7 +603,7 @@
603
  {
604
  "cell_type": "code",
605
  "execution_count": null,
606
- "id": "1cb4ff0f-36cc-43cf-86a4-aaab9f106725",
607
  "metadata": {},
608
  "outputs": [],
609
  "source": []
 
603
  {
604
  "cell_type": "code",
605
  "execution_count": null,
606
+ "id": "0a1d32f6-23ea-4f6f-b2e8-9584a0c12a0d",
607
  "metadata": {},
608
  "outputs": [],
609
  "source": []
train.py CHANGED
@@ -7,11 +7,12 @@ from torch.utils.data import DataLoader, Sampler
7
  from torch.utils.data.distributed import DistributedSampler
8
  from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
- from diffusers import UNet2DConditionModel, AutoencoderKL
 
11
  from accelerate import Accelerator
12
  from datasets import load_from_disk
13
  from tqdm import tqdm
14
- from PIL import Image, ImageOps
15
  import wandb
16
  import random
17
  import gc
@@ -30,7 +31,7 @@ project = "unet"
30
  batch_size = 128
31
  base_learning_rate = 9e-5
32
  min_learning_rate = 1e-5
33
- num_epochs = 84
34
  # samples/save per epoch
35
  sample_interval_share = 5
36
  use_wandb = True
@@ -44,8 +45,8 @@ unet_gradient = True
44
  clip_sample = False #Scheduler
45
  fixed_seed = True
46
  shuffle = True
47
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
48
- comet_ml_workspace = "recoilme"
49
  torch.backends.cuda.matmul.allow_tf32 = True
50
  torch.backends.cudnn.allow_tf32 = True
51
  torch.backends.cuda.enable_mem_efficient_sdp(False)
@@ -59,7 +60,7 @@ clip_grad_norm = 1.0
59
  steps_offset = 0 # Scheduler
60
  limit = 0
61
  checkpoints_folder = ""
62
- mixed_precision = "no" # "fp16"
63
  gradient_accumulation_steps = 1
64
  accelerator = Accelerator(
65
  mixed_precision=mixed_precision,
@@ -87,7 +88,7 @@ if fixed_seed:
87
  torch.cuda.manual_seed_all(seed)
88
 
89
  # --------------------------- Параметры LoRA ---------------------------
90
- lora_name = ""
91
  lora_rank = 32
92
  lora_alpha = 64
93
 
@@ -102,6 +103,13 @@ def sample_timesteps_bias(
102
  device=None,
103
  mode: str = "beta", # "beta", "uniform"
104
  ) -> torch.Tensor:
 
 
 
 
 
 
 
105
  max_idx = num_train_timesteps - 1 - steps_offset
106
 
107
  if mode == "beta":
@@ -118,16 +126,17 @@ def sample_timesteps_bias(
118
  timesteps = steps_offset + (samples * max_idx).long().to(device)
119
  return timesteps
120
 
121
-
122
  def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
123
  normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
 
124
  logit_normal_samples = torch.sigmoid(normal_samples)
 
125
  return logit_normal_samples
126
 
127
  # --------------------------- Инициализация WandB ---------------------------
128
  if accelerator.is_main_process:
129
  if use_wandb:
130
- wandb.init(project=project + lora_name, config={
131
  "batch_size": batch_size,
132
  "base_learning_rate": base_learning_rate,
133
  "num_epochs": num_epochs,
@@ -141,6 +150,7 @@ if accelerator.is_main_process:
141
  project_name=project,
142
  workspace=comet_ml_workspace
143
  )
 
144
  hyper_params = {
145
  "batch_size": batch_size,
146
  "base_learning_rate": base_learning_rate,
@@ -160,6 +170,7 @@ gen = torch.Generator(device=device)
160
  gen.manual_seed(seed)
161
 
162
  # --------------------------- Загрузка моделей ---------------------------
 
163
  vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
164
 
165
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
@@ -174,6 +185,8 @@ latents_mean = getattr(vae.config, "latents_mean", None)
174
  latents_std = getattr(vae.config, "latents_std", None)
175
 
176
 
 
 
177
  class DistributedResolutionBatchSampler(Sampler):
178
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
179
  self.dataset = dataset
@@ -253,16 +266,14 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
253
  fixed_samples = {}
254
  for size, indices in size_groups.items():
255
  n_samples = min(samples_per_group, len(indices))
256
- if len(size_groups) == 1:
257
  n_samples = samples_to_generate
258
  if n_samples == 0:
259
  continue
260
  sample_indices = random.sample(indices, n_samples)
261
  samples_data = [dataset[idx] for idx in sample_indices]
262
-
263
- # FIXED: keep fixed samples on CPU to avoid device/device-transfer issues when creating dataset
264
- latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).cpu()
265
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).cpu()
266
  texts = [item["text"] for item in samples_data]
267
  fixed_samples[size] = (latents, embeddings, texts)
268
 
@@ -274,10 +285,9 @@ if limit > 0:
274
  else:
275
  dataset = load_from_disk(ds_path)
276
 
277
-
278
  def collate_fn_simple(batch):
279
- latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
280
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
281
  return latents, embeddings
282
 
283
  batch_sampler = DistributedResolutionBatchSampler(
@@ -288,23 +298,20 @@ batch_sampler = DistributedResolutionBatchSampler(
288
  shuffle=shuffle
289
  )
290
 
291
- # NOTE: we create dataloader first, then prepare it with accelerator. We'll create optimizer/lr_scheduler after
292
- # we know len(dataloader) (which is per-process after prepare) so that scheduling is consistent.
293
-
294
  dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
295
- print("Total batches (pre-prepare):", len(dataloader))
296
  dataloader = accelerator.prepare(dataloader)
297
 
298
- # --------------------------- Теперь безопасно --- вычисляем шаги и создаём оптимизатор/шедулер ---------------------------
299
- steps_per_epoch = len(dataloader) # this is per-process (after prepare)
300
- total_training_steps = steps_per_epoch * num_epochs
301
- print(f"[rank {accelerator.process_index}] steps_per_epoch={steps_per_epoch}, total_training_steps={total_training_steps}")
302
 
303
- # --------------------------- Создание/загрузка UNet ---------------------------
304
  latest_checkpoint = os.path.join(checkpoints_folder, project)
305
  if os.path.isdir(latest_checkpoint):
306
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
307
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
308
  if torch_compile:
309
  print("compiling")
310
  torch.set_float32_matmul_precision('high')
@@ -318,14 +325,16 @@ if os.path.isdir(latest_checkpoint):
318
  except Exception as e:
319
  print(f"Ошибка при включении SDPA: {e}")
320
  unet.set_use_memory_efficient_attention_xformers(True)
 
321
  else:
 
322
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
323
 
324
- # --------------------------- LoRA (если нужно) ---------------------------
325
  if lora_name:
326
  print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
327
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
328
  from peft.tuners.lora import LoraModel
 
329
  unet.requires_grad_(False)
330
  print("Параметры базового UNet заморожены.")
331
 
@@ -363,8 +372,8 @@ if lora_name:
363
  if lora_name:
364
  trainable_params = [p for p in unet.parameters() if p.requires_grad]
365
  else:
366
- trainable_params = list(unet.parameters()) if fbp else [p for p in unet.parameters() if p.requires_grad]
367
-
368
 
369
  def create_optimizer(name, params):
370
  if name == "adam8bit":
@@ -376,18 +385,6 @@ def create_optimizer(name, params):
376
  return torch.optim.AdamW(
377
  params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
378
  )
379
- elif name == "lion8bit":
380
- return bnb.optim.Lion8bit(
381
- params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
382
- percentile_clipping=percentile_clipping
383
- )
384
- elif name == "adafactor":
385
- from transformers import Adafactor
386
- return Adafactor(
387
- params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
388
- warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
389
- beta1=0.9, weight_decay=0.01
390
- )
391
  else:
392
  raise ValueError(f"Unknown optimizer: {name}")
393
 
@@ -398,15 +395,11 @@ if fbp:
398
  optimizer_dict[param].zero_grad(set_to_none=True)
399
  for param in trainable_params:
400
  param.register_post_accumulate_grad_hook(optimizer_hook)
401
- # FIXED: prepare fbp variant (keeps original logic)
402
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
403
  else:
404
- optimizer = create_optimizer(optimizer_type, trainable_params)
405
-
406
- # FIXED: LR schedule should be based on total_training_steps (per-process steps * epochs)
407
  def lr_schedule(step):
408
- # step is current scheduler step (0..total_training_steps)
409
- x = step / max(1, total_training_steps)
410
  warmup = warmup_percent
411
  if not use_decay:
412
  return base_learning_rate
@@ -415,8 +408,6 @@ else:
415
  decay_ratio = (x - warmup) / (1 - warmup)
416
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
417
  (1 + math.cos(math.pi * decay_ratio))
418
-
419
- # LambdaLR expects a multiplier, so divide by base_learning_rate
420
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
421
 
422
  num_params = sum(p.numel() for p in unet.parameters())
@@ -424,13 +415,12 @@ else:
424
  for name, param in unet.named_parameters():
425
  if torch.isnan(param).any() or torch.isinf(param).any():
426
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
427
-
428
- # FIXED: prepare model, optimizer, scheduler AFTER creating them and after dataloader.prepare
429
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
430
 
431
  # --------------------------- Фиксированные семплы для генерации ---------------------------
432
  fixed_samples = get_fixed_samples_by_resolution(dataset)
433
 
 
434
  @torch.no_grad()
435
  def generate_and_save_samples(fixed_samples_cpu, step):
436
  original_model = None
@@ -438,12 +428,12 @@ def generate_and_save_samples(fixed_samples_cpu, step):
438
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
439
  vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
440
 
 
441
  all_generated_images = []
442
  all_captions = []
443
 
444
  for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
445
  width, height = size
446
- # move CPU tensors to device here (they were kept on CPU in get_fixed_samples_by_resolution)
447
  sample_latents = sample_latents.to(dtype=dtype, device=device)
448
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
449
 
@@ -482,11 +472,19 @@ def generate_and_save_samples(fixed_samples_cpu, step):
482
 
483
  # Параметры нормализации
484
  latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
 
485
  decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
 
 
 
486
  decoded_fp32 = decoded.to(torch.float32)
487
  for img_idx, img_tensor in enumerate(decoded_fp32):
 
 
488
  img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
489
- img = img.transpose(1, 2, 0)
 
 
490
  if np.isnan(img).any():
491
  print("NaNs found, saving stopped! Step:", step)
492
  pil_img = Image.fromarray((img * 255).astype("uint8"))
@@ -510,7 +508,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
510
  wandb.Image(img, caption=f"{all_captions[i]}")
511
  for i, img in enumerate(all_generated_images)
512
  ]
513
- wandb.log({"generated_images": wandb_images}, step=step)
514
  if use_comet_ml and accelerator.is_main_process:
515
  for i, img in enumerate(all_generated_images):
516
  comet_experiment.log_image(
@@ -525,7 +523,11 @@ def generate_and_save_samples(fixed_samples_cpu, step):
525
  }
526
  )
527
  finally:
 
528
  vae.to("cpu")
 
 
 
529
  torch.cuda.empty_cache()
530
  gc.collect()
531
 
@@ -533,23 +535,20 @@ def generate_and_save_samples(fixed_samples_cpu, step):
533
  if accelerator.is_main_process:
534
  if save_model:
535
  print("Генерация сэмплов до старта обучения...")
536
- generate_and_save_samples(fixed_samples, 0)
537
  accelerator.wait_for_everyone()
538
 
539
  # Модиф��цируем функцию сохранения модели для поддержки LoRA
540
- def save_checkpoint(unet, variant=""):
541
  if accelerator.is_main_process:
542
  if lora_name:
543
  save_lora_checkpoint(unet)
544
  else:
545
- # FIXED: don't change dtype of model wrapped by accelerator. Unwrap and save as-is.
546
- model_to_save = accelerator.unwrap_model(unet)
547
- dest = os.path.join(checkpoints_folder, f"{project}")
548
- if variant != "":
549
- model_to_save.save_pretrained(dest, variant=variant)
550
  else:
551
- model_to_save.save_pretrained(dest)
552
-
553
 
554
  # --------------------------- Тренировочный цикл ---------------------------
555
  if accelerator.is_main_process:
@@ -558,25 +557,20 @@ if accelerator.is_main_process:
558
  epoch_loss_points = []
559
  progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
560
 
 
561
  sample_interval = max(1, steps_per_epoch // sample_interval_share)
562
  min_loss = 2.
563
 
564
- for epoch in range(0, num_epochs):
565
- # FIXED: set epoch on the dataloader's batch_sampler if available (accelerator may have wrapped it)
566
- if hasattr(dataloader, "batch_sampler") and hasattr(dataloader.batch_sampler, "set_epoch"):
567
- dataloader.batch_sampler.set_epoch(epoch)
568
- elif hasattr(batch_sampler, "set_epoch"):
569
- batch_sampler.set_epoch(epoch)
570
-
571
- accelerator.wait_for_everyone()
572
- unet.train()
573
-
574
  batch_losses = []
575
  batch_grads = []
576
-
 
 
 
577
  for step, (latents, embeddings) in enumerate(dataloader):
578
  with accelerator.accumulate(unet):
579
- if save_model == False and step == 5:
580
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
581
  print(f"Шаг {step}: {used_gb:.2f} GB")
582
 
@@ -590,94 +584,83 @@ for epoch in range(0, num_epochs):
590
 
591
  mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
592
 
593
- # Сохраняем для локальных логов
594
- batch_losses.append(mse_loss.detach().cpu().item())
595
 
 
 
 
596
  # Backward
597
  accelerator.backward(mse_loss)
598
 
599
- grad_norm_val = 0.0
 
 
 
600
  if not fbp:
601
  if accelerator.sync_gradients:
602
- # Clip gradients and step only when gradients are synchronized (i.e. actual optimizer step)
603
- grad_norm = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
604
- grad_norm_val = float(grad_norm)
605
  optimizer.step()
606
  lr_scheduler.step()
607
  optimizer.zero_grad(set_to_none=True)
608
 
609
- # increment global_step only when we have synchronized gradients (i.e. on optimizer step)
610
- # FIXED: ensure global_step reflects optimizer updates, not micro-batches
611
  if accelerator.sync_gradients:
612
- try:
613
- global_step += 1
614
- except NameError:
615
- global_step = 1
616
  progress_bar.update(1)
617
-
618
- # Aggregate loss across processes for correct logging
619
- loss_tensor = mse_loss.detach().clone()
620
- # move to device if not already
621
- loss_tensor = loss_tensor.to(device)
622
- gathered = accelerator.gather(loss_tensor)
623
- if accelerator.is_main_process:
624
- reduced_loss = gathered.mean().item()
625
- else:
626
- reduced_loss = None
627
-
628
- # Логируем метрики только на главном процессе
629
  if accelerator.is_main_process:
630
  if fbp:
631
  current_lr = base_learning_rate
632
  else:
633
  current_lr = lr_scheduler.get_last_lr()[0]
634
-
635
- batch_grads.append(grad_norm_val)
636
-
637
- log_data = {
638
- "loss": reduced_loss,
639
- "lr": current_lr,
640
- "grad": grad_norm_val,
641
- "epoch": epoch + 1,
642
- "global_step": global_step,
643
- }
644
-
645
- if use_wandb:
646
- wandb.log(log_data, step=global_step)
647
- if use_comet_ml:
648
- comet_experiment.log_metrics(log_data, step=global_step)
649
-
650
- # Генерируем сэмплы с заданным интервалом (только на главном процессе)
651
  if global_step % sample_interval == 0:
652
- generate_and_save_samples(fixed_samples, global_step)
653
-
 
654
  if save_model:
655
- # use recent local losses to decide saving (still local); you may want to use reduced_loss here
656
  avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
657
- if use_wandb:
658
- wandb.log({"sample_loss": avg_sample_loss)
659
- print("saving:", reduced_loss is not None and reduced_loss < min_loss * save_barrier, "Current:", reduced_loss, "Avg:", avg_sample_loss)
660
- if reduced_loss is not None and reduced_loss < min_loss * save_barrier:
661
  min_loss = avg_sample_loss
662
  save_checkpoint(unet)
663
 
664
- # Эпоха окончена — агрегируем и логируем средние значения
665
  if accelerator.is_main_process:
666
  # local averages
667
  avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
668
  avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
669
 
670
- # FIXED: optionally reduce across processes if you want a true global epoch average
671
- # Here we compute local values and log them (main process only). For global average use accelerator.gather
672
- print(f"\nЭпоха {epoch} завершена. Средний лосс (local main proc): {avg_epoch_loss:.6f}")
 
 
 
673
  if use_wandb:
674
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch_grad": avg_epoch_grad, "epoch": epoch + 1})
 
 
675
 
676
  # Завершение обучения - сохраняем финальную модель
677
  if accelerator.is_main_process:
678
  print("Обучение завершено! Сохраняем финальную модель...")
679
  if save_model:
680
- save_checkpoint(unet, "fp16")
681
  if use_comet_ml:
682
  comet_experiment.end()
683
  accelerator.free_memory()
 
7
  from torch.utils.data.distributed import DistributedSampler
8
  from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
+ from torch.optim.lr_scheduler import LambdaLR
11
+ from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
15
+ from PIL import Image,ImageOps
16
  import wandb
17
  import random
18
  import gc
 
31
  batch_size = 128
32
  base_learning_rate = 9e-5
33
  min_learning_rate = 1e-5
34
+ num_epochs = 63
35
  # samples/save per epoch
36
  sample_interval_share = 5
37
  use_wandb = True
 
45
  clip_sample = False #Scheduler
46
  fixed_seed = True
47
  shuffle = True
48
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
+ comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
  torch.backends.cuda.enable_mem_efficient_sdp(False)
 
60
  steps_offset = 0 # Scheduler
61
  limit = 0
62
  checkpoints_folder = ""
63
+ mixed_precision = "no" #"fp16"
64
  gradient_accumulation_steps = 1
65
  accelerator = Accelerator(
66
  mixed_precision=mixed_precision,
 
88
  torch.cuda.manual_seed_all(seed)
89
 
90
  # --------------------------- Параметры LoRA ---------------------------
91
+ lora_name = ""
92
  lora_rank = 32
93
  lora_alpha = 64
94
 
 
103
  device=None,
104
  mode: str = "beta", # "beta", "uniform"
105
  ) -> torch.Tensor:
106
+ """
107
+ Возвращает timesteps с разным bias:
108
+ - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
109
+ - normal : около середины (гауссовое распределение)
110
+ - uniform: равномерно по всем timestep’ам
111
+ """
112
+
113
  max_idx = num_train_timesteps - 1 - steps_offset
114
 
115
  if mode == "beta":
 
126
  timesteps = steps_offset + (samples * max_idx).long().to(device)
127
  return timesteps
128
 
 
129
  def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
  normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
+
132
  logit_normal_samples = torch.sigmoid(normal_samples)
133
+
134
  return logit_normal_samples
135
 
136
  # --------------------------- Инициализация WandB ---------------------------
137
  if accelerator.is_main_process:
138
  if use_wandb:
139
+ wandb.init(project=project+lora_name, config={
140
  "batch_size": batch_size,
141
  "base_learning_rate": base_learning_rate,
142
  "num_epochs": num_epochs,
 
150
  project_name=project,
151
  workspace=comet_ml_workspace
152
  )
153
+ # Логируем гиперпараметры в Comet ML
154
  hyper_params = {
155
  "batch_size": batch_size,
156
  "base_learning_rate": base_learning_rate,
 
170
  gen.manual_seed(seed)
171
 
172
  # --------------------------- Загрузка моделей ---------------------------
173
+ # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
174
  vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
175
 
176
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
 
185
  latents_std = getattr(vae.config, "latents_std", None)
186
 
187
 
188
+
189
+
190
  class DistributedResolutionBatchSampler(Sampler):
191
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
192
  self.dataset = dataset
 
266
  fixed_samples = {}
267
  for size, indices in size_groups.items():
268
  n_samples = min(samples_per_group, len(indices))
269
+ if len(size_groups)==1:
270
  n_samples = samples_to_generate
271
  if n_samples == 0:
272
  continue
273
  sample_indices = random.sample(indices, n_samples)
274
  samples_data = [dataset[idx] for idx in sample_indices]
275
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
276
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
 
 
277
  texts = [item["text"] for item in samples_data]
278
  fixed_samples[size] = (latents, embeddings, texts)
279
 
 
285
  else:
286
  dataset = load_from_disk(ds_path)
287
 
 
288
  def collate_fn_simple(batch):
289
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
290
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
291
  return latents, embeddings
292
 
293
  batch_sampler = DistributedResolutionBatchSampler(
 
298
  shuffle=shuffle
299
  )
300
 
 
 
 
301
  dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
302
+ print("Total samples",len(dataloader))
303
  dataloader = accelerator.prepare(dataloader)
304
 
305
+ start_epoch = 0
306
+ global_step = 0
307
+ total_training_steps = (len(dataloader) * num_epochs)
308
+ world_size = accelerator.state.num_processes
309
 
310
+ # Опция загрузки модели из последнего чекпоинта (если существует)
311
  latest_checkpoint = os.path.join(checkpoints_folder, project)
312
  if os.path.isdir(latest_checkpoint):
313
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
314
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
315
  if torch_compile:
316
  print("compiling")
317
  torch.set_float32_matmul_precision('high')
 
325
  except Exception as e:
326
  print(f"Ошибка при включении SDPA: {e}")
327
  unet.set_use_memory_efficient_attention_xformers(True)
328
+
329
  else:
330
+ # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
331
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
332
 
 
333
  if lora_name:
334
  print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
335
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
336
  from peft.tuners.lora import LoraModel
337
+ import os
338
  unet.requires_grad_(False)
339
  print("Параметры базового UNet заморожены.")
340
 
 
372
  if lora_name:
373
  trainable_params = [p for p in unet.parameters() if p.requires_grad]
374
  else:
375
+ if fbp:
376
+ trainable_params = list(unet.parameters())
377
 
378
  def create_optimizer(name, params):
379
  if name == "adam8bit":
 
385
  return torch.optim.AdamW(
386
  params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
387
  )
 
 
 
 
 
 
 
 
 
 
 
 
388
  else:
389
  raise ValueError(f"Unknown optimizer: {name}")
390
 
 
395
  optimizer_dict[param].zero_grad(set_to_none=True)
396
  for param in trainable_params:
397
  param.register_post_accumulate_grad_hook(optimizer_hook)
 
398
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
399
  else:
400
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
 
 
401
  def lr_schedule(step):
402
+ x = step / (total_training_steps * world_size)
 
403
  warmup = warmup_percent
404
  if not use_decay:
405
  return base_learning_rate
 
408
  decay_ratio = (x - warmup) / (1 - warmup)
409
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
410
  (1 + math.cos(math.pi * decay_ratio))
 
 
411
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
412
 
413
  num_params = sum(p.numel() for p in unet.parameters())
 
415
  for name, param in unet.named_parameters():
416
  if torch.isnan(param).any() or torch.isinf(param).any():
417
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
 
 
418
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
419
 
420
  # --------------------------- Фиксированные семплы для генерации ---------------------------
421
  fixed_samples = get_fixed_samples_by_resolution(dataset)
422
 
423
+ @torch.compiler.disable()
424
  @torch.no_grad()
425
  def generate_and_save_samples(fixed_samples_cpu, step):
426
  original_model = None
 
428
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
429
  vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
430
 
431
+
432
  all_generated_images = []
433
  all_captions = []
434
 
435
  for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
436
  width, height = size
 
437
  sample_latents = sample_latents.to(dtype=dtype, device=device)
438
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
439
 
 
472
 
473
  # Параметры нормализации
474
  latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
475
+
476
  decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
477
+ #decoded = decoded[:, :, 0, :, :] # [3, H, W]
478
+ #print(decoded.ndim, decoded.shape)
479
+
480
  decoded_fp32 = decoded.to(torch.float32)
481
  for img_idx, img_tensor in enumerate(decoded_fp32):
482
+
483
+ # Форма: [3, H, W] -> преобразуем в [H, W, 3]
484
  img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
485
+ img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
486
+
487
+ #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
488
  if np.isnan(img).any():
489
  print("NaNs found, saving stopped! Step:", step)
490
  pil_img = Image.fromarray((img * 255).astype("uint8"))
 
508
  wandb.Image(img, caption=f"{all_captions[i]}")
509
  for i, img in enumerate(all_generated_images)
510
  ]
511
+ wandb.log({"generated_images": wandb_images})
512
  if use_comet_ml and accelerator.is_main_process:
513
  for i, img in enumerate(all_generated_images):
514
  comet_experiment.log_image(
 
523
  }
524
  )
525
  finally:
526
+ # вернуть VAE на CPU (как было в твоём коде)
527
  vae.to("cpu")
528
+ for var in list(locals().keys()):
529
+ if isinstance(locals()[var], torch.Tensor):
530
+ del locals()[var]
531
  torch.cuda.empty_cache()
532
  gc.collect()
533
 
 
535
  if accelerator.is_main_process:
536
  if save_model:
537
  print("Генерация сэмплов до старта обучения...")
538
+ generate_and_save_samples(fixed_samples,0)
539
  accelerator.wait_for_everyone()
540
 
541
  # Модиф��цируем функцию сохранения модели для поддержки LoRA
542
+ def save_checkpoint(unet,variant=""):
543
  if accelerator.is_main_process:
544
  if lora_name:
545
  save_lora_checkpoint(unet)
546
  else:
547
+ if variant!="":
548
+ accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
 
 
 
549
  else:
550
+ accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
551
+ unet = unet.to(dtype=dtype)
552
 
553
  # --------------------------- Тренировочный цикл ---------------------------
554
  if accelerator.is_main_process:
 
557
  epoch_loss_points = []
558
  progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
559
 
560
+ steps_per_epoch = len(dataloader)
561
  sample_interval = max(1, steps_per_epoch // sample_interval_share)
562
  min_loss = 2.
563
 
564
+ for epoch in range(start_epoch, start_epoch + num_epochs):
 
 
 
 
 
 
 
 
 
565
  batch_losses = []
566
  batch_grads = []
567
+ batch_sampler.set_epoch(epoch)
568
+ accelerator.wait_for_everyone()
569
+ unet.train()
570
+ #print("epoch:",epoch)
571
  for step, (latents, embeddings) in enumerate(dataloader):
572
  with accelerator.accumulate(unet):
573
+ if save_model == False and step == 5 :
574
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
575
  print(f"Шаг {step}: {used_gb:.2f} GB")
576
 
 
584
 
585
  mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
586
 
587
+ # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
588
+ batch_losses.append(mse_loss.detach().item())
589
 
590
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
591
+ accelerator.wait_for_everyone()
592
+
593
  # Backward
594
  accelerator.backward(mse_loss)
595
 
596
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
597
+ accelerator.wait_for_everyone()
598
+
599
+ grad = 0.0
600
  if not fbp:
601
  if accelerator.sync_gradients:
602
+ with torch.amp.autocast('cuda', enabled=False):
603
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
604
+ grad = float(grad_val)
605
  optimizer.step()
606
  lr_scheduler.step()
607
  optimizer.zero_grad(set_to_none=True)
608
 
 
 
609
  if accelerator.sync_gradients:
610
+ global_step += 1
 
 
 
611
  progress_bar.update(1)
612
+ # Логируем метрики
 
 
 
 
 
 
 
 
 
 
 
613
  if accelerator.is_main_process:
614
  if fbp:
615
  current_lr = base_learning_rate
616
  else:
617
  current_lr = lr_scheduler.get_last_lr()[0]
618
+ batch_grads.append(grad)
619
+
620
+ log_data = {}
621
+ log_data["loss"] = mse_loss.detach().item()
622
+ log_data["lr"] = current_lr
623
+ log_data["grad"] = grad
624
+ if accelerator.sync_gradients:
625
+ if use_wandb:
626
+ wandb.log(log_data, step=global_step)
627
+ if use_comet_ml:
628
+ comet_experiment.log_metrics(log_data, step=global_step)
629
+
630
+ # Генерируем сэмплы с заданным интервалом
 
 
 
 
631
  if global_step % sample_interval == 0:
632
+ generate_and_save_samples(fixed_samples,global_step)
633
+ last_n = sample_interval
634
+
635
  if save_model:
 
636
  avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
637
+ print("saving:", avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
638
+ if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
 
 
639
  min_loss = avg_sample_loss
640
  save_checkpoint(unet)
641
 
642
+
643
  if accelerator.is_main_process:
644
  # local averages
645
  avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
646
  avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
647
 
648
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
649
+ log_data_ep = {
650
+ "epoch_loss": avg_epoch_loss,
651
+ "epoch_grad": avg_epoch_grad,
652
+ "epoch": epoch + 1,
653
+ }
654
  if use_wandb:
655
+ wandb.log(log_data_ep)
656
+ if use_comet_ml:
657
+ comet_experiment.log_metrics(log_data_ep)
658
 
659
  # Завершение обучения - сохраняем финальную модель
660
  if accelerator.is_main_process:
661
  print("Обучение завершено! Сохраняем финальную модель...")
662
  if save_model:
663
+ save_checkpoint(unet,"fp16")
664
  if use_comet_ml:
665
  comet_experiment.end()
666
  accelerator.free_memory()
train-Copy1.py → train_chatgpt.py RENAMED
@@ -7,12 +7,11 @@ from torch.utils.data import DataLoader, Sampler
7
  from torch.utils.data.distributed import DistributedSampler
8
  from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
  import wandb
17
  import random
18
  import gc
@@ -26,17 +25,17 @@ import torch.nn.functional as F
26
  from collections import deque
27
 
28
  # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs3d/datasets/mjnj"
30
  project = "unet"
31
  batch_size = 128
32
  base_learning_rate = 9e-5
33
  min_learning_rate = 1e-5
34
- num_epochs = 84
35
  # samples/save per epoch
36
- sample_interval_share = 5
37
  use_wandb = True
38
  use_comet_ml = False
39
- save_model = True
40
  use_decay = True
41
  fbp = False # fused backward pass
42
  optimizer_type = "adam8bit"
@@ -45,8 +44,8 @@ unet_gradient = True
45
  clip_sample = False #Scheduler
46
  fixed_seed = True
47
  shuffle = True
48
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
- comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
  torch.backends.cuda.enable_mem_efficient_sdp(False)
@@ -60,7 +59,7 @@ clip_grad_norm = 1.0
60
  steps_offset = 0 # Scheduler
61
  limit = 0
62
  checkpoints_folder = ""
63
- mixed_precision = "no" #"fp16"
64
  gradient_accumulation_steps = 1
65
  accelerator = Accelerator(
66
  mixed_precision=mixed_precision,
@@ -88,7 +87,7 @@ if fixed_seed:
88
  torch.cuda.manual_seed_all(seed)
89
 
90
  # --------------------------- Параметры LoRA ---------------------------
91
- lora_name = ""
92
  lora_rank = 32
93
  lora_alpha = 64
94
 
@@ -103,13 +102,6 @@ def sample_timesteps_bias(
103
  device=None,
104
  mode: str = "beta", # "beta", "uniform"
105
  ) -> torch.Tensor:
106
- """
107
- Возвращает timesteps с разным bias:
108
- - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
109
- - normal : около середины (гауссовое распределение)
110
- - uniform: равномерно по всем timestep’ам
111
- """
112
-
113
  max_idx = num_train_timesteps - 1 - steps_offset
114
 
115
  if mode == "beta":
@@ -126,17 +118,16 @@ def sample_timesteps_bias(
126
  timesteps = steps_offset + (samples * max_idx).long().to(device)
127
  return timesteps
128
 
 
129
  def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
  normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
-
132
  logit_normal_samples = torch.sigmoid(normal_samples)
133
-
134
  return logit_normal_samples
135
 
136
  # --------------------------- Инициализация WandB ---------------------------
137
  if accelerator.is_main_process:
138
  if use_wandb:
139
- wandb.init(project=project+lora_name, config={
140
  "batch_size": batch_size,
141
  "base_learning_rate": base_learning_rate,
142
  "num_epochs": num_epochs,
@@ -150,7 +141,6 @@ if accelerator.is_main_process:
150
  project_name=project,
151
  workspace=comet_ml_workspace
152
  )
153
- # Логируем гиперпараметры в Comet ML
154
  hyper_params = {
155
  "batch_size": batch_size,
156
  "base_learning_rate": base_learning_rate,
@@ -170,8 +160,7 @@ gen = torch.Generator(device=device)
170
  gen.manual_seed(seed)
171
 
172
  # --------------------------- Загрузка моделей ---------------------------
173
- # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
174
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to("cpu").eval()
175
 
176
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
177
  if shift_factor is None:
@@ -185,8 +174,6 @@ latents_mean = getattr(vae.config, "latents_mean", None)
185
  latents_std = getattr(vae.config, "latents_std", None)
186
 
187
 
188
-
189
-
190
  class DistributedResolutionBatchSampler(Sampler):
191
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
192
  self.dataset = dataset
@@ -266,14 +253,16 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
266
  fixed_samples = {}
267
  for size, indices in size_groups.items():
268
  n_samples = min(samples_per_group, len(indices))
269
- if len(size_groups)==1:
270
  n_samples = samples_to_generate
271
  if n_samples == 0:
272
  continue
273
  sample_indices = random.sample(indices, n_samples)
274
  samples_data = [dataset[idx] for idx in sample_indices]
275
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
276
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
 
 
277
  texts = [item["text"] for item in samples_data]
278
  fixed_samples[size] = (latents, embeddings, texts)
279
 
@@ -285,9 +274,10 @@ if limit > 0:
285
  else:
286
  dataset = load_from_disk(ds_path)
287
 
 
288
  def collate_fn_simple(batch):
289
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
290
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
291
  return latents, embeddings
292
 
293
  batch_sampler = DistributedResolutionBatchSampler(
@@ -298,20 +288,23 @@ batch_sampler = DistributedResolutionBatchSampler(
298
  shuffle=shuffle
299
  )
300
 
 
 
 
301
  dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
302
- print("Total samples",len(dataloader))
303
  dataloader = accelerator.prepare(dataloader)
304
 
305
- start_epoch = 0
306
- global_step = 0
307
- total_training_steps = (len(dataloader) * num_epochs)
308
- world_size = accelerator.state.num_processes
309
 
310
- # Опция загрузки модели из последнего чекпоинта (если существует)
311
  latest_checkpoint = os.path.join(checkpoints_folder, project)
312
  if os.path.isdir(latest_checkpoint):
313
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
314
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
315
  if torch_compile:
316
  print("compiling")
317
  torch.set_float32_matmul_precision('high')
@@ -325,16 +318,14 @@ if os.path.isdir(latest_checkpoint):
325
  except Exception as e:
326
  print(f"Ошибка при включении SDPA: {e}")
327
  unet.set_use_memory_efficient_attention_xformers(True)
328
-
329
  else:
330
- # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
331
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
332
 
 
333
  if lora_name:
334
  print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
335
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
336
  from peft.tuners.lora import LoraModel
337
- import os
338
  unet.requires_grad_(False)
339
  print("Параметры базового UNet заморожены.")
340
 
@@ -372,8 +363,8 @@ if lora_name:
372
  if lora_name:
373
  trainable_params = [p for p in unet.parameters() if p.requires_grad]
374
  else:
375
- if fbp:
376
- trainable_params = list(unet.parameters())
377
 
378
  def create_optimizer(name, params):
379
  if name == "adam8bit":
@@ -407,11 +398,15 @@ if fbp:
407
  optimizer_dict[param].zero_grad(set_to_none=True)
408
  for param in trainable_params:
409
  param.register_post_accumulate_grad_hook(optimizer_hook)
 
410
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
411
  else:
412
- optimizer = create_optimizer(optimizer_type, unet.parameters())
 
 
413
  def lr_schedule(step):
414
- x = step / (total_training_steps * world_size)
 
415
  warmup = warmup_percent
416
  if not use_decay:
417
  return base_learning_rate
@@ -420,6 +415,8 @@ else:
420
  decay_ratio = (x - warmup) / (1 - warmup)
421
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
422
  (1 + math.cos(math.pi * decay_ratio))
 
 
423
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
424
 
425
  num_params = sum(p.numel() for p in unet.parameters())
@@ -427,12 +424,13 @@ else:
427
  for name, param in unet.named_parameters():
428
  if torch.isnan(param).any() or torch.isinf(param).any():
429
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
 
 
430
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
431
 
432
  # --------------------------- Фиксированные семплы для генерации ---------------------------
433
  fixed_samples = get_fixed_samples_by_resolution(dataset)
434
 
435
- @torch.compiler.disable()
436
  @torch.no_grad()
437
  def generate_and_save_samples(fixed_samples_cpu, step):
438
  original_model = None
@@ -440,12 +438,12 @@ def generate_and_save_samples(fixed_samples_cpu, step):
440
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
441
  vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
442
 
443
-
444
  all_generated_images = []
445
  all_captions = []
446
 
447
  for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
448
  width, height = size
 
449
  sample_latents = sample_latents.to(dtype=dtype, device=device)
450
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
451
 
@@ -484,19 +482,11 @@ def generate_and_save_samples(fixed_samples_cpu, step):
484
 
485
  # Параметры нормализации
486
  latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
487
-
488
  decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
489
- #decoded = decoded[:, :, 0, :, :] # [3, H, W]
490
- #print(decoded.ndim, decoded.shape)
491
-
492
  decoded_fp32 = decoded.to(torch.float32)
493
  for img_idx, img_tensor in enumerate(decoded_fp32):
494
-
495
- # Форма: [3, H, W] -> преобразуем в [H, W, 3]
496
  img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
497
- img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
498
-
499
- #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
500
  if np.isnan(img).any():
501
  print("NaNs found, saving stopped! Step:", step)
502
  pil_img = Image.fromarray((img * 255).astype("uint8"))
@@ -520,7 +510,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
520
  wandb.Image(img, caption=f"{all_captions[i]}")
521
  for i, img in enumerate(all_generated_images)
522
  ]
523
- wandb.log({"generated_images": wandb_images})
524
  if use_comet_ml and accelerator.is_main_process:
525
  for i, img in enumerate(all_generated_images):
526
  comet_experiment.log_image(
@@ -535,11 +525,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
535
  }
536
  )
537
  finally:
538
- # вернуть VAE на CPU (как было в твоём коде)
539
  vae.to("cpu")
540
- for var in list(locals().keys()):
541
- if isinstance(locals()[var], torch.Tensor):
542
- del locals()[var]
543
  torch.cuda.empty_cache()
544
  gc.collect()
545
 
@@ -547,20 +533,22 @@ def generate_and_save_samples(fixed_samples_cpu, step):
547
  if accelerator.is_main_process:
548
  if save_model:
549
  print("Генерация сэмплов до старта обучения...")
550
- generate_and_save_samples(fixed_samples,0)
551
  accelerator.wait_for_everyone()
552
 
553
  # Модифицируем функцию сохранения модели для поддержки LoRA
554
- def save_checkpoint(unet,variant=""):
555
  if accelerator.is_main_process:
556
  if lora_name:
557
  save_lora_checkpoint(unet)
558
  else:
559
- if variant!="":
560
- accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
 
 
561
  else:
562
- accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
563
- unet = unet.to(dtype=dtype)
564
 
565
  # --------------------------- Тренировочный цикл ---------------------------
566
  if accelerator.is_main_process:
@@ -569,20 +557,26 @@ if accelerator.is_main_process:
569
  epoch_loss_points = []
570
  progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
571
 
572
- steps_per_epoch = len(dataloader)
573
  sample_interval = max(1, steps_per_epoch // sample_interval_share)
574
  min_loss = 2.
 
 
 
 
 
 
 
 
575
 
576
- for epoch in range(start_epoch, start_epoch + num_epochs):
577
- batch_losses = []
578
- batch_grads = []
579
- batch_sampler.set_epoch(epoch)
580
  accelerator.wait_for_everyone()
581
  unet.train()
582
- #print("epoch:",epoch)
 
 
 
583
  for step, (latents, embeddings) in enumerate(dataloader):
584
  with accelerator.accumulate(unet):
585
- if save_model == False and step == 5 :
586
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
587
  print(f"Шаг {step}: {used_gb:.2f} GB")
588
 
@@ -596,80 +590,85 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
596
 
597
  mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
598
 
599
- # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
600
- batch_losses.append(mse_loss.detach().item())
601
 
602
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
603
- accelerator.wait_for_everyone()
604
-
605
  # Backward
606
  accelerator.backward(mse_loss)
607
 
608
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
609
- accelerator.wait_for_everyone()
610
-
611
- grad = 0.0
612
  if not fbp:
613
  if accelerator.sync_gradients:
614
- with torch.amp.autocast('cuda', enabled=False):
615
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
616
- grad = float(grad_val)
617
  optimizer.step()
618
  lr_scheduler.step()
619
  optimizer.zero_grad(set_to_none=True)
620
 
621
- global_step += 1
622
- progress_bar.update(1)
623
-
624
- # Логируем метрики
625
- if accelerator.is_main_process:
626
- if fbp:
627
- current_lr = base_learning_rate
628
- else:
629
- current_lr = lr_scheduler.get_last_lr()[0]
630
- batch_grads.append(grad)
 
 
 
 
 
 
 
 
 
 
 
 
631
 
632
- log_data = {}
633
- log_data["loss"] = mse_loss.detach().item()
634
- log_data["lr"] = current_lr
635
- log_data["grad"] = grad
636
- if accelerator.sync_gradients:
637
  if use_wandb:
638
  wandb.log(log_data, step=global_step)
639
  if use_comet_ml:
640
  comet_experiment.log_metrics(log_data, step=global_step)
641
 
642
- # Генерируем сэмплы с заданным интервалом
643
- if global_step % sample_interval == 0:
644
- generate_and_save_samples(fixed_samples,global_step)
645
- last_n = sample_interval
646
-
647
- if save_model:
648
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses)>0 else 0.0
649
- print("saving:",avg_sample_loss < min_loss*save_barrier,"Current:",log_data["loss"],"Avg:",avg_sample_loss)
650
- if log_data["loss"] < min_loss*save_barrier:
651
- min_loss = avg_sample_loss
652
- save_checkpoint(unet)
653
- if use_wandb:
654
- wandb.log(log_data, step=global_step)
655
- if use_comet_ml:
656
- comet_experiment.log_metrics(log_data, step=global_step)
657
 
 
 
 
 
 
 
 
 
658
 
 
659
  if accelerator.is_main_process:
660
- avg_epoch_loss = np.mean(batch_losses[-steps_per_epoch:]) if len(batch_losses)>0 else 0.0
661
- avg_epoch_grad = np.mean(batch_grads[-steps_per_epoch:]) if len(batch_grads)>0 else 0.0
 
 
662
  print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
 
 
 
 
 
663
  if use_wandb:
664
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch_grad": avg_epoch_grad, "epoch": epoch+1})
665
- #if use_comet_ml:
666
- # comet_experiment.log_metrics(epoch_data)
667
 
668
  # Завершение обучения - сохраняем финальную модель
669
  if accelerator.is_main_process:
670
  print("Обучение завершено! Сохраняем финальную модель...")
671
  if save_model:
672
- save_checkpoint(unet,"fp16")
673
  if use_comet_ml:
674
  comet_experiment.end()
675
  accelerator.free_memory()
 
7
  from torch.utils.data.distributed import DistributedSampler
8
  from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
+ from diffusers import UNet2DConditionModel, AutoencoderKL
 
11
  from accelerate import Accelerator
12
  from datasets import load_from_disk
13
  from tqdm import tqdm
14
+ from PIL import Image, ImageOps
15
  import wandb
16
  import random
17
  import gc
 
25
  from collections import deque
26
 
27
  # --------------------------- Параметры ---------------------------
28
+ ds_path = "/workspace/sdxs3d/datasets/butterfly"
29
  project = "unet"
30
  batch_size = 128
31
  base_learning_rate = 9e-5
32
  min_learning_rate = 1e-5
33
+ num_epochs = 100
34
  # samples/save per epoch
35
+ sample_interval_share = 1
36
  use_wandb = True
37
  use_comet_ml = False
38
+ save_model = False
39
  use_decay = True
40
  fbp = False # fused backward pass
41
  optimizer_type = "adam8bit"
 
44
  clip_sample = False #Scheduler
45
  fixed_seed = True
46
  shuffle = True
47
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
48
+ comet_ml_workspace = "recoilme"
49
  torch.backends.cuda.matmul.allow_tf32 = True
50
  torch.backends.cudnn.allow_tf32 = True
51
  torch.backends.cuda.enable_mem_efficient_sdp(False)
 
59
  steps_offset = 0 # Scheduler
60
  limit = 0
61
  checkpoints_folder = ""
62
+ mixed_precision = "no" # "fp16"
63
  gradient_accumulation_steps = 1
64
  accelerator = Accelerator(
65
  mixed_precision=mixed_precision,
 
87
  torch.cuda.manual_seed_all(seed)
88
 
89
  # --------------------------- Параметры LoRA ---------------------------
90
+ lora_name = ""
91
  lora_rank = 32
92
  lora_alpha = 64
93
 
 
102
  device=None,
103
  mode: str = "beta", # "beta", "uniform"
104
  ) -> torch.Tensor:
 
 
 
 
 
 
 
105
  max_idx = num_train_timesteps - 1 - steps_offset
106
 
107
  if mode == "beta":
 
118
  timesteps = steps_offset + (samples * max_idx).long().to(device)
119
  return timesteps
120
 
121
+
122
  def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
123
  normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
 
124
  logit_normal_samples = torch.sigmoid(normal_samples)
 
125
  return logit_normal_samples
126
 
127
  # --------------------------- Инициализация WandB ---------------------------
128
  if accelerator.is_main_process:
129
  if use_wandb:
130
+ wandb.init(project=project + lora_name, config={
131
  "batch_size": batch_size,
132
  "base_learning_rate": base_learning_rate,
133
  "num_epochs": num_epochs,
 
141
  project_name=project,
142
  workspace=comet_ml_workspace
143
  )
 
144
  hyper_params = {
145
  "batch_size": batch_size,
146
  "base_learning_rate": base_learning_rate,
 
160
  gen.manual_seed(seed)
161
 
162
  # --------------------------- Загрузка моделей ---------------------------
163
+ vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
 
164
 
165
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
166
  if shift_factor is None:
 
174
  latents_std = getattr(vae.config, "latents_std", None)
175
 
176
 
 
 
177
  class DistributedResolutionBatchSampler(Sampler):
178
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
179
  self.dataset = dataset
 
253
  fixed_samples = {}
254
  for size, indices in size_groups.items():
255
  n_samples = min(samples_per_group, len(indices))
256
+ if len(size_groups) == 1:
257
  n_samples = samples_to_generate
258
  if n_samples == 0:
259
  continue
260
  sample_indices = random.sample(indices, n_samples)
261
  samples_data = [dataset[idx] for idx in sample_indices]
262
+
263
+ # FIXED: keep fixed samples on CPU to avoid device/device-transfer issues when creating dataset
264
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).cpu()
265
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).cpu()
266
  texts = [item["text"] for item in samples_data]
267
  fixed_samples[size] = (latents, embeddings, texts)
268
 
 
274
  else:
275
  dataset = load_from_disk(ds_path)
276
 
277
+
278
  def collate_fn_simple(batch):
279
+ latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
280
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
281
  return latents, embeddings
282
 
283
  batch_sampler = DistributedResolutionBatchSampler(
 
288
  shuffle=shuffle
289
  )
290
 
291
+ # NOTE: we create dataloader first, then prepare it with accelerator. We'll create optimizer/lr_scheduler after
292
+ # we know len(dataloader) (which is per-process after prepare) so that scheduling is consistent.
293
+
294
  dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
295
+ print("Total batches (pre-prepare):", len(dataloader))
296
  dataloader = accelerator.prepare(dataloader)
297
 
298
+ # --------------------------- Теперь безопасно --- вычисляем шаги и создаём оптимизатор/шедулер ---------------------------
299
+ steps_per_epoch = len(dataloader) # this is per-process (after prepare)
300
+ total_training_steps = steps_per_epoch * num_epochs
301
+ print(f"[rank {accelerator.process_index}] steps_per_epoch={steps_per_epoch}, total_training_steps={total_training_steps}")
302
 
303
+ # --------------------------- Создание/загрузка UNet ---------------------------
304
  latest_checkpoint = os.path.join(checkpoints_folder, project)
305
  if os.path.isdir(latest_checkpoint):
306
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
307
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
308
  if torch_compile:
309
  print("compiling")
310
  torch.set_float32_matmul_precision('high')
 
318
  except Exception as e:
319
  print(f"Ошибка при включении SDPA: {e}")
320
  unet.set_use_memory_efficient_attention_xformers(True)
 
321
  else:
 
322
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
323
 
324
+ # --------------------------- LoRA (если нужно) ---------------------------
325
  if lora_name:
326
  print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
327
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
328
  from peft.tuners.lora import LoraModel
 
329
  unet.requires_grad_(False)
330
  print("Параметры базового UNet заморожены.")
331
 
 
363
  if lora_name:
364
  trainable_params = [p for p in unet.parameters() if p.requires_grad]
365
  else:
366
+ trainable_params = list(unet.parameters()) if fbp else [p for p in unet.parameters() if p.requires_grad]
367
+
368
 
369
  def create_optimizer(name, params):
370
  if name == "adam8bit":
 
398
  optimizer_dict[param].zero_grad(set_to_none=True)
399
  for param in trainable_params:
400
  param.register_post_accumulate_grad_hook(optimizer_hook)
401
+ # FIXED: prepare fbp variant (keeps original logic)
402
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
403
  else:
404
+ optimizer = create_optimizer(optimizer_type, trainable_params)
405
+
406
+ # FIXED: LR schedule should be based on total_training_steps (per-process steps * epochs)
407
  def lr_schedule(step):
408
+ # step is current scheduler step (0..total_training_steps)
409
+ x = step / max(1, total_training_steps)
410
  warmup = warmup_percent
411
  if not use_decay:
412
  return base_learning_rate
 
415
  decay_ratio = (x - warmup) / (1 - warmup)
416
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
417
  (1 + math.cos(math.pi * decay_ratio))
418
+
419
+ # LambdaLR expects a multiplier, so divide by base_learning_rate
420
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
421
 
422
  num_params = sum(p.numel() for p in unet.parameters())
 
424
  for name, param in unet.named_parameters():
425
  if torch.isnan(param).any() or torch.isinf(param).any():
426
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
427
+
428
+ # FIXED: prepare model, optimizer, scheduler AFTER creating them and after dataloader.prepare
429
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
430
 
431
  # --------------------------- Фиксированные семплы для генерации ---------------------------
432
  fixed_samples = get_fixed_samples_by_resolution(dataset)
433
 
 
434
  @torch.no_grad()
435
  def generate_and_save_samples(fixed_samples_cpu, step):
436
  original_model = None
 
438
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
439
  vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
440
 
 
441
  all_generated_images = []
442
  all_captions = []
443
 
444
  for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
445
  width, height = size
446
+ # move CPU tensors to device here (they were kept on CPU in get_fixed_samples_by_resolution)
447
  sample_latents = sample_latents.to(dtype=dtype, device=device)
448
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
449
 
 
482
 
483
  # Параметры нормализации
484
  latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
 
485
  decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
 
 
 
486
  decoded_fp32 = decoded.to(torch.float32)
487
  for img_idx, img_tensor in enumerate(decoded_fp32):
 
 
488
  img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
489
+ img = img.transpose(1, 2, 0)
 
 
490
  if np.isnan(img).any():
491
  print("NaNs found, saving stopped! Step:", step)
492
  pil_img = Image.fromarray((img * 255).astype("uint8"))
 
510
  wandb.Image(img, caption=f"{all_captions[i]}")
511
  for i, img in enumerate(all_generated_images)
512
  ]
513
+ wandb.log({"generated_images": wandb_images}, step=step)
514
  if use_comet_ml and accelerator.is_main_process:
515
  for i, img in enumerate(all_generated_images):
516
  comet_experiment.log_image(
 
525
  }
526
  )
527
  finally:
 
528
  vae.to("cpu")
 
 
 
529
  torch.cuda.empty_cache()
530
  gc.collect()
531
 
 
533
  if accelerator.is_main_process:
534
  if save_model:
535
  print("Генерация сэмплов до старта обучения...")
536
+ generate_and_save_samples(fixed_samples, 0)
537
  accelerator.wait_for_everyone()
538
 
539
  # Модифицируем функцию сохранения модели для поддержки LoRA
540
+ def save_checkpoint(unet, variant=""):
541
  if accelerator.is_main_process:
542
  if lora_name:
543
  save_lora_checkpoint(unet)
544
  else:
545
+ model_to_save = accelerator.unwrap_model(unet)
546
+ dest = os.path.join(checkpoints_folder, f"{project}")
547
+ if variant != "":
548
+ model_to_save.to(dtype=torch.float16).save_pretrained(dest, variant=variant)
549
  else:
550
+ model_to_save.save_pretrained(dest)
551
+
552
 
553
  # --------------------------- Тренировочный цикл ---------------------------
554
  if accelerator.is_main_process:
 
557
  epoch_loss_points = []
558
  progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
559
 
 
560
  sample_interval = max(1, steps_per_epoch // sample_interval_share)
561
  min_loss = 2.
562
+ global_step = 0
563
+
564
+ for epoch in range(0, num_epochs):
565
+ # FIXED: set epoch on the dataloader's batch_sampler if available (accelerator may have wrapped it)
566
+ if hasattr(dataloader, "batch_sampler") and hasattr(dataloader.batch_sampler, "set_epoch"):
567
+ dataloader.batch_sampler.set_epoch(epoch)
568
+ elif hasattr(batch_sampler, "set_epoch"):
569
+ batch_sampler.set_epoch(epoch)
570
 
 
 
 
 
571
  accelerator.wait_for_everyone()
572
  unet.train()
573
+
574
+ batch_losses = []
575
+ batch_grads = []
576
+
577
  for step, (latents, embeddings) in enumerate(dataloader):
578
  with accelerator.accumulate(unet):
579
+ if save_model == False and step == 5:
580
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
581
  print(f"Шаг {step}: {used_gb:.2f} GB")
582
 
 
590
 
591
  mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
592
 
593
+ # Сохраняем для локальных логов
594
+ batch_losses.append(mse_loss.detach().cpu().item())
595
 
 
 
 
596
  # Backward
597
  accelerator.backward(mse_loss)
598
 
599
+ grad_norm_val = 0.0
 
 
 
600
  if not fbp:
601
  if accelerator.sync_gradients:
602
+ # Clip gradients and step only when gradients are synchronized (i.e. actual optimizer step)
603
+ grad_norm = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
604
+ grad_norm_val = float(grad_norm)
605
  optimizer.step()
606
  lr_scheduler.step()
607
  optimizer.zero_grad(set_to_none=True)
608
 
609
+ # increment global_step only when we have synchronized gradients (i.e. on optimizer step)
610
+ # FIXED: ensure global_step reflects optimizer updates, not micro-batches
611
+ if accelerator.sync_gradients:
612
+ global_step += 1
613
+ progress_bar.update(1)
614
+
615
+ # Логируем метрики только на главном процессе
616
+ if accelerator.is_main_process:
617
+ if fbp:
618
+ current_lr = base_learning_rate
619
+ else:
620
+ current_lr = lr_scheduler.get_last_lr()[0]
621
+
622
+ batch_grads.append(grad_norm_val)
623
+
624
+ log_data = {
625
+ "loss": mse_loss.detach().item(),
626
+ "lr": current_lr,
627
+ "grad": grad_norm_val,
628
+ "epoch": epoch + 1,
629
+ "global_step": global_step,
630
+ }
631
 
 
 
 
 
 
632
  if use_wandb:
633
  wandb.log(log_data, step=global_step)
634
  if use_comet_ml:
635
  comet_experiment.log_metrics(log_data, step=global_step)
636
 
637
+ # Генерируем сэмплы с заданным интервалом (только на главном процессе)
638
+ if global_step % sample_interval == 0:
639
+ generate_and_save_samples(fixed_samples, global_step)
 
 
 
 
 
 
 
 
 
 
 
 
640
 
641
+ if save_model:
642
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
643
+ if use_wandb:
644
+ wandb.log({"sample_loss": avg_sample_loss})
645
+ print("saving:", avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
646
+ if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
647
+ min_loss = avg_sample_loss
648
+ save_checkpoint(unet)
649
 
650
+ # Эпоха окончена — агрегируем и логируем средние значения
651
  if accelerator.is_main_process:
652
+ # local averages
653
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
654
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
655
+
656
  print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
657
+ log_data_ep = {
658
+ "epoch_loss": avg_epoch_loss,
659
+ "epoch_grad": avg_epoch_grad,
660
+ "epoch": epoch + 1,
661
+ }
662
  if use_wandb:
663
+ wandb.log(log_data_ep)
664
+ if use_comet_ml:
665
+ comet_experiment.log_metrics(log_data_ep)
666
 
667
  # Завершение обучения - сохраняем финальную модель
668
  if accelerator.is_main_process:
669
  print("Обучение завершено! Сохраняем финальную модель...")
670
  if save_model:
671
+ save_checkpoint(unet, "fp16")
672
  if use_comet_ml:
673
  comet_experiment.end()
674
  accelerator.free_memory()
{unet37ep → unet}/config.json RENAMED
File without changes
{unet37ep → unet}/diffusion_pytorch_model.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e6fbf2b837fc99c48fa0815289cf1babe32a45dc7b975c4d0fcdcf41a7547367
3
  size 6184944280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad73d05c1584296954b3b137600b58f87c376f31cb5a80211a0dad2e3e118e9e
3
  size 6184944280
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.35.1",
4
+ "_name_or_path": "AiArtLab/simplevae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": false,
19
+ "in_channels": 3,
20
+ "latent_channels": 16,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 1024,
28
+ "scaling_factor": 1.0,
29
+ "shift_factor": 0,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ff1aa1cce9dcf003c7eb23548fdb62101fe9bdeedf30aa4a897b91a6ef13f93
3
+ size 167669678