2512
Browse files- dataset.py +9 -2
- media/result_grid.jpg +2 -2
- samples/unet_384x768_0.jpg +2 -2
- samples/unet_416x768_0.jpg +2 -2
- samples/unet_448x768_0.jpg +2 -2
- samples/unet_480x768_0.jpg +2 -2
- samples/unet_512x768_0.jpg +2 -2
- samples/unet_544x768_0.jpg +2 -2
- samples/unet_576x768_0.jpg +2 -2
- samples/unet_608x768_0.jpg +2 -2
- samples/unet_640x768_0.jpg +2 -2
- samples/unet_672x768_0.jpg +2 -2
- samples/unet_704x768_0.jpg +2 -2
- samples/unet_736x768_0.jpg +2 -2
- samples/unet_768x384_0.jpg +2 -2
- samples/unet_768x416_0.jpg +2 -2
- samples/unet_768x448_0.jpg +2 -2
- samples/unet_768x480_0.jpg +2 -2
- samples/unet_768x512_0.jpg +2 -2
- samples/unet_768x544_0.jpg +2 -2
- samples/unet_768x576_0.jpg +2 -2
- samples/unet_768x608_0.jpg +2 -2
- samples/unet_768x640_0.jpg +2 -2
- samples/unet_768x672_0.jpg +2 -2
- samples/unet_768x704_0.jpg +2 -2
- samples/unet_768x736_0.jpg +2 -2
- samples/unet_768x768_0.jpg +2 -2
- test.ipynb +2 -2
- train.py +98 -47
- unet/config.json +2 -2
- unet/diffusion_pytorch_model.safetensors +2 -2
dataset.py
CHANGED
|
@@ -27,8 +27,8 @@ step = 32 #64
|
|
| 27 |
empty_share = 0.0
|
| 28 |
limit = 0
|
| 29 |
# Основная процедура обработки
|
| 30 |
-
folder_path = "/workspace/
|
| 31 |
-
save_path = "/workspace/
|
| 32 |
os.makedirs(save_path, exist_ok=True)
|
| 33 |
|
| 34 |
# Функция для очистки CUDA памяти
|
|
@@ -394,6 +394,13 @@ print(f"Всего найдено {len(image_paths)} изображений")
|
|
| 394 |
# Обработка с чанкованием
|
| 395 |
process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size)
|
| 396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
# Объединение чанков в финальный датасет
|
| 398 |
combine_chunks(temp_path, save_path)
|
| 399 |
|
|
|
|
| 27 |
empty_share = 0.0
|
| 28 |
limit = 0
|
| 29 |
# Основная процедура обработки
|
| 30 |
+
folder_path = "/workspace/alchemist" #alchemist"
|
| 31 |
+
save_path = "/workspace/sdxs/datasets/alchemist" #"alchemist"
|
| 32 |
os.makedirs(save_path, exist_ok=True)
|
| 33 |
|
| 34 |
# Функция для очистки CUDA памяти
|
|
|
|
| 394 |
# Обработка с чанкованием
|
| 395 |
process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size)
|
| 396 |
|
| 397 |
+
# Удаление папки
|
| 398 |
+
try:
|
| 399 |
+
shutil.rmtree(folder_path)
|
| 400 |
+
print(f"✅ Папка {folder_path} успешно удалена")
|
| 401 |
+
except Exception as e:
|
| 402 |
+
print(f"⚠️ Ошибка при удалении папки: {e}")
|
| 403 |
+
|
| 404 |
# Объединение чанков в финальный датасет
|
| 405 |
combine_chunks(temp_path, save_path)
|
| 406 |
|
media/result_grid.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_384x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_416x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_448x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_480x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_512x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_544x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_576x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_608x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_672x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_704x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_736x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x416_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x448_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x480_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x512_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x544_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x576_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x608_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x672_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x736_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_768x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
test.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb11577f5bc95ddaec1ac5dfbd4090581c6ea31c5c5e0bbc1d616d5948ec98aa
|
| 3 |
+
size 9700840
|
train.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from comet_ml import Experiment
|
| 2 |
import os
|
| 3 |
import math
|
| 4 |
import torch
|
|
@@ -29,14 +29,14 @@ from transformers import AutoTokenizer, AutoModel
|
|
| 29 |
# --------------------------- Параметры ---------------------------
|
| 30 |
ds_path = "/workspace/sdxs/datasets/768"
|
| 31 |
project = "unet"
|
| 32 |
-
batch_size =
|
| 33 |
-
base_learning_rate =
|
| 34 |
-
min_learning_rate = 2.
|
| 35 |
-
num_epochs =
|
| 36 |
-
sample_interval_share =
|
| 37 |
max_length = 192
|
| 38 |
-
use_wandb =
|
| 39 |
-
use_comet_ml =
|
| 40 |
save_model = True
|
| 41 |
use_decay = True
|
| 42 |
fbp = False
|
|
@@ -51,11 +51,11 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
|
| 51 |
torch.backends.cudnn.allow_tf32 = True
|
| 52 |
#torch.backends.cuda.enable_mem_efficient_sdp(False)
|
| 53 |
dtype = torch.float32
|
| 54 |
-
save_barrier = 1.
|
| 55 |
warmup_percent = 0.01
|
| 56 |
percentile_clipping = 98
|
| 57 |
betta2 = 0.998
|
| 58 |
-
eps = 1e-
|
| 59 |
clip_grad_norm = 1.0
|
| 60 |
limit = 0
|
| 61 |
checkpoints_folder = ""
|
|
@@ -171,7 +171,7 @@ def encode_texts(texts, max_length=max_length):
|
|
| 171 |
batch_size = hidden.shape[0]
|
| 172 |
pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
|
| 173 |
|
| 174 |
-
#return hidden, attention_mask
|
| 175 |
# --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
|
| 176 |
# 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
|
| 177 |
pooled_expanded = pooled.unsqueeze(1)
|
|
@@ -186,7 +186,7 @@ def encode_texts(texts, max_length=max_length):
|
|
| 186 |
# torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
|
| 187 |
new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
|
| 188 |
|
| 189 |
-
return new_encoder_hidden_states, new_attention_mask
|
| 190 |
|
| 191 |
shift_factor = getattr(vae.config, "shift_factor", 0.0)
|
| 192 |
if shift_factor is None: shift_factor = 0.0
|
|
@@ -287,9 +287,9 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
|
|
| 287 |
texts = [item["text"] for item in samples_data]
|
| 288 |
|
| 289 |
# Кодируем тексты на лету, чтобы получить маски и пулинг
|
| 290 |
-
embeddings, masks
|
| 291 |
|
| 292 |
-
fixed_samples[size] = (latents, embeddings, masks,
|
| 293 |
|
| 294 |
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
|
| 295 |
return fixed_samples
|
|
@@ -309,19 +309,19 @@ def collate_fn_simple(batch):
|
|
| 309 |
texts = [
|
| 310 |
"" if t.lower().startswith("zero")
|
| 311 |
else "" if random.random() < 0.05
|
| 312 |
-
else t[1:].lstrip() if t.startswith(".")
|
| 313 |
-
else t
|
| 314 |
for t in raw_texts
|
| 315 |
]
|
| 316 |
|
| 317 |
# 3. Кодируем на лету
|
| 318 |
-
# Возвращает: hidden (B, L, D), mask (B, L)
|
| 319 |
-
embeddings, attention_mask
|
| 320 |
|
| 321 |
# attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
|
| 322 |
attention_mask = attention_mask.to(dtype=torch.int64)
|
| 323 |
|
| 324 |
-
return latents, embeddings, attention_mask
|
| 325 |
|
| 326 |
batch_sampler = DistributedResolutionBatchSampler(
|
| 327 |
dataset=dataset,
|
|
@@ -375,8 +375,70 @@ def create_optimizer(name, params):
|
|
| 375 |
)
|
| 376 |
elif name == "adam":
|
| 377 |
return torch.optim.AdamW(
|
| 378 |
-
params, lr=base_learning_rate, betas=(0.9,
|
| 379 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
else:
|
| 381 |
raise ValueError(f"Unknown optimizer: {name}")
|
| 382 |
|
|
@@ -418,24 +480,22 @@ def get_negative_embedding(neg_prompt="", batch_size=1):
|
|
| 418 |
seq_len = max_length
|
| 419 |
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
|
| 420 |
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
|
| 421 |
-
|
| 422 |
-
return empty_emb, empty_mask, empty_pool
|
| 423 |
|
| 424 |
-
uncond_emb, uncond_mask
|
| 425 |
uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
|
| 426 |
uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
|
| 427 |
-
uncond_pool = uncond_pool.to(dtype=dtype, device=device).repeat(batch_size, 1)
|
| 428 |
|
| 429 |
-
return uncond_emb, uncond_mask
|
| 430 |
|
| 431 |
# Получаем негативные (пустые) условия для валидации
|
| 432 |
-
uncond_emb, uncond_mask
|
| 433 |
|
| 434 |
-
# ---
|
| 435 |
@torch.compiler.disable()
|
| 436 |
@torch.no_grad()
|
| 437 |
def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
| 438 |
-
uncond_emb, uncond_mask
|
| 439 |
|
| 440 |
original_model = None
|
| 441 |
try:
|
|
@@ -449,13 +509,12 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
|
| 449 |
all_generated_images = []
|
| 450 |
all_captions = []
|
| 451 |
|
| 452 |
-
# Распаковываем 5 элементов (добавились mask
|
| 453 |
-
for size, (sample_latents, sample_text_embeddings, sample_mask,
|
| 454 |
width, height = size
|
| 455 |
sample_latents = sample_latents.to(dtype=dtype, device=device)
|
| 456 |
sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
|
| 457 |
sample_mask = sample_mask.to(device=device)
|
| 458 |
-
sample_pool = sample_pool.to(dtype=dtype, device=device)
|
| 459 |
|
| 460 |
latents = torch.randn(
|
| 461 |
sample_latents.shape,
|
|
@@ -483,15 +542,10 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
|
| 483 |
neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
|
| 484 |
attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
|
| 485 |
|
| 486 |
-
# 3. Pools
|
| 487 |
-
neg_pool_batch = uncond_pool[0:1].expand(curr_batch_size, -1)
|
| 488 |
-
pooled_batch = torch.cat([neg_pool_batch, sample_pool], dim=0)
|
| 489 |
-
|
| 490 |
else:
|
| 491 |
latent_model_input = latents
|
| 492 |
text_embeddings_batch = sample_text_embeddings
|
| 493 |
attention_mask_batch = sample_mask
|
| 494 |
-
pooled_batch = sample_pool
|
| 495 |
|
| 496 |
# Предсказание с передачей всех условий
|
| 497 |
model_out = original_model(
|
|
@@ -499,7 +553,6 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
|
| 499 |
t,
|
| 500 |
encoder_hidden_states=text_embeddings_batch,
|
| 501 |
encoder_attention_mask=attention_mask_batch,
|
| 502 |
-
#added_cond_kwargs={"text_embeds": pooled_batch} # <--- ПУЛИНГ ЗДЕСЬ
|
| 503 |
)
|
| 504 |
flow = getattr(model_out, "sample", model_out)
|
| 505 |
|
|
@@ -560,7 +613,7 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
|
| 560 |
if accelerator.is_main_process:
|
| 561 |
if save_model:
|
| 562 |
print("Генерация сэмплов до старта обучения...")
|
| 563 |
-
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask
|
| 564 |
accelerator.wait_for_everyone()
|
| 565 |
|
| 566 |
def save_checkpoint(unet, variant=""):
|
|
@@ -601,8 +654,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
| 601 |
accelerator.wait_for_everyone()
|
| 602 |
unet.train()
|
| 603 |
|
| 604 |
-
|
| 605 |
-
for step, (latents, embeddings, attention_mask, pooled_embeddings) in enumerate(dataloader):
|
| 606 |
with accelerator.accumulate(unet):
|
| 607 |
if save_model == False and step == 5 :
|
| 608 |
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
|
@@ -617,13 +669,12 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
| 617 |
# делаем integer timesteps для UNet
|
| 618 |
timesteps = (t * scheduler.config.num_train_timesteps).long()
|
| 619 |
|
| 620 |
-
# ---
|
| 621 |
model_pred = unet(
|
| 622 |
noisy_latents,
|
| 623 |
timesteps,
|
| 624 |
encoder_hidden_states=embeddings,
|
| 625 |
-
encoder_attention_mask=attention_mask
|
| 626 |
-
#added_cond_kwargs={"text_embeds": pooled_embeddings} # <--- Передача пулинга
|
| 627 |
).sample
|
| 628 |
|
| 629 |
target = noise - latents
|
|
@@ -641,8 +692,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
| 641 |
grad = 0.0
|
| 642 |
if not fbp:
|
| 643 |
if accelerator.sync_gradients:
|
| 644 |
-
with torch.amp.autocast('cuda', enabled=False):
|
| 645 |
-
|
| 646 |
grad = float(grad_val)
|
| 647 |
optimizer.step()
|
| 648 |
lr_scheduler.step()
|
|
@@ -669,8 +720,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
| 669 |
comet_experiment.log_metrics(log_data, step=global_step)
|
| 670 |
|
| 671 |
if global_step % sample_interval == 0:
|
| 672 |
-
# Передаем tuple (emb, mask
|
| 673 |
-
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask
|
| 674 |
last_n = sample_interval
|
| 675 |
|
| 676 |
if save_model:
|
|
|
|
| 1 |
+
#from comet_ml import Experiment
|
| 2 |
import os
|
| 3 |
import math
|
| 4 |
import torch
|
|
|
|
| 29 |
# --------------------------- Параметры ---------------------------
|
| 30 |
ds_path = "/workspace/sdxs/datasets/768"
|
| 31 |
project = "unet"
|
| 32 |
+
batch_size = 60
|
| 33 |
+
base_learning_rate = 2.7e-5
|
| 34 |
+
min_learning_rate = 2.7e-5
|
| 35 |
+
num_epochs = 100
|
| 36 |
+
sample_interval_share = 5
|
| 37 |
max_length = 192
|
| 38 |
+
use_wandb = True
|
| 39 |
+
use_comet_ml = False
|
| 40 |
save_model = True
|
| 41 |
use_decay = True
|
| 42 |
fbp = False
|
|
|
|
| 51 |
torch.backends.cudnn.allow_tf32 = True
|
| 52 |
#torch.backends.cuda.enable_mem_efficient_sdp(False)
|
| 53 |
dtype = torch.float32
|
| 54 |
+
save_barrier = 1.006
|
| 55 |
warmup_percent = 0.01
|
| 56 |
percentile_clipping = 98
|
| 57 |
betta2 = 0.998
|
| 58 |
+
eps = 1e-7
|
| 59 |
clip_grad_norm = 1.0
|
| 60 |
limit = 0
|
| 61 |
checkpoints_folder = ""
|
|
|
|
| 171 |
batch_size = hidden.shape[0]
|
| 172 |
pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
|
| 173 |
|
| 174 |
+
#return hidden, attention_mask
|
| 175 |
# --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
|
| 176 |
# 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
|
| 177 |
pooled_expanded = pooled.unsqueeze(1)
|
|
|
|
| 186 |
# torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
|
| 187 |
new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
|
| 188 |
|
| 189 |
+
return new_encoder_hidden_states, new_attention_mask
|
| 190 |
|
| 191 |
shift_factor = getattr(vae.config, "shift_factor", 0.0)
|
| 192 |
if shift_factor is None: shift_factor = 0.0
|
|
|
|
| 287 |
texts = [item["text"] for item in samples_data]
|
| 288 |
|
| 289 |
# Кодируем тексты на лету, чтобы получить маски и пулинг
|
| 290 |
+
embeddings, masks = encode_texts(texts)
|
| 291 |
|
| 292 |
+
fixed_samples[size] = (latents, embeddings, masks, texts)
|
| 293 |
|
| 294 |
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
|
| 295 |
return fixed_samples
|
|
|
|
| 309 |
texts = [
|
| 310 |
"" if t.lower().startswith("zero")
|
| 311 |
else "" if random.random() < 0.05
|
| 312 |
+
else t[1:].lstrip() if t.startswith(".")
|
| 313 |
+
else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
|
| 314 |
for t in raw_texts
|
| 315 |
]
|
| 316 |
|
| 317 |
# 3. Кодируем на лету
|
| 318 |
+
# Возвращает: hidden (B, L, D), mask (B, L)
|
| 319 |
+
embeddings, attention_mask = encode_texts(texts)
|
| 320 |
|
| 321 |
# attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
|
| 322 |
attention_mask = attention_mask.to(dtype=torch.int64)
|
| 323 |
|
| 324 |
+
return latents, embeddings, attention_mask
|
| 325 |
|
| 326 |
batch_sampler = DistributedResolutionBatchSampler(
|
| 327 |
dataset=dataset,
|
|
|
|
| 375 |
)
|
| 376 |
elif name == "adam":
|
| 377 |
return torch.optim.AdamW(
|
| 378 |
+
params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
|
| 379 |
)
|
| 380 |
+
elif name == "adamuon":
|
| 381 |
+
from pytorch_optimizer import AdaMuon
|
| 382 |
+
# 1. РАЗДЕЛЕНИЕ ПАРАМЕТРОВ
|
| 383 |
+
# Разделяем параметры на те, которые Muon-совместимы (ndim >= 2),
|
| 384 |
+
# и те, которые лучше оставить для AdamW8bit (ndim < 2).
|
| 385 |
+
|
| 386 |
+
muon_params = []
|
| 387 |
+
adam8bit_params = []
|
| 388 |
+
|
| 389 |
+
if isinstance(params, (list, tuple)):
|
| 390 |
+
all_params = params
|
| 391 |
+
else:
|
| 392 |
+
# Если params - это генератор (unet.parameters()), преобразуем его в list
|
| 393 |
+
all_params = list(params)
|
| 394 |
+
|
| 395 |
+
for p in all_params:
|
| 396 |
+
if p.requires_grad:
|
| 397 |
+
if p.ndim >= 2:
|
| 398 |
+
muon_params.append(p)
|
| 399 |
+
else:
|
| 400 |
+
adam8bit_params.append(p)
|
| 401 |
+
|
| 402 |
+
# 2. СОЗДАНИЕ ГРУПП ПАРАМЕТРОВ
|
| 403 |
+
param_groups = []
|
| 404 |
+
|
| 405 |
+
# --- Группа 1: AdaMuon (для матриц/сверток) ---
|
| 406 |
+
if muon_params:
|
| 407 |
+
print(f"Используем AdaMuon для {len(muon_params)} параметров (ndim >= 2)")
|
| 408 |
+
param_groups.append({
|
| 409 |
+
'params': muon_params,
|
| 410 |
+
'optimizer': 'AdaMuon', # Явно указываем оптимизатор
|
| 411 |
+
'use_muon': True, # Обязательно для AdaMuon
|
| 412 |
+
'lr': base_learning_rate, # Использование обычных пробелов
|
| 413 |
+
'betas': (0.9, betta2),
|
| 414 |
+
'eps': eps,
|
| 415 |
+
'weight_decay': 0.01,
|
| 416 |
+
'nesterov': True,
|
| 417 |
+
'decouple_weight_decay': True
|
| 418 |
+
})
|
| 419 |
+
|
| 420 |
+
# --- Группа 2: AdamW8bit (для смещений/нормализации) ---
|
| 421 |
+
if adam8bit_params:
|
| 422 |
+
print(f"Используем AdamW8bit для {len(adam8bit_params)} параметров (ndim < 2)")
|
| 423 |
+
param_groups.append({
|
| 424 |
+
'params': adam8bit_params,
|
| 425 |
+
'optimizer': 'AdamW8bit', # Явно указываем оптимизатор
|
| 426 |
+
'use_muon': False,
|
| 427 |
+
# AdamW8bit не требует use_muon
|
| 428 |
+
'lr': base_learning_rate, # Использование обычных пробелов
|
| 429 |
+
'betas': (0.9, betta2),
|
| 430 |
+
'eps': eps,
|
| 431 |
+
'weight_decay': 0.01,
|
| 432 |
+
'percentile_clipping': percentile_clipping
|
| 433 |
+
})
|
| 434 |
+
|
| 435 |
+
if not param_groups:
|
| 436 |
+
raise ValueError("Нет параметров для оптимизации.")
|
| 437 |
+
|
| 438 |
+
# 3. ИНИЦИАЛИЗАЦИЯ (смешанный оптимизатор)
|
| 439 |
+
# При передаче списка групп оптимизатор AdaMuon автоматически
|
| 440 |
+
# распознает ключи 'optimizer' и инициализирует внутренние оптимизаторы.
|
| 441 |
+
return AdaMuon(param_groups)
|
| 442 |
else:
|
| 443 |
raise ValueError(f"Unknown optimizer: {name}")
|
| 444 |
|
|
|
|
| 480 |
seq_len = max_length
|
| 481 |
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
|
| 482 |
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
|
| 483 |
+
return empty_emb, empty_mask
|
|
|
|
| 484 |
|
| 485 |
+
uncond_emb, uncond_mask = encode_texts([neg_prompt])
|
| 486 |
uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
|
| 487 |
uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
|
|
|
|
| 488 |
|
| 489 |
+
return uncond_emb, uncond_mask
|
| 490 |
|
| 491 |
# Получаем негативные (пустые) условия для валидации
|
| 492 |
+
uncond_emb, uncond_mask = get_negative_embedding("low quality")
|
| 493 |
|
| 494 |
+
# --- Функция генерации семплов ---
|
| 495 |
@torch.compiler.disable()
|
| 496 |
@torch.no_grad()
|
| 497 |
def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
| 498 |
+
uncond_emb, uncond_mask = uncond_data
|
| 499 |
|
| 500 |
original_model = None
|
| 501 |
try:
|
|
|
|
| 509 |
all_generated_images = []
|
| 510 |
all_captions = []
|
| 511 |
|
| 512 |
+
# Распаковываем 5 элементов (добавились mask)
|
| 513 |
+
for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
|
| 514 |
width, height = size
|
| 515 |
sample_latents = sample_latents.to(dtype=dtype, device=device)
|
| 516 |
sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
|
| 517 |
sample_mask = sample_mask.to(device=device)
|
|
|
|
| 518 |
|
| 519 |
latents = torch.randn(
|
| 520 |
sample_latents.shape,
|
|
|
|
| 542 |
neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
|
| 543 |
attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
else:
|
| 546 |
latent_model_input = latents
|
| 547 |
text_embeddings_batch = sample_text_embeddings
|
| 548 |
attention_mask_batch = sample_mask
|
|
|
|
| 549 |
|
| 550 |
# Предсказание с передачей всех условий
|
| 551 |
model_out = original_model(
|
|
|
|
| 553 |
t,
|
| 554 |
encoder_hidden_states=text_embeddings_batch,
|
| 555 |
encoder_attention_mask=attention_mask_batch,
|
|
|
|
| 556 |
)
|
| 557 |
flow = getattr(model_out, "sample", model_out)
|
| 558 |
|
|
|
|
| 613 |
if accelerator.is_main_process:
|
| 614 |
if save_model:
|
| 615 |
print("Генерация сэмплов до старта обучения...")
|
| 616 |
+
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
|
| 617 |
accelerator.wait_for_everyone()
|
| 618 |
|
| 619 |
def save_checkpoint(unet, variant=""):
|
|
|
|
| 654 |
accelerator.wait_for_everyone()
|
| 655 |
unet.train()
|
| 656 |
|
| 657 |
+
for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
|
|
|
|
| 658 |
with accelerator.accumulate(unet):
|
| 659 |
if save_model == False and step == 5 :
|
| 660 |
used_gb = torch.cuda.max_memory_allocated() / 1024**3
|
|
|
|
| 669 |
# делаем integer timesteps для UNet
|
| 670 |
timesteps = (t * scheduler.config.num_train_timesteps).long()
|
| 671 |
|
| 672 |
+
# --- Вызов UNet с маской ---
|
| 673 |
model_pred = unet(
|
| 674 |
noisy_latents,
|
| 675 |
timesteps,
|
| 676 |
encoder_hidden_states=embeddings,
|
| 677 |
+
encoder_attention_mask=attention_mask
|
|
|
|
| 678 |
).sample
|
| 679 |
|
| 680 |
target = noise - latents
|
|
|
|
| 692 |
grad = 0.0
|
| 693 |
if not fbp:
|
| 694 |
if accelerator.sync_gradients:
|
| 695 |
+
#with torch.amp.autocast('cuda', enabled=False):
|
| 696 |
+
grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
|
| 697 |
grad = float(grad_val)
|
| 698 |
optimizer.step()
|
| 699 |
lr_scheduler.step()
|
|
|
|
| 720 |
comet_experiment.log_metrics(log_data, step=global_step)
|
| 721 |
|
| 722 |
if global_step % sample_interval == 0:
|
| 723 |
+
# Передаем tuple (emb, mask) для негатива
|
| 724 |
+
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
|
| 725 |
last_n = sample_interval
|
| 726 |
|
| 727 |
if save_model:
|
unet/config.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ab7222cdd538ff5178adc870a764d22ab24a185f0a7b63852ea728b3b09fcff
|
| 3 |
+
size 1876
|
unet/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b202a600c7baf31b986d98f00109d0c396f0fddc91bab63c49a01c3477eeba1
|
| 3 |
+
size 6604736640
|