recoilme commited on
Commit
2f27d93
·
1 Parent(s): b12d34d
samples/unet_192x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 1a536d4fac31aa4e394ed8ea7f04697be7b197ea51f5b7e5d258c9b39d50ebc7
  • Pointer size: 130 Bytes
  • Size of remote file: 35.6 kB

Git LFS Details

  • SHA256: fee3a9248eb72959fcdf53469b5f7e3fe1c868518a7ee5cd1eb79fa8d241da78
  • Pointer size: 130 Bytes
  • Size of remote file: 28.3 kB
samples/unet_256x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 16c7f1b7450529cbed24a07c051dda854f58ab3a1df7e6fadb6615369a4021ae
  • Pointer size: 130 Bytes
  • Size of remote file: 39.9 kB

Git LFS Details

  • SHA256: 3e14a6bb9f7079308e23b1d57380b1a31c5762088266c3cc56c61d020175b206
  • Pointer size: 130 Bytes
  • Size of remote file: 47.4 kB
samples/unet_320x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 4f5958a2d74a4b8a882ca49d3b3ec1df6969b2491aefd6e8d50292a2caaeb5c2
  • Pointer size: 130 Bytes
  • Size of remote file: 35.4 kB

Git LFS Details

  • SHA256: 559137ef6c264f8e69ffc0c49164ea21425893ad25eea7a09bcc5b94ae639a46
  • Pointer size: 130 Bytes
  • Size of remote file: 48.8 kB
samples/unet_384x192_0.jpg CHANGED

Git LFS Details

  • SHA256: cbdcd248eb100bb931edff25ec92b2564c271b40e8a5e91fb39f3a1b9a94f35e
  • Pointer size: 130 Bytes
  • Size of remote file: 33.8 kB

Git LFS Details

  • SHA256: bdf751a3eb5ecc15af9ac3e32ea43880a141d3763cafbe3a0a29314270b37134
  • Pointer size: 130 Bytes
  • Size of remote file: 35.6 kB
samples/unet_384x256_0.jpg CHANGED

Git LFS Details

  • SHA256: 2eda76c157e8621a2973bdde28e6a5fe961210559d117f7a2f275c773c991c80
  • Pointer size: 130 Bytes
  • Size of remote file: 46.3 kB

Git LFS Details

  • SHA256: 2e8280d0b253a93c73422f133a34955c4a39c54dceb54591f7afb9cb387ec2cc
  • Pointer size: 130 Bytes
  • Size of remote file: 52.4 kB
samples/unet_384x320_0.jpg CHANGED

Git LFS Details

  • SHA256: a8e9f4715b636d643273195ec3bb2781d6bf219461c661e36559600e61f2dffe
  • Pointer size: 130 Bytes
  • Size of remote file: 72.2 kB

Git LFS Details

  • SHA256: 07ecb57f6c3c130b6eaaf0fd4a4927ccc63a982d3a943e0e7d623260184b61c6
  • Pointer size: 130 Bytes
  • Size of remote file: 72.6 kB
samples/unet_384x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 771a94814a3c2e18910c085d90404ea851f40b0468e0dc8cb19b8f6ac92e41fb
  • Pointer size: 130 Bytes
  • Size of remote file: 44.6 kB

Git LFS Details

  • SHA256: 79aa26427b527f0897b2029f6f2abc03a63d74053c6753cb1d62ffc62159b7dd
  • Pointer size: 130 Bytes
  • Size of remote file: 50.1 kB
train.py CHANGED
@@ -37,7 +37,7 @@ use_wandb = True
37
  save_model = True
38
  use_decay = True
39
  fbp = False # fused backward pass
40
- adam8bit = True
41
  torch_compile = False
42
  unet_gradient = True
43
  clip_sample = False #Scheduler
@@ -187,7 +187,7 @@ if use_wandb and accelerator.is_main_process:
187
  "base_learning_rate": base_learning_rate,
188
  "num_epochs": num_epochs,
189
  "fbp": fbp,
190
- "adam8bit": adam8bit,
191
  })
192
 
193
  # Включение Flash Attention 2/SDPA
@@ -495,89 +495,63 @@ else:
495
  if fbp:
496
  trainable_params = list(unet.parameters())
497
 
