recoilme commited on
Commit
6df0c8d
·
1 Parent(s): f575c8a
dataset.py CHANGED
@@ -27,8 +27,8 @@ step = 32 #64
27
  empty_share = 0.0
28
  limit = 0
29
  # Основная процедура обработки
30
- folder_path = "/workspace/ds" #alchemist"
31
- save_path = "/workspace/768" #"alchemist"
32
  os.makedirs(save_path, exist_ok=True)
33
 
34
  # Функция для очистки CUDA памяти
@@ -394,6 +394,13 @@ print(f"Всего найдено {len(image_paths)} изображений")
394
  # Обработка с чанкованием
395
  process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size)
396
 
 
 
 
 
 
 
 
397
  # Объединение чанков в финальный датасет
398
  combine_chunks(temp_path, save_path)
399
 
 
27
  empty_share = 0.0
28
  limit = 0
29
  # Основная процедура обработки
30
+ folder_path = "/workspace/alchemist" #alchemist"
31
+ save_path = "/workspace/sdxs/datasets/alchemist" #"alchemist"
32
  os.makedirs(save_path, exist_ok=True)
33
 
34
  # Функция для очистки CUDA памяти
 
394
  # Обработка с чанкованием
395
  process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size)
396
 
397
+ # Удаление папки
398
+ try:
399
+ shutil.rmtree(folder_path)
400
+ print(f"✅ Папка {folder_path} успешно удалена")
401
+ except Exception as e:
402
+ print(f"⚠️ Ошибка при удалении папки: {e}")
403
+
404
  # Объединение чанков в финальный датасет
405
  combine_chunks(temp_path, save_path)
406
 
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: be304fdd651f5cb73ffccd5d56c02991330e9d7d725dee385d60090f07b3123c
  • Pointer size: 132 Bytes
  • Size of remote file: 4.42 MB

Git LFS Details

  • SHA256: 7180ef8fc767e120ac38569d218b221eb11c56fa3a718a554ccc9f2301b0253f
  • Pointer size: 132 Bytes
  • Size of remote file: 5.37 MB
samples/unet_384x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 87ea16b6a264e247da3287c5069b2e332c39504ae55cecc2cd33059103765a58
  • Pointer size: 130 Bytes
  • Size of remote file: 66.1 kB

Git LFS Details

  • SHA256: dc7cf8347165908d515bbd58c69a7997c60bc399a510a54ce3f7cdc578e55272
  • Pointer size: 131 Bytes
  • Size of remote file: 196 kB
samples/unet_416x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 4d9f305603aa3ffdd6858a79fc17924ddbbca4946c0563c995b58b7a41f4fd1c
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB

Git LFS Details

  • SHA256: c503cc7a8754a496f048f62a0f3dc44292795114b963705cedff8591b5cf8095
  • Pointer size: 130 Bytes
  • Size of remote file: 83.2 kB
samples/unet_448x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 7c79dd999533dda77936afcfb43d77564bff3d46b050eee3c8620e897a9fc917
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB

Git LFS Details

  • SHA256: 8620b9c0bc250a4891965081a0ac1d3f0f60601fd9c3886081cc13a23eec83f3
  • Pointer size: 130 Bytes
  • Size of remote file: 60.2 kB
samples/unet_480x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 0b50d094956d7c966922a068b4265b9a1d0ca9a5a4ac431211f185d83459bacd
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB

Git LFS Details

  • SHA256: ef43348df1f6b22b80dc17bda15d9e92d981eb962f1451d655d926b85e5ebd86
  • Pointer size: 131 Bytes
  • Size of remote file: 188 kB
samples/unet_512x768_0.jpg CHANGED

Git LFS Details

  • SHA256: fc5e128ce9f530c1805539d78b818a64cc1f35610062e58d8379517a2552b374
  • Pointer size: 130 Bytes
  • Size of remote file: 77.2 kB

Git LFS Details

  • SHA256: 2065cbdb2f54e0cfdc7867ce64b95795d017eeb31152791b50c4f36830fdf751
  • Pointer size: 131 Bytes
  • Size of remote file: 226 kB
