recoilme commited on
Commit
63fe015
·
1 Parent(s): 1fada1f
samples/unet_384x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 09c35eb121c172751009625cabeb0b49d01c0c9421b194d745089e949c423031
  • Pointer size: 130 Bytes
  • Size of remote file: 88.8 kB

Git LFS Details

  • SHA256: db3d876660d9fc3451ffca740acebaca381badb7fce469a1c6af4c7a284a2dfb
  • Pointer size: 130 Bytes
  • Size of remote file: 87.8 kB
samples/unet_416x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 1ea19ab79156f33b3d740fb4dd490c26169b29e89c22c75742750676b7897119
  • Pointer size: 130 Bytes
  • Size of remote file: 94.4 kB

Git LFS Details

  • SHA256: bf63fd8c7f16e77eed1f42f5df0ffbcbd34dc92ee4d83b4400b2e0c0aa62f604
  • Pointer size: 131 Bytes
  • Size of remote file: 111 kB
samples/unet_448x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 1c9272c9771994cb3d576ab32eec5a12ea588909c96977811fd4bdc8caa29b4d
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB

Git LFS Details

  • SHA256: f593c68b8a4ac399d8df2cff6787f308a29c7909436aaab856b814dd7747dc26
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
samples/unet_480x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 6d2c928c3937e05ad4bff0da9ed50b7a733f5e38854f3cc18d27cde4bd7e3f3a
  • Pointer size: 130 Bytes
  • Size of remote file: 89.9 kB

Git LFS Details

  • SHA256: 5ca4ecb1134384bc249e6d9f884efa9bf5850cc44ec186eeec6e2099cba91fc2
  • Pointer size: 130 Bytes
  • Size of remote file: 53.2 kB
samples/unet_512x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 40f96083a53374119de7d3c668ca72f02d403e41458cec0c34cd1b552d2dffb0
  • Pointer size: 130 Bytes
  • Size of remote file: 77.2 kB

Git LFS Details

  • SHA256: edf30a1644d71c8f0bfb13f80ea1dcf19745692139bcc0b17519e332232f785b
  • Pointer size: 131 Bytes
  • Size of remote file: 191 kB
samples/unet_544x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 05f589a3bcc74fe6922ec49586636db68e82c6c3f297c791c2edb75ebc6db513
  • Pointer size: 130 Bytes
  • Size of remote file: 89.1 kB

Git LFS Details

  • SHA256: febc8d9a5d218e3df446f3c455b764299e3dd87a96f6a2a3d03f5d904a6cf191
  • Pointer size: 131 Bytes
  • Size of remote file: 187 kB
samples/unet_576x768_0.jpg CHANGED

Git LFS Details

  • SHA256: ac9127d12bc2e45ffc0a594bd901ab2d5d016b37516012384e094cee6684c007
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB

Git LFS Details

  • SHA256: 44487823e9513d4b1d005159503d6968f8242af8afd33afcfc9f27516b6635e6
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
samples/unet_608x768_0.jpg CHANGED

Git LFS Details

  • SHA256: e62c105427605cd3f8d41ce67cf55c2b4b316ba35136feda684c72d54296310d
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB

Git LFS Details

  • SHA256: 5e8368f7e96134a7511fce813575bacbd70e4bc94187c81b29150bfa1eb4f924
  • Pointer size: 130 Bytes
  • Size of remote file: 81.5 kB
samples/unet_640x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 50d9114cfb516526acb3cd8a0b544d95998d6ddffdb77ce5fc6ba663f99e7df1
  • Pointer size: 130 Bytes
  • Size of remote file: 96.5 kB

Git LFS Details

  • SHA256: 6efe3ac0e579303d0c34a2a69e666d10a6b9bc4060e9d85d8abe4f74dafe56a5
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
samples/unet_672x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 801d56debd926a84a6e12dfb29d96c496cc257bb5710476edb6b3711bb7e5bcf
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB

Git LFS Details

  • SHA256: a8e99b6d31e2a0e4504d6832a7feab237bf8d8c435dae8533702d910b5d0aafa
  • Pointer size: 130 Bytes
  • Size of remote file: 99.2 kB