498
- if fbp:
499
- # [1] Создаем словарь оптимизаторов (fused backward)
500
- if adam8bit:
501
- optimizer_dict = {
502
- p: bnb.optim.AdamW8bit(
503
- [p], # Каждый параметр получает свой оптимизатор
504
- lr=base_learning_rate,
505
- eps=1e-8
506
- ) for p in trainable_params
507
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  else:
509
- optimizer_dict = {
510
- p: bnb.optim.Lion8bit(
511
- [p], # Каждый параметр получает свой оптимизатор
512
- lr=base_learning_rate,
513
- betas=(0.9, 0.97),
514
- weight_decay=0.01,
515
- percentile_clipping=percentile_clipping,
516
- ) for p in trainable_params
517
- }
518
 
519
- # [2] Определяем hook для применения оптимизатора сразу после накопления градиента
520
  def optimizer_hook(param):
521
  optimizer_dict[param].step()
522
  optimizer_dict[param].zero_grad(set_to_none=True)
523
 
524
- # [3] Регистрируем hook для trainable параметров модели
525
  for param in trainable_params:
526
  param.register_post_accumulate_grad_hook(optimizer_hook)
527
 
528
- # Подготовка через Accelerator
529
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
530
  else:
531
- if adam8bit:
532
- optimizer = bnb.optim.AdamW8bit(
533
- params=unet.parameters(),
534
- lr=base_learning_rate,
535
- betas=(0.9, 0.999),
536
- eps=1e-8,
537
- weight_decay=0.01
538
- )
539
- #from torch.optim import AdamW
540
- #optimizer = AdamW(
541
- # params=unet.parameters(),
542
- # lr=base_learning_rate,
543
- # betas=(0.9, 0.999),
544
- # eps=1e-8,
545
- # weight_decay=0.01
546
- #)
547
- else:
548
- optimizer = bnb.optim.Lion8bit(
549
- params=unet.parameters(),
550
- lr=base_learning_rate,
551
- betas=(0.9, 0.97),
552
- weight_decay=0.01,
553
- percentile_clipping=percentile_clipping,
554
- )
555
-
556
- def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True):
557
- # Если не используем затухание, возвращаем базовый LR
558
  if not use_decay:
559
- return base_lr
 
 
560
 
561
- # Иначе используем линейный прогрев и косинусное затухание
562
- x = step / max_steps
563
- percent = 0.05
564
- if x < percent:
565
- # Линейный прогрев до percent% шагов
566
- return min_lr + (base_lr - min_lr) * (x / percent)
567
- else:
568
- # Косинусное затухание
569
- decay_ratio = (x - percent) / (1 - percent)
570
- return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio))
571
 
572
-
573
- def custom_lr_lambda(step):
574
- return lr_schedule(step, total_training_steps*world_size,
575
- base_learning_rate, min_learning_rate,
576
- use_decay) / base_learning_rate
577
-
578
- lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
579
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
580
-
581
  # Регистрация хуков ПОСЛЕ prepare
582
  dispersive_hook.register_hooks(unet, "down_blocks.2")
583
 
 
37
  save_model = True
38
  use_decay = True
39
  fbp = False # fused backward pass
40
+ optimizer_type = "adam8bit"
41
  torch_compile = False
42
  unet_gradient = True
43
  clip_sample = False #Scheduler
 
187
  "base_learning_rate": base_learning_rate,
188
  "num_epochs": num_epochs,
189
  "fbp": fbp,
190
+ "optimizer_type": optimizer_type,
191
  })
192
 
193
  # Включение Flash Attention 2/SDPA
 
495
  if fbp:
496
  trainable_params = list(unet.parameters())
497
 
498
+ def create_optimizer(name, params):
499
+ if name == "adam8bit":
500
+ import bitsandbytes as bnb
501
+ return bnb.optim.AdamW8bit(
502
+ params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
503
+ )
504
+ elif name == "adam":
505
+ return torch.optim.AdamW(
506
+ params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
507
+ )
508
+ elif name == "lion8bit":
509
+ import bitsandbytes as bnb
510
+ return bnb.optim.Lion8bit(
511
+ params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
512
+ percentile_clipping=percentile_clipping
513
+ )
514
+ elif name == "adafactor":
515
+ from transformers import Adafactor
516
+ return Adafactor(
517
+ params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
518
+ warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
519
+ beta1=0.9, weight_decay=0.01
520
+ )
521
  else:
522
+ raise ValueError(f"Unknown optimizer: {name}")
523
+
524
+ if fbp:
525
+ # Создаем отдельный оптимизатор для каждого параметра
526
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
 
 
 
 
527
 
 
528
  def optimizer_hook(param):
529
  optimizer_dict[param].step()
530
  optimizer_dict[param].zero_grad(set_to_none=True)
531
 
 
532
  for param in trainable_params:
533
  param.register_post_accumulate_grad_hook(optimizer_hook)
534
 
 
535
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
536
  else:
537
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
538
+
539
+ def lr_schedule(step):
540
+ x = step / (total_training_steps * world_size)
541
+ warmup = 0.05
542
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  if not use_decay:
544
+ return base_learning_rate
545
+ if x < warmup:
546
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
547
 
548
+ decay_ratio = (x - warmup) / (1 - warmup)
549
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
550
+ (1 + math.cos(math.pi * decay_ratio))
 
 
 
 
 
 
 
551
 
552
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
 
 
 
 
 
 
553
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
554
+
555
  # Регистрация хуков ПОСЛЕ prepare
556
  dispersive_hook.register_hooks(unet, "down_blocks.2")
557
 
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e00ac22a38007a5fadae89d5d39b7286f8d1189c3ba4806653a29214dfd04e3f
3
  size 7014306128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a3a5c9271f789977d7738b696c288c509253bbb416e182ab58925101a0a9409
3
  size 7014306128