samples/unet_544x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 1fdc9bfe0b10fcae9d370c736665a5fd2200be8fe421cac3be7d5a7a275672cb
  • Pointer size: 130 Bytes
  • Size of remote file: 85.7 kB

Git LFS Details

  • SHA256: 942719c5a9239737eca1281c2164b622232ab9416c0eaff26a76e7b2fda99d3f
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
samples/unet_576x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 7c640984ccb67f0acdf3f0734f98db1595d78be9de3df5d0815ba51e4fca3873
  • Pointer size: 131 Bytes
  • Size of remote file: 290 kB

Git LFS Details

  • SHA256: 492be57c73fbb2b94ce8ace010510647b32a4c7c4c4f05651e790aad060f90ea
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
samples/unet_608x768_0.jpg CHANGED

Git LFS Details

  • SHA256: be21b0ec3ab9ccbfc4897cb355f1239f93b2baae934c6e96454748ec391f59c8
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB

Git LFS Details

  • SHA256: 55b9a8ecb378bea8a39d12af76efd9959679721ab2aab7ddeb19bf0e400a87a6
  • Pointer size: 131 Bytes
  • Size of remote file: 172 kB
samples/unet_640x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 1c18aceae372fc77055bac9453ebdbafe1ec3fb324540522784a9b099e1bc812
  • Pointer size: 130 Bytes
  • Size of remote file: 91.3 kB

Git LFS Details

  • SHA256: d6cb49c9e2d09d7829210cd4c8d614e7773d42823e740985770559044d4c8e16
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
samples/unet_672x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 6f691232ec6308d6a144d6d7472b922731bc0c382237d6e4b64908fd0a85f1a5
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB

Git LFS Details

  • SHA256: 5969f754d1c1a5c774ee54a143038980f6be054aad3663ffea908de9dde494d1
  • Pointer size: 130 Bytes
  • Size of remote file: 69.8 kB
samples/unet_704x768_0.jpg CHANGED

Git LFS Details

  • SHA256: a4921da6a4a994f11e76ef804c088cb794cd2dfad5f334cc03ea746f4f8b867d
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB

Git LFS Details

  • SHA256: 4503966b6019cc1462b79dfdde243a33925c50b4ee92b647d08c0bb0a36e4ca1
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
samples/unet_736x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 3855348d5f1f4d4300b64fb733861d73a82dd45f0d9a0856024d9d463893c571
  • Pointer size: 130 Bytes
  • Size of remote file: 48.2 kB

Git LFS Details

  • SHA256: 87ca0a06d8f3af59ad1b0e1767f8f5396f40fea713781f97be7f19202975d0b5
  • Pointer size: 130 Bytes
  • Size of remote file: 82.5 kB
samples/unet_768x384_0.jpg CHANGED

Git LFS Details

  • SHA256: ae210869b1ca765f04cd979b9962c07f8275c3b814bd60bb9d05d34a00dae853
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB

Git LFS Details

  • SHA256: 7b3bead76b8c19cd370c82c6f4449cd84a1f487228817eeb2f2ab2854e5526d5
  • Pointer size: 131 Bytes
  • Size of remote file: 170 kB
samples/unet_768x416_0.jpg CHANGED

Git LFS Details

  • SHA256: ce6d00486ae7c81d2ec8ea2e41dfe21ae49cbbcfd75a0a4e4f2e02c6c09a47c0
  • Pointer size: 130 Bytes
  • Size of remote file: 87.3 kB

Git LFS Details

  • SHA256: 8244a4630fc0d34e05e5eb801e34495bc29a429bfd4e195ac1186be30141626a
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
samples/unet_768x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 1331bbf43e99ea75d80d2c01598e187b6e3f91fe5c1213247c1dd719d2dc31fd
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB

Git LFS Details

  • SHA256: f9b901ef7a15a341c454926294aec16df7458b32bf878741ef2d15e1307fd9fe
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
samples/unet_768x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 31b71dbe5fc2053c4bab6997b3e324ad2919c43a0353dc3326a28f1495035b55
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB

Git LFS Details

  • SHA256: 1cc0c2bf521ecdd56937f4289f123a7092c9a3db2fd31e5b8c51628a44ee2054
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
samples/unet_768x512_0.jpg CHANGED