samples/unet_704x768_0.jpg CHANGED

Git LFS Details

  • SHA256: eecc6da7c3dcba4079ec9a51bf66175b5e41342929aed8e53d79adf3a70c86f0
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB

Git LFS Details

  • SHA256: bf1f87eeeee8ba2a0689408919a83cc3cb7e55ae73b73e1791461f87c65ce559
  • Pointer size: 130 Bytes
  • Size of remote file: 38.8 kB
samples/unet_736x768_0.jpg CHANGED

Git LFS Details

  • SHA256: ff8c3883fc4b4cc0a06205d6279d46ae615329d3fab39a23f9bed8d6289d7b84
  • Pointer size: 131 Bytes
  • Size of remote file: 170 kB

Git LFS Details

  • SHA256: df89ffe59fbb535e7d694abea93d9291775c5d9f739a5c3c01ab4b300acd05c5
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
samples/unet_768x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 802b15929117b53d258d08fbbe03630b1b2baa79466b9c81f729de77560e4f46
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB

Git LFS Details

  • SHA256: 858d032e82bc14cf41f32c2f0d47b47b72a26034a84b312c6b667cb561b0f082
  • Pointer size: 130 Bytes
  • Size of remote file: 95.9 kB
samples/unet_768x416_0.jpg CHANGED

Git LFS Details

  • SHA256: 6c3418a13e1d34e1d176b8a518c1ce48763025a15b93b0cc173e87745d31d70f
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB

Git LFS Details

  • SHA256: ff9dc43b30cd6745e85ff138a1e7d95a759ac6de7e3b253bdbe7cce9fb66f830
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
samples/unet_768x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 5ec2d51c7581356a49b384d2f473c70d72678176a96c59d92a4d72a11df51be5
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB

Git LFS Details

  • SHA256: 1ef88a51e5bc2a735092cf3d93fe26591f04ef8a0b11f4ef961dfdabe4e55957
  • Pointer size: 130 Bytes
  • Size of remote file: 61.7 kB
samples/unet_768x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 5a424c54aa00f9e79cc060e5609b57cd1ceb30758f93f21113c2f5f4d20091c3
  • Pointer size: 131 Bytes
  • Size of remote file: 310 kB

Git LFS Details

  • SHA256: b4723f5784d2ea32a69b7a0a64a8dc8b2341af01dbbf117cf48a13ba459bd549
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
samples/unet_768x512_0.jpg CHANGED

Git LFS Details

  • SHA256: c6ce8ab3293f4d4e8b829f988f8f648d3fd68fcb72c51a4957e7b6dd4eaa3faf
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB

Git LFS Details

  • SHA256: 063257c30ad38063b1c3b451ac0375400daea2b32939c51e95530c40ef6b0122
  • Pointer size: 131 Bytes
  • Size of remote file: 128 kB
samples/unet_768x544_0.jpg CHANGED

Git LFS Details

  • SHA256: a524eb892d693db8064a78598902aef039d88ce89a673da812c49a4ef939aee3
  • Pointer size: 130 Bytes
  • Size of remote file: 92 kB

Git LFS Details

  • SHA256: b711280bf5fe47a2cd291ab409ceff2c828a89a5f133eb8ffef3f107dae4301e
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
samples/unet_768x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 9698162220a01421171ce099dccf1f07e1dee7409b8f184257006f62fe93e416
  • Pointer size: 130 Bytes
  • Size of remote file: 49.2 kB

Git LFS Details

  • SHA256: 6b16571a3d34ec7728cf22d6370d9ebb500d1fd979cd97346659d0875ce3e7be
  • Pointer size: 130 Bytes
  • Size of remote file: 22.9 kB
samples/unet_768x608_0.jpg CHANGED

Git LFS Details

  • SHA256: ea980e4896195daf66fa0a57daa058edefb48dca5df8422415a2cc1d8f34514f
  • Pointer size: 130 Bytes
  • Size of remote file: 76.4 kB

