2512
Browse files- samples/unet_384x768_0.jpg +2 -2
- samples/unet_416x768_0.jpg +2 -2
- samples/unet_448x768_0.jpg +2 -2
- samples/unet_480x768_0.jpg +2 -2
- samples/unet_512x768_0.jpg +2 -2
- samples/unet_544x768_0.jpg +2 -2
- samples/unet_576x768_0.jpg +2 -2
- samples/unet_608x768_0.jpg +2 -2
- samples/unet_640x768_0.jpg +2 -2
- samples/unet_672x768_0.jpg +2 -2
- samples/unet_704x768_0.jpg +2 -2
- samples/unet_736x768_0.jpg +2 -2
- samples/unet_768x384_0.jpg +2 -2
- samples/unet_768x416_0.jpg +2 -2
- samples/unet_768x448_0.jpg +2 -2
- samples/unet_768x480_0.jpg +2 -2
- samples/unet_768x512_0.jpg +2 -2
- samples/unet_768x544_0.jpg +2 -2
- samples/unet_768x576_0.jpg +2 -2
- samples/unet_768x608_0.jpg +2 -2
- samples/unet_768x640_0.jpg +2 -2
- samples/unet_768x672_0.jpg +2 -2
- samples/unet_768x704_0.jpg +2 -2
- samples/unet_768x736_0.jpg +2 -2
- samples/unet_768x768_0.jpg +2 -2
- snooc.py +89 -0
- train.py +33 -72
- unet/diffusion_pytorch_model.safetensors +1 -1
samples/unet_384x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_416x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_448x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_480x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_512x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_544x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_576x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_608x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_672x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_736x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x416_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x448_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x480_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x512_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x544_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x576_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x608_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x672_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x736_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
snooc.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.optim import Optimizer
|
| 3 |
+
|
| 4 |
+
class SnooC(Optimizer):
|
| 5 |
+
"""
|
| 6 |
+
@DominikKallusky, @vishal9-team, @vinaysrao
|
| 7 |
+
|
| 8 |
+
Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can
|
| 9 |
+
improve the stability and smoothness of the optimization process and thus the quality
|
| 10 |
+
of large language models (LLM) and other models. Snoo implicitly adds temporal regularization
|
| 11 |
+
to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter
|
| 12 |
+
minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead
|
| 13 |
+
in compute and moderate memory usage.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None:
|
| 18 |
+
self.optimizer = optimizer
|
| 19 |
+
self.lr = lr
|
| 20 |
+
self.momentum = momentum
|
| 21 |
+
self.k = k
|
| 22 |
+
self.current_step = 0
|
| 23 |
+
self.model_params = None
|
| 24 |
+
self.outer_buf = None
|
| 25 |
+
self.outer_optimizer = None
|
| 26 |
+
|
| 27 |
+
# Check if the optimizer already has parameters
|
| 28 |
+
if self.optimizer.param_groups:
|
| 29 |
+
self.param_groups = self.optimizer.param_groups
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def _initialize_outer_optimizer(self):
|
| 33 |
+
params = []
|
| 34 |
+
for pg in self.optimizer.param_groups:
|
| 35 |
+
if len(pg['params']) > 1:
|
| 36 |
+
for param in pg['params']:
|
| 37 |
+
if isinstance(param, torch.Tensor):
|
| 38 |
+
params.append(param)
|
| 39 |
+
else:
|
| 40 |
+
params = pg['params']
|
| 41 |
+
|
| 42 |
+
if not params:
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
self.model_params = list(params)
|
| 46 |
+
self.outer_buf = [p.clone() for p in self.model_params]
|
| 47 |
+
self.outer_optimizer = torch.optim.SGD(
|
| 48 |
+
self.model_params,
|
| 49 |
+
lr=self.lr,
|
| 50 |
+
momentum=self.momentum,
|
| 51 |
+
nesterov=True,
|
| 52 |
+
fused=True,
|
| 53 |
+
)
|
| 54 |
+
self.param_groups = self.optimizer.param_groups
|
| 55 |
+
del params
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def step(self, closure=None):
|
| 59 |
+
if self.outer_optimizer is None or self.current_step == 0:
|
| 60 |
+
# If the optimizer has been updated with parameters, initialize.
|
| 61 |
+
if self.optimizer.param_groups:
|
| 62 |
+
self._initialize_outer_optimizer()
|
| 63 |
+
else:
|
| 64 |
+
# If there are still no parameters, we cannot perform a step.
|
| 65 |
+
# Depending on the use case, you might want to raise an error
|
| 66 |
+
# or simply return without doing anything.
|
| 67 |
+
return self.optimizer.step(closure)
|
| 68 |
+
|
| 69 |
+
loss = self.optimizer.step(closure)
|
| 70 |
+
if self.current_step % self.k == 0:
|
| 71 |
+
for p_new, p_old in zip(self.model_params, self.outer_buf):
|
| 72 |
+
p_new.grad = p_old.data - p_new.data
|
| 73 |
+
p_new.copy_(p_old, non_blocking=True)
|
| 74 |
+
|
| 75 |
+
self.outer_optimizer.step()
|
| 76 |
+
|
| 77 |
+
for p_new, p_old in zip(self.model_params, self.outer_buf):
|
| 78 |
+
p_old.copy_(p_new, non_blocking=True)
|
| 79 |
+
self.current_step += 1
|
| 80 |
+
return loss
|
| 81 |
+
|
| 82 |
+
def zero_grad(self, set_to_none: bool = False):
|
| 83 |
+
self.optimizer.zero_grad(set_to_none=set_to_none)
|
| 84 |
+
|
| 85 |
+
def state_dict(self):
|
| 86 |
+
return self.optimizer.state_dict()
|
| 87 |
+
|
| 88 |
+
def load_state_dict(self, state_dict):
|
| 89 |
+
self.optimizer.load_state_dict(state_dict)
|
train.py
CHANGED
|
@@ -29,14 +29,14 @@ from transformers import AutoTokenizer, AutoModel
|
|
| 29 |
# --------------------------- Параметры ---------------------------
|
| 30 |
ds_path = "/workspace/sdxs/datasets/768"
|
| 31 |
project = "unet"
|
| 32 |
-
batch_size =
|
| 33 |
base_learning_rate = 4e-5
|
| 34 |
min_learning_rate = 2.7e-5
|
| 35 |
num_epochs = 100
|
| 36 |
sample_interval_share = 5
|
| 37 |
max_length = 192
|
| 38 |
-
use_wandb =
|
| 39 |
-
use_comet_ml =
|
| 40 |
save_model = True
|
| 41 |
use_decay = True
|
| 42 |
fbp = False
|
|
@@ -53,8 +53,8 @@ torch.backends.cudnn.allow_tf32 = True
|
|
| 53 |
dtype = torch.float32
|
| 54 |
save_barrier = 1.006
|
| 55 |
warmup_percent = 0.01
|
| 56 |
-
percentile_clipping =
|
| 57 |
-
betta2 = 0.999
|
| 58 |
eps = 1e-7
|
| 59 |
clip_grad_norm = 1.0
|
| 60 |
limit = 0
|
|
@@ -377,68 +377,21 @@ def create_optimizer(name, params):
|
|
| 377 |
return torch.optim.AdamW(
|
| 378 |
params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
|
| 379 |
)
|
| 380 |
-
elif name == "
|
| 381 |
-
from
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
for p in all_params:
|
| 396 |
-
if p.requires_grad:
|
| 397 |
-
if p.ndim >= 2:
|
| 398 |
-
muon_params.append(p)
|
| 399 |
-
else:
|
| 400 |
-
adam8bit_params.append(p)
|
| 401 |
-
|
| 402 |
-
# 2. СОЗДАНИЕ ГРУПП ПАРАМЕТРОВ
|
| 403 |
-
param_groups = []
|
| 404 |
-
|
| 405 |
-
# --- Группа 1: AdaMuon (для матриц/сверток) ---
|
| 406 |
-
if muon_params:
|
| 407 |
-
print(f"Используем AdaMuon для {len(muon_params)} параметров (ndim >= 2)")
|
| 408 |
-
param_groups.append({
|
| 409 |
-
'params': muon_params,
|
| 410 |
-
'optimizer': 'AdaMuon', # Явно указываем оптимизатор
|
| 411 |
-
'use_muon': True, # Обязательно для AdaMuon
|
| 412 |
-
'lr': base_learning_rate, # Использование обычных пробелов
|
| 413 |
-
'betas': (0.9, betta2),
|
| 414 |
-
'eps': eps,
|
| 415 |
-
'weight_decay': 0.01,
|
| 416 |
-
'nesterov': True,
|
| 417 |
-
'decouple_weight_decay': True
|
| 418 |
-
})
|
| 419 |
-
|
| 420 |
-
# --- Группа 2: AdamW8bit (для смещений/нормализации) ---
|
| 421 |
-
if adam8bit_params:
|
| 422 |
-
print(f"Используем AdamW8bit для {len(adam8bit_params)} параметров (ndim < 2)")
|
| 423 |
-
param_groups.append({
|
| 424 |
-
'params': adam8bit_params,
|
| 425 |
-
'optimizer': 'AdamW8bit', # Явно указываем оптимизатор
|
| 426 |
-
'use_muon': False,
|
| 427 |
-
# AdamW8bit не требует use_muon
|
| 428 |
-
'lr': base_learning_rate, # Использование обычных пробелов
|
| 429 |
-
'betas': (0.9, betta2),
|
| 430 |
-
'eps': eps,
|
| 431 |
-
'weight_decay': 0.01,
|
| 432 |
-
'percentile_clipping': percentile_clipping
|
| 433 |
-
})
|
| 434 |
-
|
| 435 |
-
if not param_groups:
|
| 436 |
-
raise ValueError("Нет параметров для оптимизации.")
|
| 437 |
-
|
| 438 |
-
# 3. ИНИЦИАЛИЗАЦИЯ (смешанный оптимизатор)
|
| 439 |
-
# При передаче списка групп оптимизатор AdaMuon автоматически
|
| 440 |
-
# распознает ключи 'optimizer' и инициализирует внутренние оптимизаторы.
|
| 441 |
-
return AdaMuon(param_groups)
|
| 442 |
else:
|
| 443 |
raise ValueError(f"Unknown optimizer: {name}")
|
| 444 |
|
|
@@ -723,12 +676,20 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
| 723 |
# Передаем tuple (emb, mask) для негатива
|
| 724 |
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
|
| 725 |
last_n = sample_interval
|
| 726 |
-
|
| 727 |
if save_model:
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
save_checkpoint(unet)
|
| 733 |
|
| 734 |
if accelerator.is_main_process:
|
|
|
|
| 29 |
# --------------------------- Параметры ---------------------------
|
| 30 |
ds_path = "/workspace/sdxs/datasets/768"
|
| 31 |
project = "unet"
|
| 32 |
+
batch_size = 36
|
| 33 |
base_learning_rate = 4e-5
|
| 34 |
min_learning_rate = 2.7e-5
|
| 35 |
num_epochs = 100
|
| 36 |
sample_interval_share = 5
|
| 37 |
max_length = 192
|
| 38 |
+
use_wandb = False
|
| 39 |
+
use_comet_ml = True
|
| 40 |
save_model = True
|
| 41 |
use_decay = True
|
| 42 |
fbp = False
|
|
|
|
| 53 |
dtype = torch.float32
|
| 54 |
save_barrier = 1.006
|
| 55 |
warmup_percent = 0.01
|
| 56 |
+
percentile_clipping = 97
|
| 57 |
+
betta2 = 0.999
|
| 58 |
eps = 1e-7
|
| 59 |
clip_grad_norm = 1.0
|
| 60 |
limit = 0
|
|
|
|
| 377 |
return torch.optim.AdamW(
|
| 378 |
params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
|
| 379 |
)
|
| 380 |
+
elif name == "muon":
|
| 381 |
+
from muon import MuonWithAuxAdam
|
| 382 |
+
trainable_params = [p for p in params if p.requires_grad]
|
| 383 |
+
hidden_weights = [p for p in trainable_params if p.ndim >= 2]
|
| 384 |
+
hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
|
| 385 |
+
|
| 386 |
+
param_groups = [
|
| 387 |
+
dict(params=hidden_weights, use_muon=True,
|
| 388 |
+
lr=1e-3, weight_decay=1e-4),
|
| 389 |
+
dict(params=hidden_gains_biases, use_muon=False,
|
| 390 |
+
lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
|
| 391 |
+
]
|
| 392 |
+
optimizer = MuonWithAuxAdam(param_groups)
|
| 393 |
+
from snooc import SnooC
|
| 394 |
+
return SnooC(optimizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
else:
|
| 396 |
raise ValueError(f"Unknown optimizer: {name}")
|
| 397 |
|
|
|
|
| 676 |
# Передаем tuple (emb, mask) для негатива
|
| 677 |
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
|
| 678 |
last_n = sample_interval
|
| 679 |
+
|
| 680 |
if save_model:
|
| 681 |
+
has_losses = len(batch_losses) > 0
|
| 682 |
+
avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
|
| 683 |
+
last_loss = batch_losses[-1] if has_losses else 0.0
|
| 684 |
+
max_loss = max(avg_sample_loss, last_loss)
|
| 685 |
+
should_save = max_loss < min_loss * save_barrier
|
| 686 |
+
print(
|
| 687 |
+
f"Saving: {should_save} | Max: {max_loss:.4f} | "
|
| 688 |
+
f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
|
| 689 |
+
)
|
| 690 |
+
# 6. Сохранение и обновление
|
| 691 |
+
if should_save:
|
| 692 |
+
min_loss = max_loss
|
| 693 |
save_checkpoint(unet)
|
| 694 |
|
| 695 |
if accelerator.is_main_process:
|
unet/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 7444321360
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e910ef5f2cb2a67f6f4259d48422d509e2b77e51021e41b611e7329aecd84aa0
|
| 3 |
size 7444321360
|