Git LFS Details

  • SHA256: d92aa72f3bfd16adb23cfb2f486f53c80f50a06455e5e0956428ef371bfd9630
  • Pointer size: 131 Bytes
  • Size of remote file: 298 kB

Git LFS Details

  • SHA256: 07f0547d2ad34814d22624e4f189be9ce3137d2b633b9caf47c2ad97efda9ecc
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
samples/unet_768x544_0.jpg CHANGED

Git LFS Details

  • SHA256: 44a92d99624a6b873295899db86ee853ca43e5f45f785eb56b4beb7207522a0b
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB

Git LFS Details

  • SHA256: 4042250d6f689d23fb2d4bc716d17d6da0dbff91fed6be40b03e0f9d2c2f1638
  • Pointer size: 130 Bytes
  • Size of remote file: 85.4 kB
samples/unet_768x576_0.jpg CHANGED

Git LFS Details

  • SHA256: bcd2f8544b0dcc5c7e738699b5641be63e592c1d9056cc2b3f7504c885e5f9ac
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB

Git LFS Details

  • SHA256: 1bed1b9846d19cd9fe2a92434197750ad78401d31ce864146acc68bf135e5284
  • Pointer size: 131 Bytes
  • Size of remote file: 192 kB
samples/unet_768x608_0.jpg CHANGED

Git LFS Details

  • SHA256: adeb415b9591ada3681039eb6b56aba22852656a11f636485417d3bda53ad50f
  • Pointer size: 131 Bytes
  • Size of remote file: 188 kB

Git LFS Details

  • SHA256: 899b41f901562dae78dfbaea7c5fc05c7ea077ec70a9a0a55a9a51ee84d340e3
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
samples/unet_768x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 963396d5ceac0a41ec025dd5d9a8d61588e3b1ea7a6d532847d1393e44e64a01
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB

Git LFS Details

  • SHA256: 9620d1d431a896209545e80dab20944faa8a9caf05e36908390f958487d4f23b
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB
samples/unet_768x672_0.jpg CHANGED

Git LFS Details

  • SHA256: 1ce5fb2a8422c74d86d57cdb0e0437c0215ccbcab9f86d1a26a641b92d9978b9
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB

Git LFS Details

  • SHA256: c08b0ab48096384e4d4d1fdc15af058de1cf4e993729b84d375d1abddf2883de
  • Pointer size: 131 Bytes
  • Size of remote file: 184 kB
samples/unet_768x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 9e38bac28ea95caa0f61c0bb0650add0a7792642b07cc94ae326eb9c9191f297
  • Pointer size: 130 Bytes
  • Size of remote file: 68.2 kB

Git LFS Details

  • SHA256: 7589468ef88774960922e073fbb5e635e0c89c51315492f2cf3765df7f2b5a88
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
samples/unet_768x736_0.jpg CHANGED

Git LFS Details

  • SHA256: 632e20bc44d509b11ad2e0a5316188e21b900cf8063a187091137043a74ecc59
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB

Git LFS Details

  • SHA256: 330ef60f7230e1bb62724446c2bcb574367ecca211a6454e0a6f5c7a07e0fe76
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB
samples/unet_768x768_0.jpg CHANGED

Git LFS Details

  • SHA256: f3be6993576cd7a0a506cf5a4f0b39d1eb9143be82fa1341777c522457852223
  • Pointer size: 130 Bytes
  • Size of remote file: 53.5 kB

Git LFS Details

  • SHA256: 4200e52ad503a59c29fdea0f594c3ca8c0323eca089a71f37bc718bc08a16cd8
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a85017c24383c5064153544e98810e64a43a2d9c1b0ce79362f101391f84968b
3
- size 1723522
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb11577f5bc95ddaec1ac5dfbd4090581c6ea31c5c5e0bbc1d616d5948ec98aa
3
+ size 9700840
train.py CHANGED
@@ -1,4 +1,4 @@
1
- from comet_ml import Experiment
2
  import os
3
  import math
4
  import torch
@@ -29,14 +29,14 @@ from transformers import AutoTokenizer, AutoModel
29
  # --------------------------- Параметры ---------------------------