Git LFS Details

  • SHA256: d494e6c87316f3f6f43a7c3c3440a79829b3471b3809ecaf52b90d651cd55252
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
samples/unet_768x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 0ec6daa16455405a9824ab054130d4dd73cb1b6dfc6c0195066195e856ab8d23
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB

Git LFS Details

  • SHA256: 0249224cfa790ae7ead61eb8ac51febea8171b2a3a8d8374db596b06f84d819a
  • Pointer size: 131 Bytes
  • Size of remote file: 203 kB
samples/unet_768x672_0.jpg CHANGED

Git LFS Details

  • SHA256: ff791690eefde6c8e60bc646127bf81699c9dd9a02f99e64466f7472971a8099
  • Pointer size: 130 Bytes
  • Size of remote file: 97.6 kB

Git LFS Details

  • SHA256: 971ac94086c5b0af748b5e4de32f6c774115bae492829ab07942de5404e7570a
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
samples/unet_768x704_0.jpg CHANGED

Git LFS Details

  • SHA256: ad09a5f7d2a94c723eb3a9bb3b9b02b8b000463a722ce97a337a8ee10bdda3c8
  • Pointer size: 130 Bytes
  • Size of remote file: 89.4 kB

Git LFS Details

  • SHA256: 9bbc3aa774ec450b914926ba1e845fb6e77fdda2ae1e6a8ddf4a7a2ef04774e4
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
samples/unet_768x736_0.jpg CHANGED

Git LFS Details

  • SHA256: d61a8a18980f1037588194da0862ab5f9b9fe739a8a520f5de5a9f14a3958aa0
  • Pointer size: 130 Bytes
  • Size of remote file: 95.7 kB

Git LFS Details

  • SHA256: 5f55edf0582707708f90e8ae5f0fa5f2c93b15625a1a9c3518066fad8c7396f3
  • Pointer size: 131 Bytes
  • Size of remote file: 121 kB
samples/unet_768x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 7a1df37cda034e531f3656602baca7675c152aa38fd81218f348aeb6814157d5
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB

Git LFS Details

  • SHA256: 2cc8525522891dd9b5b84074f6cc4fd9ab185eb9e645fc553212013f526171f2
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
snooc.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim import Optimizer
3
+
4
+ class SnooC(Optimizer):
5
+ """
6
+ @DominikKallusky, @vishal9-team, @vinaysrao
7
+
8
+ Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can
9
+ improve the stability and smoothness of the optimization process and thus the quality
10
+ of large language models (LLM) and other models. Snoo implicitly adds temporal regularization
11
+ to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter
12
+ minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead
13
+ in compute and moderate memory usage.
14
+ """
15
+
16
+ @torch.no_grad()
17
+ def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None:
18
+ self.optimizer = optimizer
19
+ self.lr = lr
20
+ self.momentum = momentum
21
+ self.k = k
22
+ self.current_step = 0
23
+ self.model_params = None
24
+ self.outer_buf = None
25
+ self.outer_optimizer = None
26
+
27
+ # Check if the optimizer already has parameters
28
+ if self.optimizer.param_groups:
29
+ self.param_groups = self.optimizer.param_groups
30
+
31
+ @torch.no_grad()
32
+ def _initialize_outer_optimizer(self):
33
+ params = []
34
+ for pg in self.optimizer.param_groups:
35
+ if len(pg['params']) > 1:
36
+ for param in pg['params']:
37
+ if isinstance(param, torch.Tensor):
38
+ params.append(param)
39
+ else:
40
+ params = pg['params']
41
+
42
+ if not params:
43
+ return
44
+
45
+ self.model_params = list(params)
46
+ self.outer_buf = [p.clone() for p in self.model_params]
47
+ self.outer_optimizer = torch.optim.SGD(
48
+ self.model_params,
49
+ lr=self.lr,
50
+ momentum=self.momentum,
51
+ nesterov=True,
52
+ fused=True,
53
+ )
54
+ self.param_groups = self.optimizer.param_groups
55
+ del params
56
+
57
+ @torch.no_grad()
58
+ def step(self, closure=None):
59
+ if self.outer_optimizer is None or self.current_step == 0:
60
+ # If the optimizer has been updated with parameters, initialize.
61
+ if self.optimizer.param_groups:
62
+ self._initialize_outer_optimizer()
63
+ else:
64
+ # If there are still no parameters, we cannot perform a step.
65
+ # Depending on the use case, you might want to raise an error
66
+ # or simply return without doing anything.
67
+ return self.optimizer.step(closure)
68
+
69
+ loss = self.optimizer.step(closure)
70
+ if self.current_step % self.k == 0:
71
+ for p_new, p_old in zip(self.model_params, self.outer_buf):
72
+ p_new.grad = p_old.data - p_new.data
73
+ p_new.copy_(p_old, non_blocking=True)
74
+
75
+ self.outer_optimizer.step()
76
+
77
+ for p_new, p_old in zip(self.model_params, self.outer_buf):
78
+ p_old.copy_(p_new, non_blocking=True)
79
+ self.current_step += 1
80
+ return loss
81
+
82
+ def zero_grad(self, set_to_none: bool = False):
83
+ self.optimizer.zero_grad(set_to_none=set_to_none)
84
+
85
+ def state_dict(self):
86
+ return self.optimizer.state_dict()
87
+
88
+ def load_state_dict(self, state_dict):
89
+ self.optimizer.load_state_dict(state_dict)
train.py CHANGED
@@ -29,14 +29,14 @@ from transformers import AutoTokenizer, AutoModel
29
  # --------------------------- Параметры ---------------------------
