Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
7d8ea4a
·
1 Parent(s): 4ac3097
1.png ADDED

Git LFS Details

  • SHA256: 7ad632fa12f2653b831189dea7f2eeb5d525a58c5b0cf35f5e8364b7205b7d4b
  • Pointer size: 131 Bytes
  • Size of remote file: 736 kB
README.md CHANGED
@@ -79,6 +79,13 @@ image.show()
79
  upscaled = pipe.image_upscale("media/girl.jpg")
80
  ```
81
 
 
 
 
 
 
 
 
82
  ### Prompt refine
83
  ```
84
  refined = pipe.refine_prompts("girl")
@@ -214,6 +221,10 @@ nohup accelerate launch train.py &
214
  - Rubles: [For users from Russia](https://www.tbank.ru/cf/90ensBQqpJj)
215
  - DOGE: DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83
216
  - BTC: 3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
 
 
 
 
217
 
218
  ## Contacts
219
  Please contact with us if you may provide some GPU's or money on training
 
79
  upscaled = pipe.image_upscale("media/girl.jpg")
80
  ```
81
 
82
+ ```markdown
83
+ | | |
84
+ |:---:|:---:|
85
+ | <img src="media/123456789.jpg" height="512"/> | <img src="media/123456789.png" height="512"/> |
86
+ | original | 2xupscale |
87
+ ```
88
+
89
  ### Prompt refine
90
  ```
91
  refined = pipe.refine_prompts("girl")
 
221
  - Rubles: [For users from Russia](https://www.tbank.ru/cf/90ensBQqpJj)
222
  - DOGE: DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83
223
  - BTC: 3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
224
+ USTD
225
+ - Ethereum / Polygon / BNB SmartChain: 0xD4388B6698dFaE1460E72099D4F208aaCA4f6E6C
226
+ - Tron: TD7ey4h9igPGdcrcBcnZaz56R5tNgRZNvV
227
+ - Solana: MMYFJeYEtYHrSNFHChytJDHbEDniXrnAxPNLhJ1LbkB
228
 
229
  ## Contacts
230
  Please contact with us if you may provide some GPU's or money on training
media/123456789.jpg ADDED

Git LFS Details

  • SHA256: 131522c2f1db361170fb7f8819138893ccec8c1be544509b03aee277c3118e31
  • Pointer size: 131 Bytes
  • Size of remote file: 215 kB
media/123456789.png ADDED

Git LFS Details

  • SHA256: 23cccef4940e4899124a63a96c4f9efa1eda83d488337869f0fe4a89afee7d57
  • Pointer size: 132 Bytes
  • Size of remote file: 3.11 MB
media/girl.jpg CHANGED

Git LFS Details

  • SHA256: d034b43e0ac84a4ea1c2f270338d532a5fd9f7f2f0244e1c230292e7308293ce
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB

Git LFS Details

  • SHA256: 3b246b6f9af0e3320cc4d74f3b419dc30c85fb01defa89b58bf36818cd7f520a
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: b777e000e154f8b337da1d9834b4d34620bf09dd61581ee388db07bc965b2647
  • Pointer size: 132 Bytes
  • Size of remote file: 6.29 MB

Git LFS Details

  • SHA256: 1f0ac037b9c16ad73612903c405a663dfb569ad2ad16b43cdfb1b001a5bbb661
  • Pointer size: 132 Bytes
  • Size of remote file: 6.43 MB
pipeline_sdxs.py CHANGED
@@ -248,7 +248,6 @@ class SdxsPipeline(DiffusionPipeline):
248
 
249
  # Encode -> Decode (using mean for deterministic upscale)
250
  latents = self.vae.encode(tensors).latent_dist.mean
251
- latents = latents * self.vae_latents_std.to(latents) + self.vae_latents_mean.to(latents)
252
  decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
253
 
254
  # 4. Post-process: Denormalize and Crop
 
248
 
249
  # Encode -> Decode (using mean for deterministic upscale)
250
  latents = self.vae.encode(tensors).latent_dist.mean
 
251
  decoded = self.vae.decode(latents.to(self.vae.dtype))[0]
252
 
253
  # 4. Post-process: Denormalize and Crop
samples/unet_384x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 40e4e80a08953df78d2cdf311b1c7f3956f170340c41c0c46ed47d6e5ad1d4c4
  • Pointer size: 131 Bytes
  • Size of remote file: 319 kB

Git LFS Details

  • SHA256: d00d4e2116748740011a782f02a0f3d76813c2bb5383a4bccbc17bc45d76d03a
  • Pointer size: 131 Bytes
  • Size of remote file: 321 kB
samples/unet_416x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 3b59750e748584c41e38a7f0051d0617503915fdeec9133a547925fdab72c29e
  • Pointer size: 131 Bytes
  • Size of remote file: 323 kB

Git LFS Details

  • SHA256: a764f89b9c7db78cf65380dc2b5efb9b80e66372d30d8b29f112f6ce1358c64c
  • Pointer size: 131 Bytes
  • Size of remote file: 289 kB
samples/unet_448x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 328b22afe1f69a970e186381c26eda3c3995650b8ab9c2bc67ae435647f386e2
  • Pointer size: 131 Bytes
  • Size of remote file: 356 kB

Git LFS Details

  • SHA256: 4bf3ad7d152835ef721f9b78d1b652560801c541ba3aa0ad5a5c0c17f539cef5
  • Pointer size: 131 Bytes
  • Size of remote file: 447 kB
samples/unet_480x704_0.jpg CHANGED

Git LFS Details

  • SHA256: ef1e7f79a85f40392c44fda93d3a8a381510828f09ba449c91631855a1c18541
  • Pointer size: 131 Bytes
  • Size of remote file: 513 kB

Git LFS Details

  • SHA256: 5412e778f18133d39004642863cf0725a3ef80616b12d3cc97caf202927b8934
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
samples/unet_512x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 40b4b933783788224a8f46f1190599185852a8543402647ec96573b4dbab66d2
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB

Git LFS Details

  • SHA256: 485050a25158286e44c8dbd55bb3f4d93feeaa29568c79fe386a5830db5098f3
  • Pointer size: 131 Bytes
  • Size of remote file: 512 kB
samples/unet_544x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 44fece279aeacbda8cec94ec352258cfee766bd2708d1831b5ff2d9c4b737063
  • Pointer size: 131 Bytes
  • Size of remote file: 645 kB

Git LFS Details

  • SHA256: 2262aa315669c814062bb5139da8dc1c798ee9bf0703d497e1b8dd10ecbe5e30
  • Pointer size: 131 Bytes
  • Size of remote file: 592 kB
samples/unet_576x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 478bfc400b7827bd784d743a7d7d41f81466be08f75080f74c886b31f8ba145e
  • Pointer size: 131 Bytes
  • Size of remote file: 553 kB

Git LFS Details

  • SHA256: 57a61d690cb6cfb7ee19d4edf1ad4ffde5ed745f555558f58b2067fcc13f09bd
  • Pointer size: 131 Bytes
  • Size of remote file: 469 kB
samples/unet_608x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 40105e4c18eac27d881bc315261d71bb9c06f4f91df1008e5ca98158802c9a92
  • Pointer size: 131 Bytes
  • Size of remote file: 389 kB

Git LFS Details

  • SHA256: 73a82467b9ea936641d28e26e21e5f62ec02e966a4a77d9349c629544ff237b0
  • Pointer size: 131 Bytes
  • Size of remote file: 462 kB
samples/unet_640x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 4450cdb9903b7010471ad848c4bc9eaa9940bb2c66625acd08ce4b80a6e4a20d
  • Pointer size: 131 Bytes
  • Size of remote file: 394 kB

Git LFS Details

  • SHA256: 913152c25eb7c1bfd93ff94b67b68ac8852672b079b2ca91b645c97e9f65aef1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
samples/unet_672x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 5d343f9758b8b6e09f862088b19db7f633f8435d19fc17cd80623ce99db046e3
  • Pointer size: 131 Bytes
  • Size of remote file: 259 kB

Git LFS Details

  • SHA256: 58ef1e9e22f16e69f55a5fee2c8fd6a647eac45eb635dfc9acd113afe39e190a
  • Pointer size: 131 Bytes
  • Size of remote file: 547 kB
samples/unet_704x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 064d3f2bc1023c7e81a197372dc1f49a60846165dd2ad3d9c57f8d36e120126f
  • Pointer size: 131 Bytes
  • Size of remote file: 345 kB

Git LFS Details

  • SHA256: 8bee9b3fb5479bced74ce18c08f97b36fe59776c60008fccd9939d8d4e2eb84c
  • Pointer size: 131 Bytes
  • Size of remote file: 295 kB
samples/unet_704x416_0.jpg CHANGED

Git LFS Details

  • SHA256: db55ffcd19721c73b94158cf808bd0ad37ef77aa2af25aff64316ad87eb2e6d1
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB

Git LFS Details

  • SHA256: a33349da23384d4670fd68c3a622370f2f64dbec7b06c9fb8cc1bb427331d8ee
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
samples/unet_704x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 76fc0942a04be0b6e261aba92028d46bb344a2ab2fabe1e590992ec865657903
  • Pointer size: 131 Bytes
  • Size of remote file: 530 kB

Git LFS Details

  • SHA256: 358020e59b1af51c75e11c6c471ec1ec29a74d77a9dcf474279028f3d7ceaa9f
  • Pointer size: 131 Bytes
  • Size of remote file: 412 kB
samples/unet_704x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 2e13d19a26a81263f4d11a453cdfb64436e6e6b3919edc7fedc4ee8f938e1e60
  • Pointer size: 131 Bytes
  • Size of remote file: 482 kB

Git LFS Details

  • SHA256: 689af1b587e2932ef7243594cb9f90b8d8fe3b7445dd7d38d19f4d358c703d47
  • Pointer size: 131 Bytes
  • Size of remote file: 328 kB
samples/unet_704x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 190a71005a21fe5be6ebcb8e0ba4e8353632394b615cc31180409ee0531a9b34
  • Pointer size: 131 Bytes
  • Size of remote file: 413 kB

Git LFS Details

  • SHA256: 87ee45512e26ffa68b8b1237f470fe641fa1a20f04c69eee1dd59731c58f1ebd
  • Pointer size: 131 Bytes
  • Size of remote file: 528 kB
samples/unet_704x544_0.jpg CHANGED

Git LFS Details

  • SHA256: b26efe72bb512c5df1550e22c84ebe2f5030df69749cc879a647a7af9c8a7250
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB

Git LFS Details

  • SHA256: cd1eff5df818b37eb365f9a1a39efc91396c0ffb2c3284bf7677215d323e4087
  • Pointer size: 131 Bytes
  • Size of remote file: 214 kB
samples/unet_704x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 5c16d69e85df706353a4ab2e64d27dcbff85864c8b5fb45108ea188a68fe8581
  • Pointer size: 131 Bytes
  • Size of remote file: 521 kB

Git LFS Details

  • SHA256: 3812742f1a6bab81ec0176b841006abbcc0bf29abedaa136b9773710c3268acc
  • Pointer size: 131 Bytes
  • Size of remote file: 535 kB
samples/unet_704x608_0.jpg CHANGED

Git LFS Details

  • SHA256: 55a89f5567e933fed12ee3fd0269f55c9265fa01219ed9119a3f1c2bd60d822b
  • Pointer size: 131 Bytes
  • Size of remote file: 386 kB

Git LFS Details

  • SHA256: b0a0c03bde7979077f4e152414e2dd3325d00776ce5549b60b5be783a3f26e18
  • Pointer size: 131 Bytes
  • Size of remote file: 425 kB
samples/unet_704x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 3e503d919fe7c90104ac4f9b66a073fe7883d642a73eab61830373d0afea3813
  • Pointer size: 131 Bytes
  • Size of remote file: 747 kB

Git LFS Details

  • SHA256: d8967830aeb09b45642c522ba92645cd518530a97a57ebc6b5b28075fbb900d5
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
samples/unet_704x672_0.jpg CHANGED

Git LFS Details

  • SHA256: 6a751180451b9fd05bab2bb1d2815f6f2d05e6de1f7b27776e896458dcdee950
  • Pointer size: 131 Bytes
  • Size of remote file: 589 kB

Git LFS Details

  • SHA256: 3defeb654c8fd965d7565ad4e31e4f470adba86886fc40e5a5052b44bef1c94a
  • Pointer size: 131 Bytes
  • Size of remote file: 679 kB
samples/unet_704x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 22290dd9942a92ae67fa4249b351ddacd201e2fb4ec8fe9416e7eb132010a139
  • Pointer size: 131 Bytes
  • Size of remote file: 659 kB

Git LFS Details

  • SHA256: 22449772745d311d83f89a6c0798abb045c983b816b7813e4014405e4fb63ef0
  • Pointer size: 131 Bytes
  • Size of remote file: 716 kB
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1f99ca4f959a85a5453820eedad6b5c1fde29fcadd30a3eaae22097c96a0a7a0
3
- size 16433141
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd0d1f7f30e064d8fb03320c204b7d67fd5c7df87da6919af818d6e2f60ee1a2
3
+ size 12104053
train.py CHANGED
@@ -10,6 +10,7 @@ import bitsandbytes as bnb
10
  import torch.nn.functional as F
11
  import argparse
12
 
 
13
  from diffusers import UNet2DConditionModel, AsymmetricAutoencoderKL, FlowMatchEulerDiscreteScheduler
14
  from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
15
  from torch.utils.data import DataLoader, Sampler
@@ -21,7 +22,7 @@ from tqdm import tqdm
21
  from PIL import Image, ImageOps
22
  from torch.utils.checkpoint import checkpoint
23
  from diffusers.models.attention_processor import AttnProcessor2_0
24
- from datetime import datetime
25
 
26
  # Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git
27
  from muon_adamw8bit import MuonAdamW8bit
@@ -470,12 +471,14 @@ else:
470
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
471
  (1 + math.cos(math.pi * decay_ratio))
472
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
473
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
474
 
475
  if torch_compile:
476
- print("compiling")
477
  unet = torch.compile(unet)
478
- print("compiling - ok")
 
 
 
479
 
480
  # Фиксированные семплы
481
  fixed_samples = get_fixed_samples_by_resolution(dataset)
@@ -702,105 +705,108 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
702
  unet.train()
703
 
704
  for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
705
- with accelerator.accumulate(unet):
706
- if save_model == False and epoch == 0 and step == 5 :
707
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
708
- print(f"Шаг {step}: {used_gb:.2f} GB")
709
-
710
- # шум
711
- noise = torch.randn_like(latents, dtype=latents.dtype)
712
-
713
- # 3. Время t (сэмплим, как и раньше, но чуть сжимаем края)
714
- u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
715
- t = u * (1 - 2 * 1e-5) + 1e-5 # Теперь t строго в (0.00001 ... 0.99999)
716
- # интерполяция между x0 и шумом
717
- noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
718
- # делаем integer timesteps для UNet
719
- timesteps = t.to(torch.float32).mul(999.0)
720
- timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1)
721
-
722
- # --- Вызов UNet с маской ---
723
- model_pred = unet(
724
- noisy_latents,
725
- timesteps,
726
- encoder_hidden_states=embeddings,
727
- encoder_attention_mask=attention_mask,
728
- ).sample
729
-
730
- target = noise - latents
731
-
732
- mse_loss = F.mse_loss(model_pred.float(), target.float())
733
- batch_losses.append(mse_loss.detach().item())
734
-
735
- if (global_step % 100 == 0) or (global_step % sink_interval == 0):
736
- accelerator.wait_for_everyone()
737
-
738
- losses_dict = {}
739
- losses_dict["mse"] = mse_loss
740
-
741
- if (global_step % 100 == 0) or (global_step % sink_interval == 0):
742
- accelerator.wait_for_everyone()
743
-
744
- accelerator.backward(mse_loss)
745
 
746
- if (global_step % 100 == 0) or (global_step % sink_interval == 0):
747
- accelerator.wait_for_everyone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
 
749
- grad = 0.0
750
- if not fbp:
751
- if accelerator.sync_gradients:
752
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
753
- grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
754
- optimizer.step()
755
- lr_scheduler.step()
756
- optimizer.zero_grad(set_to_none=True)
757
-
758
- if accelerator.sync_gradients:
759
- global_step += 1
760
- progress_bar.update(1)
761
- if accelerator.is_main_process:
762
- if fbp:
763
- current_lr = base_learning_rate
764
- else:
765
- current_lr = lr_scheduler.get_last_lr()[0]
766
- batch_grads.append(grad)
767
 
768
- log_data = {}
769
- log_data["loss_mse"] = mse_loss.detach().item()
770
- log_data["lr"] = current_lr
771
- log_data["grad"] = grad
 
 
 
 
 
 
772
  if accelerator.sync_gradients:
773
- if use_wandb:
774
- wandb.log(log_data, step=global_step)
775
- if use_comet_ml:
776
- comet_experiment.log_metrics(log_data, step=global_step)
777
-
778
- current_time = time.time()
779
- is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds
780
- if is_time_to_sample or global_step == 50:
781
- # Передаем tuple (emb, mask) для негатива
782
- if save_model:
783
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
784
- elif epoch % 10 == 0:
785
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
786
- last_n = sink_interval
787
-
788
- if save_model:
789
- has_losses = len(batch_losses) > 0
790
- avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0
791
- last_loss = batch_losses[-1] if has_losses else 0.0
792
- max_loss = max(avg_sample_loss, last_loss)
793
- should_save = max_loss < min_loss * save_barrier
794
- print(
795
- f"Saving: {should_save} | Max: {max_loss:.4f} | "
796
- f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
797
- )
798
- # 6. Сохранение и обновление
799
- if should_save:
800
- min_loss = max_loss
801
- save_checkpoint(unet)
802
- last_sample_time = current_time
803
- unet.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
 
805
  if accelerator.is_main_process:
806
  avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
 
10
  import torch.nn.functional as F
11
  import argparse
12
 
13
+ from datetime import datetime
14
  from diffusers import UNet2DConditionModel, AsymmetricAutoencoderKL, FlowMatchEulerDiscreteScheduler
15
  from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
16
  from torch.utils.data import DataLoader, Sampler
 
22
  from PIL import Image, ImageOps
23
  from torch.utils.checkpoint import checkpoint
24
  from diffusers.models.attention_processor import AttnProcessor2_0
25
+ from contextlib import nullcontext
26
 
27
  # Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git
28
  from muon_adamw8bit import MuonAdamW8bit
 
471
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
472
  (1 + math.cos(math.pi * decay_ratio))
473
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
 
474
 
475
  if torch_compile:
476
+ print("Compiling UNet... Это займет несколько минут, не прерывайте!")
477
  unet = torch.compile(unet)
478
+ print("Compiling - ok")
479
+
480
+ if not fbp:
481
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
482
 
483
  # Фиксированные семплы
484
  fixed_samples = get_fixed_samples_by_resolution(dataset)
 
705
  unet.train()
706
 
707
  for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
708
+
709
+ if save_model == False and epoch == 0 and step == 5 :
710
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
711
+ print(f"Шаг {step}: {used_gb:.2f} GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
 
713
+ amp_context = accelerator.autocast() if torch_compile else nullcontext()
714
+ with accelerator.accumulate(unet):
715
+ with amp_context:
716
+ # шум
717
+ noise = torch.randn_like(latents, dtype=latents.dtype)
718
+
719
+ # 3. Время t (сэмплим, как и раньше, но чуть сжимаем края)
720
+ u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
721
+ t = u * (1 - 2 * 1e-5) + 1e-5 # Теперь t строго в (0.00001 ... 0.99999)
722
+ # интерполяция между x0 и шумом
723
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
724
+ # делаем integer timesteps для UNet
725
+ timesteps = t.to(torch.float32).mul(999.0)
726
+ timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1)
727
+
728
+ # --- Вызов UNet с маской ---
729
+ model_pred = unet(
730
+ noisy_latents,
731
+ timesteps,
732
+ encoder_hidden_states=embeddings,
733
+ encoder_attention_mask=attention_mask,
734
+ ).sample
735
 
736
+ target = noise - latents
737
+
738
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
739
+ batch_losses.append(mse_loss.detach().item())
740
+
741
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
742
+ accelerator.wait_for_everyone()
743
+
744
+ losses_dict = {}
745
+ losses_dict["mse"] = mse_loss
 
 
 
 
 
 
 
 
746
 
747
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
748
+ accelerator.wait_for_everyone()
749
+
750
+ accelerator.backward(mse_loss)
751
+
752
+ if (global_step % 100 == 0) or (global_step % sink_interval == 0):
753
+ accelerator.wait_for_everyone()
754
+
755
+ grad = 0.0
756
+ if not fbp:
757
  if accelerator.sync_gradients:
758
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
759
+ grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
760
+ optimizer.step()
761
+ lr_scheduler.step()
762
+ optimizer.zero_grad(set_to_none=True)
763
+
764
+ if accelerator.sync_gradients:
765
+ global_step += 1
766
+ progress_bar.update(1)
767
+ if accelerator.is_main_process:
768
+ if fbp:
769
+ current_lr = base_learning_rate
770
+ else:
771
+ current_lr = lr_scheduler.get_last_lr()[0]
772
+ batch_grads.append(grad)
773
+
774
+ log_data = {}
775
+ log_data["loss_mse"] = mse_loss.detach().item()
776
+ log_data["lr"] = current_lr
777
+ log_data["grad"] = grad
778
+ if accelerator.sync_gradients:
779
+ if use_wandb:
780
+ wandb.log(log_data, step=global_step)
781
+ if use_comet_ml:
782
+ comet_experiment.log_metrics(log_data, step=global_step)
783
+
784
+ current_time = time.time()
785
+ is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds
786
+ if is_time_to_sample or global_step == 50:
787
+ # Передаем tuple (emb, mask) для негатива
788
+ if save_model:
789
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
790
+ elif epoch % 10 == 0:
791
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
792
+ last_n = sink_interval
793
+
794
+ if save_model:
795
+ has_losses = len(batch_losses) > 0
796
+ avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0
797
+ last_loss = batch_losses[-1] if has_losses else 0.0
798
+ max_loss = max(avg_sample_loss, last_loss)
799
+ should_save = max_loss < min_loss * save_barrier
800
+ print(
801
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
802
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
803
+ )
804
+ # 6. Сохранение и обновление
805
+ if should_save:
806
+ min_loss = max_loss
807
+ save_checkpoint(unet)
808
+ last_sample_time = current_time
809
+ unet.train()
810
 
811
  if accelerator.is_main_process:
812
  avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2003b45e59873c0430ebf4889ee94bae5eca7b7aafb898f378689697f5c30c69
3
  size 3210307232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b2e01643382012adcffc091502a64e908d85a394504a92cf8bb8aa8cec080b9
3
  size 3210307232