30
  ds_path = "/workspace/sdxs/datasets/768"
31
  project = "unet"
32
- batch_size = 256
33
- base_learning_rate = 3e-5
34
- min_learning_rate = 2.5e-5
35
- num_epochs = 10
36
- sample_interval_share = 20
37
  max_length = 192
38
- use_wandb = False
39
- use_comet_ml = True
40
  save_model = True
41
  use_decay = True
42
  fbp = False
@@ -51,11 +51,11 @@ torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
  #torch.backends.cuda.enable_mem_efficient_sdp(False)
53
  dtype = torch.float32
54
- save_barrier = 1.005
55
  warmup_percent = 0.01
56
  percentile_clipping = 98
57
  betta2 = 0.998
58
- eps = 1e-6
59
  clip_grad_norm = 1.0
60
  limit = 0
61
  checkpoints_folder = ""
@@ -171,7 +171,7 @@ def encode_texts(texts, max_length=max_length):
171
  batch_size = hidden.shape[0]
172
  pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
173
 
174
- #return hidden, attention_mask, pooled
175
  # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
176
  # 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
177
  pooled_expanded = pooled.unsqueeze(1)
@@ -186,7 +186,7 @@ def encode_texts(texts, max_length=max_length):
186
  # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
187
  new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
188
 
189
- return new_encoder_hidden_states, new_attention_mask, pooled
190
 
191
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
192
  if shift_factor is None: shift_factor = 0.0
@@ -287,9 +287,9 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
287
  texts = [item["text"] for item in samples_data]
288
 
289
  # Кодируем тексты на лету, чтобы получить маски и пулинг
290
- embeddings, masks, pools = encode_texts(texts)
291
 
292
- fixed_samples[size] = (latents, embeddings, masks, pools, texts)
293
 
294
  print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
295
  return fixed_samples
@@ -309,19 +309,19 @@ def collate_fn_simple(batch):
309
  texts = [
310
  "" if t.lower().startswith("zero")
311
  else "" if random.random() < 0.05
312
- else t[1:].lstrip() if t.startswith(".")
313
- else t
314
  for t in raw_texts
315
  ]
316
 
317
  # 3. Кодируем на лету
318
- # Возвращает: hidden (B, L, D), mask (B, L), pooled (B, D)
319
- embeddings, attention_mask, pooled_embeddings = encode_texts(texts)
320
 
321
  # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
322
  attention_mask = attention_mask.to(dtype=torch.int64)
323
 
324
- return latents, embeddings, attention_mask, pooled_embeddings
325
 
326
  batch_sampler = DistributedResolutionBatchSampler(
327
  dataset=dataset,
@@ -375,8 +375,70 @@ def create_optimizer(name, params):
375
  )
376
  elif name == "adam":
