527
Browse files- samples/unet_192x384_0.jpg +2 -2
- samples/unet_256x384_0.jpg +2 -2
- samples/unet_320x384_0.jpg +2 -2
- samples/unet_384x192_0.jpg +2 -2
- samples/unet_384x256_0.jpg +2 -2
- samples/unet_384x320_0.jpg +2 -2
- samples/unet_384x384_0.jpg +2 -2
- train.py +44 -70
- unet/diffusion_pytorch_model.safetensors +1 -1
samples/unet_192x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_256x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_320x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_384x192_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_384x256_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_384x320_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_384x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
train.py
CHANGED
|
@@ -37,7 +37,7 @@ use_wandb = True
|
|
| 37 |
save_model = True
|
| 38 |
use_decay = True
|
| 39 |
fbp = False # fused backward pass
|
| 40 |
-
|
| 41 |
torch_compile = False
|
| 42 |
unet_gradient = True
|
| 43 |
clip_sample = False #Scheduler
|
|
@@ -187,7 +187,7 @@ if use_wandb and accelerator.is_main_process:
|
|
| 187 |
"base_learning_rate": base_learning_rate,
|
| 188 |
"num_epochs": num_epochs,
|
| 189 |
"fbp": fbp,
|
| 190 |
-
"
|
| 191 |
})
|
| 192 |
|
| 193 |
# Включение Flash Attention 2/SDPA
|
|
@@ -495,89 +495,63 @@ else:
|
|
| 495 |
if fbp:
|
| 496 |
trainable_params = list(unet.parameters())
|
| 497 |
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
else:
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
weight_decay=0.01,
|
| 515 |
-
percentile_clipping=percentile_clipping,
|
| 516 |
-
) for p in trainable_params
|
| 517 |
-
}
|
| 518 |
|
| 519 |
-
# [2] Определяем hook для применения оптимизатора сразу после накопления градиента
|
| 520 |
def optimizer_hook(param):
|
| 521 |
optimizer_dict[param].step()
|
| 522 |
optimizer_dict[param].zero_grad(set_to_none=True)
|
| 523 |
|
| 524 |
-
# [3] Регистрируем hook для trainable параметров модели
|
| 525 |
for param in trainable_params:
|
| 526 |
param.register_post_accumulate_grad_hook(optimizer_hook)
|
| 527 |
|
| 528 |
-
# Подготовка через Accelerator
|
| 529 |
unet, optimizer = accelerator.prepare(unet, optimizer_dict)
|
| 530 |
else:
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
weight_decay=0.01
|
| 538 |
-
)
|
| 539 |
-
#from torch.optim import AdamW
|
| 540 |
-
#optimizer = AdamW(
|
| 541 |
-
# params=unet.parameters(),
|
| 542 |
-
# lr=base_learning_rate,
|
| 543 |
-
# betas=(0.9, 0.999),
|
| 544 |
-
# eps=1e-8,
|
| 545 |
-
# weight_decay=0.01
|
| 546 |
-
#)
|
| 547 |
-
else:
|
| 548 |
-
optimizer = bnb.optim.Lion8bit(
|
| 549 |
-
params=unet.parameters(),
|
| 550 |
-
lr=base_learning_rate,
|
| 551 |
-
betas=(0.9, 0.97),
|
| 552 |
-
weight_decay=0.01,
|
| 553 |
-
percentile_clipping=percentile_clipping,
|
| 554 |
-
)
|
| 555 |
-
|
| 556 |
-
def lr_schedule(step, max_steps, base_lr, min_lr, use_decay=True):
|
| 557 |
-
# Если не используем затухание, возвращаем базовый LR
|
| 558 |
if not use_decay:
|
| 559 |
-
return
|
|
|
|
|
|
|
| 560 |
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
if x < percent:
|
| 565 |
-
# Линейный прогрев до percent% шагов
|
| 566 |
-
return min_lr + (base_lr - min_lr) * (x / percent)
|
| 567 |
-
else:
|
| 568 |
-
# Косинусное затухание
|
| 569 |
-
decay_ratio = (x - percent) / (1 - percent)
|
| 570 |
-
return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * decay_ratio))
|
| 571 |
|
| 572 |
-
|
| 573 |
-
def custom_lr_lambda(step):
|
| 574 |
-
return lr_schedule(step, total_training_steps*world_size,
|
| 575 |
-
base_learning_rate, min_learning_rate,
|
| 576 |
-
use_decay) / base_learning_rate
|
| 577 |
-
|
| 578 |
-
lr_scheduler = LambdaLR(optimizer, lr_lambda=custom_lr_lambda)
|
| 579 |
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
| 580 |
-
|
| 581 |
# Регистрация хуков ПОСЛЕ prepare
|
| 582 |
dispersive_hook.register_hooks(unet, "down_blocks.2")
|
| 583 |
|
|
|
|
| 37 |
save_model = True
|
| 38 |
use_decay = True
|
| 39 |
fbp = False # fused backward pass
|
| 40 |
+
optimizer_type = "adam8bit"
|
| 41 |
torch_compile = False
|
| 42 |
unet_gradient = True
|
| 43 |
clip_sample = False #Scheduler
|
|
|
|
| 187 |
"base_learning_rate": base_learning_rate,
|
| 188 |
"num_epochs": num_epochs,
|
| 189 |
"fbp": fbp,
|
| 190 |
+
"optimizer_type": optimizer_type,
|
| 191 |
})
|
| 192 |
|
| 193 |
# Включение Flash Attention 2/SDPA
|
|
|
|
| 495 |
if fbp:
|
| 496 |
trainable_params = list(unet.parameters())
|
| 497 |
|
| 498 |
+
def create_optimizer(name, params):
|
| 499 |
+
if name == "adam8bit":
|
| 500 |
+
import bitsandbytes as bnb
|
| 501 |
+
return bnb.optim.AdamW8bit(
|
| 502 |
+
params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
|
| 503 |
+
)
|
| 504 |
+
elif name == "adam":
|
| 505 |
+
return torch.optim.AdamW(
|
| 506 |
+
params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
|
| 507 |
+
)
|
| 508 |
+
elif name == "lion8bit":
|
| 509 |
+
import bitsandbytes as bnb
|
| 510 |
+
return bnb.optim.Lion8bit(
|
| 511 |
+
params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
|
| 512 |
+
percentile_clipping=percentile_clipping
|
| 513 |
+
)
|
| 514 |
+
elif name == "adafactor":
|
| 515 |
+
from transformers import Adafactor
|
| 516 |
+
return Adafactor(
|
| 517 |
+
params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
|
| 518 |
+
warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
|
| 519 |
+
beta1=0.9, weight_decay=0.01
|
| 520 |
+
)
|
| 521 |
else:
|
| 522 |
+
raise ValueError(f"Unknown optimizer: {name}")
|
| 523 |
+
|
| 524 |
+
if fbp:
|
| 525 |
+
# Создаем отдельный оптимизатор для каждого параметра
|
| 526 |
+
optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
|
|
|
|
| 528 |
def optimizer_hook(param):
|
| 529 |
optimizer_dict[param].step()
|
| 530 |
optimizer_dict[param].zero_grad(set_to_none=True)
|
| 531 |
|
|
|
|
| 532 |
for param in trainable_params:
|
| 533 |
param.register_post_accumulate_grad_hook(optimizer_hook)
|
| 534 |
|
|
|
|
| 535 |
unet, optimizer = accelerator.prepare(unet, optimizer_dict)
|
| 536 |
else:
|
| 537 |
+
optimizer = create_optimizer(optimizer_type, unet.parameters())
|
| 538 |
+
|
| 539 |
+
def lr_schedule(step):
|
| 540 |
+
x = step / (total_training_steps * world_size)
|
| 541 |
+
warmup = 0.05
|
| 542 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
if not use_decay:
|
| 544 |
+
return base_learning_rate
|
| 545 |
+
if x < warmup:
|
| 546 |
+
return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
|
| 547 |
|
| 548 |
+
decay_ratio = (x - warmup) / (1 - warmup)
|
| 549 |
+
return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
|
| 550 |
+
(1 + math.cos(math.pi * decay_ratio))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
+
lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
| 554 |
+
|
| 555 |
# Регистрация хуков ПОСЛЕ prepare
|
| 556 |
dispersive_hook.register_hooks(unet, "down_blocks.2")
|
| 557 |
|
unet/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 7014306128
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a3a5c9271f789977d7738b696c288c509253bbb416e182ab58925101a0a9409
|
| 3 |
size 7014306128
|