30
  ds_path = "/workspace/sdxs/datasets/768"
31
  project = "unet"
32
- batch_size = 40
33
  base_learning_rate = 4e-5
34
  min_learning_rate = 2.7e-5
35
  num_epochs = 100
36
  sample_interval_share = 5
37
  max_length = 192
38
- use_wandb = True
39
- use_comet_ml = False
40
  save_model = True
41
  use_decay = True
42
  fbp = False
@@ -53,8 +53,8 @@ torch.backends.cudnn.allow_tf32 = True
53
  dtype = torch.float32
54
  save_barrier = 1.006
55
  warmup_percent = 0.01
56
- percentile_clipping = 98
57
- betta2 = 0.999 #0.999?
58
  eps = 1e-7
59
  clip_grad_norm = 1.0
60
  limit = 0
@@ -377,68 +377,21 @@ def create_optimizer(name, params):
377
  return torch.optim.AdamW(
378
  params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
379
  )
380
- elif name == "adamuon":
381
- from pytorch_optimizer import AdaMuon
382
- # 1. РАЗДЕЛЕНИЕ ПАРАМЕТРОВ
383
- # Разделяем параметры на те, которые Muon-совместимы (ndim >= 2),
384
- # и те, которые лучше оставить для AdamW8bit (ndim < 2).
385
-
386
- muon_params = []
387
- adam8bit_params = []
388
-
389
- if isinstance(params, (list, tuple)):
390
- all_params = params
391
- else:
392
- # Если params - это генератор (unet.parameters()), преобразуем его в list
393
- all_params = list(params)
394
-
395
- for p in all_params:
396
- if p.requires_grad:
397
- if p.ndim >= 2:
398
- muon_params.append(p)
399
- else:
400
- adam8bit_params.append(p)
401
-
402
- # 2. СОЗДАНИЕ ГРУПП ПАРАМЕТРОВ
403
- param_groups = []
404
-
405
- # --- Группа 1: AdaMuon (для матриц/сверток) ---
406
- if muon_params:
407
- print(f"Используем AdaMuon для {len(muon_params)} параметров (ndim >= 2)")
408
- param_groups.append({
409
- 'params': muon_params,
410
- 'optimizer': 'AdaMuon', # Явно указываем оптимизатор
411
- 'use_muon': True, # Обязательно для AdaMuon
412
- 'lr': base_learning_rate, # Использование обычных пробелов
413
- 'betas': (0.9, betta2),
414
- 'eps': eps,
415
- 'weight_decay': 0.01,
416
- 'nesterov': True,
417
- 'decouple_weight_decay': True
418
- })
419
-
420
- # --- Группа 2: AdamW8bit (для смещений/нормализации) ---
421
- if adam8bit_params:
422
- print(f"Используем AdamW8bit для {len(adam8bit_params)} параметров (ndim < 2)")
423
- param_groups.append({
424
- 'params': adam8bit_params,
425
- 'optimizer': 'AdamW8bit', # Явно указываем оптимизатор
426
- 'use_muon': False,
427
- # AdamW8bit не требует use_muon
428
- 'lr': base_learning_rate, # Использование обычных пробелов
429
- 'betas': (0.9, betta2),
430
- 'eps': eps,
431
- 'weight_decay': 0.01,
432
- 'percentile_clipping': percentile_clipping
433
- })
434
-
435
- if not param_groups:
436
- raise ValueError("Нет параметров для оптимизации.")
437
-
438
- # 3. ИНИЦИАЛИЗАЦИЯ (смешанный оптимизатор)
439
- # При передаче списка групп оптимизатор AdaMuon автоматически
440
- # распознает ключи 'optimizer' и инициализирует внутренние оптимизаторы.
441
- return AdaMuon(param_groups)
442
  else:
443
  raise ValueError(f"Unknown optimizer: {name}")
444
 
@@ -723,12 +676,20 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
723
  # Передаем tuple (emb, mask) для негатива
724
  generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
725
  last_n = sample_interval
726
-
727
  if save_model:
728
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
729
- print("saving:", avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
730
- if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
731
- min_loss = avg_sample_loss
 
 
 
 
 
 
 
 
732
  save_checkpoint(unet)
733
 
734
  if accelerator.is_main_process:
 
29
  # --------------------------- Параметры ---------------------------
30
  ds_path = "/workspace/sdxs/datasets/768"
31
  project = "unet"
32
+ batch_size = 36
33
  base_learning_rate = 4e-5
34
  min_learning_rate = 2.7e-5
35
  num_epochs = 100
36
  sample_interval_share = 5
37
  max_length = 192
38
+ use_wandb = False
39
+ use_comet_ml = True
40
  save_model = True
41
  use_decay = True
42
  fbp = False
 
53
  dtype = torch.float32
54
  save_barrier = 1.006
55
  warmup_percent = 0.01
56
+ percentile_clipping = 97
57
+ betta2 = 0.999
58
  eps = 1e-7
59
  clip_grad_norm = 1.0
60
  limit = 0
 
377
  return torch.optim.AdamW(
378
  params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
379
  )
380
+ elif name == "muon":
381
+ from muon import MuonWithAuxAdam
382
+ trainable_params = [p for p in params if p.requires_grad]
383
+ hidden_weights = [p for p in trainable_params if p.ndim >= 2]
384
+ hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
385
+
386
+ param_groups = [
387
+ dict(params=hidden_weights, use_muon=True,
388
+ lr=1e-3, weight_decay=1e-4),
389
+ dict(params=hidden_gains_biases, use_muon=False,
390
+ lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
391
+ ]
392
+ optimizer = MuonWithAuxAdam(param_groups)
393
+ from snooc import SnooC
394
+ return SnooC(optimizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  else:
396
  raise ValueError(f"Unknown optimizer: {name}")
397
 
 
676
  # Передаем tuple (emb, mask) для негатива
677
  generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
678
  last_n = sample_interval
679
+
680
  if save_model:
681
+ has_losses = len(batch_losses) > 0
682
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
683
+ last_loss = batch_losses[-1] if has_losses else 0.0
684
+ max_loss = max(avg_sample_loss, last_loss)
685
+ should_save = max_loss < min_loss * save_barrier
686
+ print(
687
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
688
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
689
+ )
690
+ # 6. Сохранение и обновление
691
+ if should_save:
692
+ min_loss = max_loss
693
  save_checkpoint(unet)
694
 
695
  if accelerator.is_main_process:
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d40e06c2a67bb2729fd4dae0a554773051648c61bc3ffa00fedf9412bf40ec74
3
  size 7444321360
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e910ef5f2cb2a67f6f4259d48422d509e2b77e51021e41b611e7329aecd84aa0
3
  size 7444321360