377
  return torch.optim.AdamW(
378
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
379
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  else:
381
  raise ValueError(f"Unknown optimizer: {name}")
382
 
@@ -418,24 +480,22 @@ def get_negative_embedding(neg_prompt="", batch_size=1):
418
  seq_len = max_length
419
  empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
420
  empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
421
- empty_pool = torch.zeros((batch_size, hidden_dim), dtype=dtype, device=device)
422
- return empty_emb, empty_mask, empty_pool
423
 
424
- uncond_emb, uncond_mask, uncond_pool = encode_texts([neg_prompt])
425
  uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
426
  uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
427
- uncond_pool = uncond_pool.to(dtype=dtype, device=device).repeat(batch_size, 1)
428
 
429
- return uncond_emb, uncond_mask, uncond_pool
430
 
431
  # Получаем негативные (пустые) условия для валидации
432
- uncond_emb, uncond_mask, uncond_pool = get_negative_embedding("low quality")
433
 
434
- # --- [UPDATED] Функция генерации семплов (принимает pool) ---
435
  @torch.compiler.disable()
436
  @torch.no_grad()
437
  def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
438
- uncond_emb, uncond_mask, uncond_pool = uncond_data
439
 
440
  original_model = None
441
  try:
@@ -449,13 +509,12 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
449
  all_generated_images = []
450
  all_captions = []
451
 
452
- # Распаковываем 5 элементов (добавились mask и pool)
453
- for size, (sample_latents, sample_text_embeddings, sample_mask, sample_pool, sample_text) in fixed_samples_cpu.items():
454
  width, height = size
455
  sample_latents = sample_latents.to(dtype=dtype, device=device)
456
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
457
  sample_mask = sample_mask.to(device=device)
458
- sample_pool = sample_pool.to(dtype=dtype, device=device)
459
 
460
  latents = torch.randn(
461
  sample_latents.shape,
@@ -483,15 +542,10 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
483
  neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
484
  attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
485
 
486
- # 3. Pools
487
- neg_pool_batch = uncond_pool[0:1].expand(curr_batch_size, -1)
488
- pooled_batch = torch.cat([neg_pool_batch, sample_pool], dim=0)
489
-
490
  else:
491
  latent_model_input = latents
492
  text_embeddings_batch = sample_text_embeddings
493
  attention_mask_batch = sample_mask
494
- pooled_batch = sample_pool
495
 
496
  # Предсказание с передачей всех условий
497
  model_out = original_model(
@@ -499,7 +553,6 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
499
  t,
500
  encoder_hidden_states=text_embeddings_batch,
501
  encoder_attention_mask=attention_mask_batch,
502
- #added_cond_kwargs={"text_embeds": pooled_batch} # <--- ПУЛИНГ ЗДЕСЬ
503
  )
504
  flow = getattr(model_out, "sample", model_out)
505
 
@@ -560,7 +613,7 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
560
  if accelerator.is_main_process:
561
  if save_model:
562
  print("Генерация сэмплов до старта обучения...")
563
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask, uncond_pool), 0)
564
  accelerator.wait_for_everyone()
565
 
566
  def save_checkpoint(unet, variant=""):
@@ -601,8 +654,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
601
  accelerator.wait_for_everyone()
602
  unet.train()
603
 
604
- # --- [UPDATED] Распаковка 4 элементов ---
605
- for step, (latents, embeddings, attention_mask, pooled_embeddings) in enumerate(dataloader):
606
  with accelerator.accumulate(unet):
607
  if save_model == False and step == 5 :
608
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
@@ -617,13 +669,12 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
617
  # делаем integer timesteps для UNet
618
  timesteps = (t * scheduler.config.num_train_timesteps).long()
619
 
620
- # --- [UPDATED] Вызов UNet с маской и пулингом ---
621
  model_pred = unet(
622
  noisy_latents,
623
  timesteps,
624
  encoder_hidden_states=embeddings,
625
- encoder_attention_mask=attention_mask,
626
- #added_cond_kwargs={"text_embeds": pooled_embeddings} # <--- Передача пулинга
627
  ).sample
628
 
629
  target = noise - latents
@@ -641,8 +692,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
641
  grad = 0.0
642
  if not fbp:
643
  if accelerator.sync_gradients:
644
- with torch.amp.autocast('cuda', enabled=False):
645
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
646
  grad = float(grad_val)
647
  optimizer.step()
648
  lr_scheduler.step()
@@ -669,8 +720,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
669
  comet_experiment.log_metrics(log_data, step=global_step)
670
 
671
  if global_step % sample_interval == 0:
672
- # Передаем tuple (emb, mask, pool) для негатива
673
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask, uncond_pool), global_step)
674
  last_n = sample_interval
675
 
676
  if save_model:
 
1
+ #from comet_ml import Experiment
2
  import os
3
  import math
4
  import torch
 
29
  # --------------------------- Параметры ---------------------------
30
  ds_path = "/workspace/sdxs/datasets/768"
31
  project = "unet"
32
+ batch_size = 60
33
+ base_learning_rate = 2.7e-5
34
+ min_learning_rate = 2.7e-5
35
+ num_epochs = 100
36
+ sample_interval_share = 5
37
  max_length = 192
38
+ use_wandb = True
39
+ use_comet_ml = False
40
  save_model = True
41
  use_decay = True
42
  fbp = False
 
51
  torch.backends.cudnn.allow_tf32 = True
52
  #torch.backends.cuda.enable_mem_efficient_sdp(False)
53
  dtype = torch.float32
54
+ save_barrier = 1.006
55
  warmup_percent = 0.01
56
  percentile_clipping = 98
57
  betta2 = 0.998
58
+ eps = 1e-7
59
  clip_grad_norm = 1.0
60
  limit = 0
61
  checkpoints_folder = ""
 
171
  batch_size = hidden.shape[0]
172
  pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
173
 
174
+ #return hidden, attention_mask
175
  # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
176
  # 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
177
  pooled_expanded = pooled.unsqueeze(1)
 
186
  # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
187
  new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
188
 
189
+ return new_encoder_hidden_states, new_attention_mask
190
 
191
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
192
  if shift_factor is None: shift_factor = 0.0
 
287
  texts = [item["text"] for item in samples_data]
288
 
289
  # Кодируем тексты на лету, чтобы получить маски и пулинг
290
+ embeddings, masks = encode_texts(texts)
291
 
292
+ fixed_samples[size] = (latents, embeddings, masks, texts)
293
 
294
  print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
295
  return fixed_samples
 
309
  texts = [
310
  "" if t.lower().startswith("zero")
311
  else "" if random.random() < 0.05
312
+ else t[1:].lstrip() if t.startswith(".")
313
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
314
  for t in raw_texts
315
  ]
316
 
317
  # 3. Кодируем на лету
318
+ # Возвращает: hidden (B, L, D), mask (B, L)
319
+ embeddings, attention_mask = encode_texts(texts)
320
 
321
  # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
322
  attention_mask = attention_mask.to(dtype=torch.int64)
323
 
324
+ return latents, embeddings, attention_mask
325
 
326
  batch_sampler = DistributedResolutionBatchSampler(
327
  dataset=dataset,
 
375
  )
376
  elif name == "adam":
377
  return torch.optim.AdamW(
378
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
379
  )
380
+ elif name == "adamuon":
381
+ from pytorch_optimizer import AdaMuon
382
+ # 1. РАЗДЕЛЕНИЕ ПАРАМЕТРОВ
383
+ # Разделяем параметры на те, которые Muon-совместимы (ndim >= 2),
384
+ # и те, которые лучше оставить для AdamW8bit (ndim < 2).
385
+
386
+ muon_params = []
387
+ adam8bit_params = []
388
+
389
+ if isinstance(params, (list, tuple)):
390
+ all_params = params
391
+ else:
392
+ # Если params - это генератор (unet.parameters()), преобразуем его в list
393
+ all_params = list(params)
394
+
395
+ for p in all_params:
396
+ if p.requires_grad:
397
+ if p.ndim >= 2:
398
+ muon_params.append(p)
399
+ else:
400
+ adam8bit_params.append(p)
401
+
402
+ # 2. СОЗДАНИЕ ГРУПП ПАРАМЕТРОВ
403
+ param_groups = []
404
+
405
+ # --- Группа 1: AdaMuon (для матриц/сверток) ---
406
+ if muon_params:
407
+ print(f"Используем AdaMuon для {len(muon_params)} параметров (ndim >= 2)")
408
+ param_groups.append({
409
+ 'params': muon_params,
410
+ 'optimizer': 'AdaMuon', # Явно указываем оптимизатор
411
+ 'use_muon': True, # Обязательно для AdaMuon
412
+ 'lr': base_learning_rate, # Использование обычных пробелов
413
+ 'betas': (0.9, betta2),
414
+ 'eps': eps,
415
+ 'weight_decay': 0.01,
416
+ 'nesterov': True,
417
+ 'decouple_weight_decay': True
418
+ })
419
+
420
+ # --- Группа 2: AdamW8bit (для смещений/нормализации) ---
421
+ if adam8bit_params:
422
+ print(f"Используем AdamW8bit для {len(adam8bit_params)} параметров (ndim < 2)")
423
+ param_groups.append({
424
+ 'params': adam8bit_params,
425
+ 'optimizer': 'AdamW8bit', # Явно указываем оптимизатор
426
+ 'use_muon': False,
427
+ # AdamW8bit не требует use_muon
428
+ 'lr': base_learning_rate, # Использование обычных пробелов
429
+ 'betas': (0.9, betta2),
430
+ 'eps': eps,
431
+ 'weight_decay': 0.01,
432
+ 'percentile_clipping': percentile_clipping
433
+ })
434
+
435
+ if not param_groups:
436
+ raise ValueError("Нет параметров для оптимизации.")
437
+
438
+ # 3. ИНИЦИАЛИЗАЦИЯ (смешанный оптимизатор)
439
+ # При передаче списка групп оптимизатор AdaMuon автоматически
440
+ # распознает ключи 'optimizer' и инициализирует внутренние оптимизаторы.
441
+ return AdaMuon(param_groups)
442
  else:
443
  raise ValueError(f"Unknown optimizer: {name}")
444
 
 
480
  seq_len = max_length
481
  empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
482
  empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
483
+ return empty_emb, empty_mask
 
484
 
485
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
486
  uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
487
  uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
 
488
 
489
+ return uncond_emb, uncond_mask
490
 
491
  # Получаем негативные (пустые) условия для валидации
492
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
493
 
494
+ # --- Функция генерации семплов ---
495
  @torch.compiler.disable()
496
  @torch.no_grad()
497
  def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
498
+ uncond_emb, uncond_mask = uncond_data
499
 
500
  original_model = None
501
  try:
 
509
  all_generated_images = []
510
  all_captions = []
511
 
512
+ # Распаковываем 5 элементов (добавились mask)
513
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
514
  width, height = size
515
  sample_latents = sample_latents.to(dtype=dtype, device=device)
516
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
517
  sample_mask = sample_mask.to(device=device)
 
518
 
519
  latents = torch.randn(
520
  sample_latents.shape,
 
542
  neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
543
  attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
544
 
 
 
 
 
545
  else:
546
  latent_model_input = latents
547
  text_embeddings_batch = sample_text_embeddings
548
  attention_mask_batch = sample_mask
 
549
 
550
  # Предсказание с передачей всех условий
551
  model_out = original_model(
 
553
  t,
554
  encoder_hidden_states=text_embeddings_batch,
555
  encoder_attention_mask=attention_mask_batch,
 
556
  )
557
  flow = getattr(model_out, "sample", model_out)
558
 
 
613
  if accelerator.is_main_process:
614
  if save_model:
615
  print("Генерация сэмплов до старта обучения...")
616
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
617
  accelerator.wait_for_everyone()
618
 
619
  def save_checkpoint(unet, variant=""):
 
654
  accelerator.wait_for_everyone()
655
  unet.train()
656
 
657
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
 
658
  with accelerator.accumulate(unet):
659
  if save_model == False and step == 5 :
660
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
 
669
  # делаем integer timesteps для UNet
670
  timesteps = (t * scheduler.config.num_train_timesteps).long()
671
 
672
+ # --- Вызов UNet с маской ---
673
  model_pred = unet(
674
  noisy_latents,
675
  timesteps,
676
  encoder_hidden_states=embeddings,
677
+ encoder_attention_mask=attention_mask
 
678
  ).sample
679
 
680
  target = noise - latents
 
692
  grad = 0.0
693
  if not fbp:
694
  if accelerator.sync_gradients:
695
+ #with torch.amp.autocast('cuda', enabled=False):
696
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
697
  grad = float(grad_val)
698
  optimizer.step()
699
  lr_scheduler.step()
 
720
  comet_experiment.log_metrics(log_data, step=global_step)
721
 
722
  if global_step % sample_interval == 0:
723
+ # Передаем tuple (emb, mask) для негатива
724
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
725
  last_n = sample_interval
726
 
727
  if save_model:
unet/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:987ce2ea59106bf0806b16c21fff90d6fe6e7a90101e0157d8295803dbf34f2b
3
- size 1892
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ab7222cdd538ff5178adc870a764d22ab24a185f0a7b63852ea728b3b09fcff
3
+ size 1876
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1924f863b340d5698a8f86471a5ac9ad0f6eda9477d1b43a1a7a146300c5bb60
3
- size 3302474272
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b202a600c7baf31b986d98f00109d0c396f0fddc91bab63c49a01c3477eeba1
3
+ size 6604736640