Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
82fa13e
·
1 Parent(s): 2c56fee
girl.jpg CHANGED

Git LFS Details

  • SHA256: 84d10d43409b09a698dc38bb6a8e1dc14c1f5b1b55a9ee8399cce66fcf278dcc
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB

Git LFS Details

  • SHA256: 4db57bd9d6001ba9fc1ad4cdaff4d85ba07ea4df16ee9e0603b1b8a7b5e71f6c
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: 6498c3a97b5351e8b43c3fb5e23ee5ced8705c9d47ccf62ec3b28ebe0064b735
  • Pointer size: 132 Bytes
  • Size of remote file: 2.73 MB

Git LFS Details

  • SHA256: ffb45b7b0f02aff4c0c3cc9cf46f93a58bdca2b58fd9e4a016e09d84d358f97e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.74 MB
requirements.txt CHANGED
@@ -8,4 +8,5 @@ bitsandbytes>=0.45.4
8
  transformers
9
  hf_transfer
10
  comet_ml
11
- flash-linear-attention
 
 
8
  transformers
9
  hf_transfer
10
  comet_ml
11
+ flash-linear-attention
12
+ git+https://github.com/recoilme/muon_adamw8bit.git
samples/unet_384x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 24e0881606e4b4ceddffc2281b77c21bc05d93bbd7bc075fede3acfa8c1f90d6
  • Pointer size: 131 Bytes
  • Size of remote file: 330 kB

Git LFS Details

  • SHA256: b5039acdd63d5b6377c834c0aa284e0ec58db8159b5173f3c5e43201ac2bef66
  • Pointer size: 131 Bytes
  • Size of remote file: 491 kB
samples/unet_416x704_0.jpg CHANGED

Git LFS Details

  • SHA256: d8b19dd8bcf5b7cf70ada28e28359b73141461676ba5d25d573882e6ce26f9a9
  • Pointer size: 131 Bytes
  • Size of remote file: 245 kB

Git LFS Details

  • SHA256: f7148ac9c52885d068c0f6ca2ecb2b8ccd5ad9939e53e0d4bd729ffa44701364
  • Pointer size: 131 Bytes
  • Size of remote file: 658 kB
samples/unet_448x704_0.jpg CHANGED

Git LFS Details

  • SHA256: c39bde7278a342e62fdf4d14ed59619830f6de5a5e152f33b11ed86f224a2608
  • Pointer size: 131 Bytes
  • Size of remote file: 556 kB

Git LFS Details

  • SHA256: 839b4b29ba6aa8fc67e6dda81a818aaf358d7d36ff368cc2cca71f21b6b90a76
  • Pointer size: 131 Bytes
  • Size of remote file: 485 kB
samples/unet_480x704_0.jpg CHANGED

Git LFS Details

  • SHA256: fc70fa558cddcec65c65d8829e37deeb4a9183ef2e2f88c6cba26f313790d9c5
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB

Git LFS Details

  • SHA256: 28ad038e6e613e6d8338caa24e270fb39605d536117df761d0283e874d965cb0
  • Pointer size: 131 Bytes
  • Size of remote file: 639 kB
samples/unet_512x704_0.jpg CHANGED

Git LFS Details

  • SHA256: bf012a689d7d9c649b42b710526af4c735c70a53eb522fb2dbede7290fd9b6bd
  • Pointer size: 131 Bytes
  • Size of remote file: 588 kB

Git LFS Details

  • SHA256: ac6d8798a64a1c3742bb661b14bd644cde39db862772586edf33dbc71d4427e3
  • Pointer size: 131 Bytes
  • Size of remote file: 421 kB
samples/unet_544x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 001ccceb63702edd7ba146afd8c021f03a8fdff1e66617667d9cfac4b3d52bf7
  • Pointer size: 131 Bytes
  • Size of remote file: 583 kB

Git LFS Details

  • SHA256: 8b3576913e6869fab3ab95b6bac7fd2216350ccbea6d9df3a6e60d329d1e5b41
  • Pointer size: 131 Bytes
  • Size of remote file: 833 kB
samples/unet_576x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 85689aed746b4f06959fee703ee74a62def5bcc63b1d0b8a3af1ac5da199ef25
  • Pointer size: 131 Bytes
  • Size of remote file: 317 kB

Git LFS Details

  • SHA256: 0b23d6be85fcb742bd1041aa396904c89149720b40346300669b030873617831
  • Pointer size: 131 Bytes
  • Size of remote file: 347 kB
samples/unet_608x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 96fa7fb3d1f5764a7b18f5ee0a2c43ed2c2c5ba7afc549918c473af1c28200f1
  • Pointer size: 131 Bytes
  • Size of remote file: 430 kB

Git LFS Details

  • SHA256: 7acbb8630de69598a11f204164a7c3fb5978b63770ea1c035b2d748314623688
  • Pointer size: 131 Bytes
  • Size of remote file: 718 kB
samples/unet_640x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 4a5b6951aa9b2b3925b11f2db841bb02319c64b512aafa9effbe650016fd426e
  • Pointer size: 131 Bytes
  • Size of remote file: 657 kB

Git LFS Details

  • SHA256: 38a73b53b4470407ddd3ab2cbb65209e70b0af83033532731e971d88d5aa0f54
  • Pointer size: 131 Bytes
  • Size of remote file: 429 kB
samples/unet_672x704_0.jpg CHANGED

Git LFS Details

  • SHA256: a9c30a3f21bd55021bf7a6d2b5c8ee73201421550cd59d0cf105aca15fad5256
  • Pointer size: 131 Bytes
  • Size of remote file: 320 kB

Git LFS Details

  • SHA256: c0455fe1a933b76145e12bdb5fd73572a1ab9d3bd77a2b5035b9e9b45f346cf2
  • Pointer size: 131 Bytes
  • Size of remote file: 578 kB
samples/unet_704x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 83ee491a926a8857d1bb206439e7b9ce45cc9861c1e873c92dd69f97af25a8b7
  • Pointer size: 131 Bytes
  • Size of remote file: 320 kB

Git LFS Details

  • SHA256: 758540dbc74a8b85a4ad4a3834af36c9abbb38a866813e6d8d538f218f67ecfe
  • Pointer size: 131 Bytes
  • Size of remote file: 295 kB
samples/unet_704x416_0.jpg CHANGED

Git LFS Details

  • SHA256: 1fcda2fc739884fe414d42c1ab766886fbfba8cb3d4fb45ceb47dc1cb427eea2
  • Pointer size: 131 Bytes
  • Size of remote file: 767 kB

Git LFS Details

  • SHA256: c5a079b679107103b9fd53d384dfdf6d032a658a0037f1745790ca012770cd7f
  • Pointer size: 131 Bytes
  • Size of remote file: 152 kB
samples/unet_704x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 5e215474d7f777c3b0535087882fc5cb459cfb778147a2f5e7480a070646b2a6
  • Pointer size: 131 Bytes
  • Size of remote file: 345 kB

Git LFS Details

  • SHA256: 915d66b574e1589cf4c181899df7b5af9becb73e9c560c80e9470f54dc6e3167
  • Pointer size: 131 Bytes
  • Size of remote file: 539 kB
samples/unet_704x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 896bf9f723173f7dafabadbdb549123879009e8928cc588f40a497830adafb8d
  • Pointer size: 131 Bytes
  • Size of remote file: 712 kB

Git LFS Details

  • SHA256: fb0f8967405bc6b3223b5aa240e1ea55a6615810a43b5a100bb06eb32adbf2fa
  • Pointer size: 131 Bytes
  • Size of remote file: 544 kB
samples/unet_704x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 4ff6ad2a8a229b5b4ec144d4d365ec016f5671e0ee8626e2b39486e6484d3e9c
  • Pointer size: 131 Bytes
  • Size of remote file: 491 kB

Git LFS Details

  • SHA256: 4f9751905b573389085bde5f4f81a6c669420d176e550f35c8ea74c3e68d2736
  • Pointer size: 131 Bytes
  • Size of remote file: 485 kB
samples/unet_704x544_0.jpg CHANGED

Git LFS Details

  • SHA256: ffbbbf2d9838a9e6589803db96a247f045832cfcc4b72670f7df86fc4d094717
  • Pointer size: 131 Bytes
  • Size of remote file: 407 kB

Git LFS Details

  • SHA256: 1ba0ce4df0d3d4ed8baa539fb889ed411a8c4c382d06352420cd8519a0e5d93f
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
samples/unet_704x576_0.jpg CHANGED

Git LFS Details

  • SHA256: b87cd73815544d9175051990c44a971a74df58416bcc5fd43eb4f280148cee02
  • Pointer size: 131 Bytes
  • Size of remote file: 555 kB

Git LFS Details

  • SHA256: ad95cf8c01354a7f35e79b131c612d54a955fb7cea730565b1232a1c33f4ce1c
  • Pointer size: 131 Bytes
  • Size of remote file: 411 kB
samples/unet_704x608_0.jpg CHANGED

Git LFS Details

  • SHA256: cefc44fd65e287a995bcddfef5b464e5ebd01e9dbb99328e03401c0af3954583
  • Pointer size: 131 Bytes
  • Size of remote file: 296 kB

Git LFS Details

  • SHA256: 5cf77a9e396e1ebef5e534af0329ff24048800e990da295310c71920a816d68c
  • Pointer size: 131 Bytes
  • Size of remote file: 569 kB
samples/unet_704x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 8bad60111dfd92dbbf97e639117d684c4522be34951a50ca5a874e97ae1a2095
  • Pointer size: 131 Bytes
  • Size of remote file: 876 kB

Git LFS Details

  • SHA256: f611599f10ced8cd9734aeb4df1f45fade63510ef782245a72e5ddd138052533
  • Pointer size: 131 Bytes
  • Size of remote file: 600 kB
samples/unet_704x672_0.jpg CHANGED

Git LFS Details

  • SHA256: 241219b0f6643da59ae96996609c8139183778bd36c84bac61d09acf464e6b23
  • Pointer size: 131 Bytes
  • Size of remote file: 657 kB

Git LFS Details

  • SHA256: a8cf25277f00a26cdb40c1a1308e4a98b49fe54d76fb81ec8eee18c79f5c421b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.31 MB
samples/unet_704x704_0.jpg CHANGED

Git LFS Details

  • SHA256: d31998797e23f94c5eef4690693768ec7cbbd60f9a5e282934cc3210a76e6d74
  • Pointer size: 131 Bytes
  • Size of remote file: 932 kB

Git LFS Details

  • SHA256: 207e8a5c8d33b0bfede10c69ab9ddc4f38281fc0241c225e660933fa8efc798b
  • Pointer size: 131 Bytes
  • Size of remote file: 965 kB
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2780ff8a377ed48e6fac594ae91ca31d32127771d2662ff5103bdfd42c76fa8e
3
- size 6429453
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c442f0dceac55de4922225b7cfa939dd38d2e3b4abf2894a45b2e7e098e927fc
3
+ size 6417106
train-Copy1.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["NCCL_P2P_DISABLE"] = "1"
3
+ os.environ["NCCL_IB_DISABLE"] = "1" # comment it on H100
4
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
5
+ import math
6
+ import torch
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from torch.utils.data import DataLoader, Sampler
10
+ from torch.utils.data.distributed import DistributedSampler
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from collections import defaultdict
13
+ from diffusers import UNet2DConditionModel,AsymmetricAutoencoderKL,FlowMatchEulerDiscreteScheduler
14
+ from accelerate import Accelerator, DeepSpeedPlugin
15
+ from datasets import load_from_disk
16
+ from tqdm import tqdm
17
+ from PIL import Image, ImageOps
18
+ import wandb
19
+ import random,time
20
+ import gc
21
+ from accelerate.state import DistributedType
22
+ from torch.distributed import broadcast_object_list
23
+ from torch.utils.checkpoint import checkpoint
24
+ from diffusers.models.attention_processor import AttnProcessor2_0
25
+ from datetime import datetime
26
+ import bitsandbytes as bnb
27
+ import torch.nn.functional as F
28
+ from collections import deque
29
+ from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
30
+ import argparse
31
+ # pip install git+https://github.com/recoilme/muon_adamw8bit.git
32
+ from muon_adamw8bit import MuonAdamW8bit
33
+
34
+ # --------------------------- Параметры ---------------------------
35
+ ds_path = "datasets/ds1234_noanime_704_vae8x16x"
36
+ project = "unet"
37
+
38
+ # 1. Считаем локальный батч
39
+ gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
40
+ local_bs = max(1, int((gpu_mem_gb / 32) * 8))
41
+ # 2. Умножаем на количество ГПУ, чтобы получить ГЛОБАЛЬНЫЙ батч
42
+ num_gpus = torch.cuda.device_count()
43
+ batch_size = local_bs * num_gpus
44
+ print(f"GPUs: {num_gpus}, Local BS: {local_bs}, Global BS: {local_bs * num_gpus}")
45
+ base_learning_rate = 1e-5
46
+ min_learning_rate = 1e-7
47
+ num_epochs = num_gpus
48
+ sample_interval_share = 20
49
+ cfg_dropout = 0.10
50
+ max_length = 248
51
+ use_precomputed_embeddings = False
52
+ use_wandb = True
53
+ use_comet_ml = False
54
+ save_model = True
55
+ use_decay = True
56
+ fbp = False
57
+ torch_compile = False
58
+ unet_gradient = True
59
+ loss_normalize = False
60
+ fixed_seed = False
61
+ shuffle = True
62
+ muon_lr_scale = 300
63
+ optimizer_type = "muon_adam8bit"
64
+ if optimizer_type == "muon_adam8bit":
65
+ batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3))
66
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
67
+ comet_ml_workspace = "recoilme"
68
+ torch.backends.cuda.matmul.allow_tf32 = True
69
+ torch.backends.cudnn.allow_tf32 = True
70
+ # Включение Flash Attention 2/SDPA #MAX_JOBS=4 pip install flash-attn --no-build-isolation
71
+ torch.backends.cuda.enable_flash_sdp(True)
72
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
73
+ torch.backends.cuda.enable_math_sdp(False) # Отключаем медленный вариант
74
+ save_barrier = 1.25
75
+ warmup_percent = 0.01
76
+ betta2 = 0.995
77
+ eps = 1e-7
78
+ clip_grad_norm = 1.0
79
+ limit = 0
80
+ checkpoints_folder = ""
81
+ gradient_accumulation_steps = 1
82
+ dtype = torch.float32
83
+ mixed_precision = "no"
84
+
85
+ # Параметры для диффузии
86
+ n_diffusion_steps = 40
87
+ samples_to_generate = 12
88
+ guidance_scale = 4
89
+
90
+ # Папки для сохранения результатов
91
+ generated_folder = "samples"
92
+ os.makedirs(generated_folder, exist_ok=True)
93
+
94
+ # Настройка seed
95
+ current_date = datetime.now()
96
+ seed = int(current_date.strftime("%Y%m%d")) + 42
97
+ if fixed_seed:
98
+ torch.manual_seed(seed)
99
+ np.random.seed(seed)
100
+ random.seed(seed)
101
+ if torch.cuda.is_available():
102
+ torch.cuda.manual_seed_all(seed)
103
+
104
+ accelerator = Accelerator(
105
+ mixed_precision=mixed_precision,
106
+ gradient_accumulation_steps=gradient_accumulation_steps
107
+ )
108
+ device = accelerator.device
109
+
110
+ print("init")
111
+ # Создаём объект ArgumentParser с рассчитанными значениями по умолчанию
112
+ parser = argparse.ArgumentParser(description='Train a model on a dataset.')
113
+ parser.add_argument('--ds-path', type=str, default=ds_path, help='Path to the dataset')
114
+ parser.add_argument('--ep', type=int, default=num_epochs, help='Number of epochs to train the model')
115
+ parser.add_argument('--batch', type=int, default=batch_size, help='Total batch size')
116
+ parser.add_argument('--min-lr', type=float, default=min_learning_rate, help='Minimum learning rate')
117
+ parser.add_argument('--max-lr', type=float, default=base_learning_rate, help='Maximum learning rate')
118
+ parser.add_argument('--dry-run', action='store_true',default=False, help='Run configuration without saving/sampling')
119
+
120
+ # Парсим аргументы командной строки
121
+ args = parser.parse_args()
122
+
123
+ # Используем значения из аргументов
124
+ batch_size = args.batch
125
+ ds_path = args.ds_path
126
+ base_learning_rate = args.max_lr
127
+ min_learning_rate = args.min_lr
128
+ num_epochs = args.ep
129
+ if args.dry_run:
130
+ save_model = False
131
+
132
+ # --------------------------- Инициализация WandB ---------------------------
133
+ if accelerator.is_main_process:
134
+ if use_wandb:
135
+ wandb.init(project=project, config={
136
+ "batch_size": batch_size,
137
+ "base_learning_rate": base_learning_rate,
138
+ "num_epochs": num_epochs,
139
+ "optimizer_type": optimizer_type,
140
+ })
141
+ if use_comet_ml:
142
+ from comet_ml import Experiment
143
+ comet_experiment = Experiment(
144
+ api_key=comet_ml_api_key,
145
+ project_name=project,
146
+ workspace=comet_ml_workspace
147
+ )
148
+ hyper_params = {
149
+ "batch_size": batch_size,
150
+ "base_learning_rate": base_learning_rate,
151
+ "num_epochs": num_epochs,
152
+ }
153
+ comet_experiment.log_parameters(hyper_params)
154
+
155
+ # --------------------------- Загрузка моделей ---------------------------
156
+ vae = AsymmetricAutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
157
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
158
+ tokenizer = None
159
+ text_encoder = None
160
+
161
+ def load_text_encoder():
162
+ global tokenizer, text_encoder
163
+ if tokenizer is None:
164
+ tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer")
165
+ if text_encoder is None:
166
+ text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained(
167
+ "text_encoder",
168
+ torch_dtype=torch.float16
169
+ ).to(device).eval()
170
+
171
+ def encode_texts(texts, max_length=max_length):
172
+ load_text_encoder()
173
+ if texts is None:
174
+ texts = [""]
175
+ if isinstance(texts, str):
176
+ texts = [texts]
177
+
178
+ with torch.no_grad():
179
+
180
+ # --- 2. QWEN Энкодер (через Chat Template) ---
181
+ # 1. Собираем текстовые промпты оборачивая их в Chat Template
182
+ formatted_prompts = []
183
+ for t in texts:
184
+ messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
185
+ res_text = tokenizer.apply_chat_template(
186
+ messages,
187
+ add_generation_prompt=True,
188
+ tokenize=False
189
+ )
190
+ formatted_prompts.append(res_text)
191
+
192
+ # 2. Токенизируем, режем и добавляем паддинг за один раз
193
+ toks = tokenizer(
194
+ formatted_prompts,
195
+ padding="max_length",
196
+ max_length=max_length,
197
+ truncation=True,
198
+ return_tensors="pt"
199
+ ).to(device)
200
+
201
+ # 3. Прогоняем через модель
202
+ outputs = text_encoder(
203
+ input_ids=toks.input_ids,
204
+ attention_mask=toks.attention_mask,
205
+ output_hidden_states=True
206
+ )
207
+
208
+ layer_index = -2
209
+ last_hidden = outputs.hidden_states[layer_index]
210
+ seq_len = toks.attention_mask.sum(dim=1) - 1
211
+ pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
212
+ # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
213
+ # 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
214
+ pooled_expanded = pooled.unsqueeze(1)
215
+
216
+ # 2. Объединяем последовательность токенов и пулинг-вектор
217
+ # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
218
+ # Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
219
+ new_encoder_hidden_states = torch.cat([pooled_expanded, last_hidden], dim=1)
220
+
221
+ # 3. Обновляем маску внимания для нового токена
222
+ # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
223
+ # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
224
+ new_attention_mask = torch.cat([torch.ones((last_hidden.shape[0], 1), device=device), toks.attention_mask], dim=1)
225
+ return new_encoder_hidden_states.to(dtype), new_attention_mask
226
+
227
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
228
+ if shift_factor is None:
229
+ shift_factor = 0.0
230
+
231
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
232
+ if scaling_factor is None:
233
+ scaling_factor = 1.0
234
+
235
+ mean = getattr(vae.config, "latents_mean", None)
236
+ std = getattr(vae.config, "latents_std", None)
237
+ if mean is not None and std is not None:
238
+ latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
239
+ latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
240
+
241
+ import numpy as np
242
+ from torch.utils.data import Sampler
243
+
244
+
245
+ class DistributedResolutionBatchSampler(Sampler):
246
+ def __init__(self, dataset, batch_size, num_replicas, rank, drop_last=True, shuffle=True):
247
+ self.dataset = dataset
248
+ self.num_replicas = num_replicas
249
+ self.rank = rank
250
+ self.shuffle = shuffle
251
+ self.drop_last = drop_last
252
+ self.epoch = 0
253
+
254
+ # batch на одну GPU
255
+ self.batch_size = max(1, batch_size // num_replicas)
256
+ self.global_batch = self.batch_size * num_replicas
257
+
258
+ try:
259
+ widths = np.asarray(dataset["width"])
260
+ heights = np.asarray(dataset["height"])
261
+ except KeyError:
262
+ widths = np.zeros(len(dataset))
263
+ heights = np.zeros(len(dataset))
264
+
265
+ # --- группировка индексов ---
266
+ groups = {}
267
+ for i, (w, h) in enumerate(zip(widths, heights)):
268
+ groups.setdefault((w, h), []).append(i)
269
+
270
+ # --- создаём список всех глобальных батчей ---
271
+ all_batches = []
272
+
273
+ for indices in groups.values():
274
+
275
+ idx = np.asarray(indices, dtype=np.int64)
276
+
277
+ num_batches = len(idx) // self.global_batch
278
+ if num_batches == 0:
279
+ continue
280
+
281
+ idx = idx[: num_batches * self.global_batch]
282
+
283
+ batches = idx.reshape(num_batches, self.global_batch)
284
+
285
+ all_batches.append(batches)
286
+
287
+ if len(all_batches) > 0:
288
+ self.global_batches = np.concatenate(all_batches, axis=0)
289
+ else:
290
+ self.global_batches = np.empty((0, self.global_batch), dtype=np.int64)
291
+
292
+ self.num_batches = len(self.global_batches)
293
+
294
+ def __iter__(self):
295
+
296
+ rng = np.random.RandomState(self.epoch)
297
+
298
+ order = np.arange(self.num_batches)
299
+
300
+ if self.shuffle:
301
+ rng.shuffle(order)
302
+
303
+ start = self.rank * self.batch_size
304
+ end = start + self.batch_size
305
+
306
+ for i in order:
307
+ yield self.global_batches[i][start:end]
308
+
309
+ def __len__(self):
310
+ return self.num_batches
311
+
312
+ def set_epoch(self, epoch):
313
+ self.epoch = epoch
314
+
315
+
316
+
317
+ # --- [UPDATED] Функция для фиксированных семплов ---
318
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
319
+ size_groups = defaultdict(list)
320
+ try:
321
+ widths = dataset["width"]
322
+ heights = dataset["height"]
323
+ except KeyError:
324
+ widths = [0] * len(dataset)
325
+ heights = [0] * len(dataset)
326
+ for i, (w, h) in enumerate(zip(widths, heights)):
327
+ size = (w, h)
328
+ size_groups[size].append(i)
329
+
330
+ fixed_samples = {}
331
+ for size, indices in size_groups.items():
332
+ n_samples = min(samples_per_group, len(indices))
333
+ if len(size_groups)==1:
334
+ n_samples = samples_to_generate
335
+ if n_samples == 0:
336
+ continue
337
+ sample_indices = random.sample(indices, n_samples)
338
+ samples_data = [dataset[idx] for idx in sample_indices]
339
+
340
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
341
+ texts = [item["text"] for item in samples_data]
342
+
343
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
344
+ #embeddings, masks = encode_texts(texts)
345
+ if use_precomputed_embeddings:
346
+ embeddings = torch.tensor(
347
+ np.array([item["embeddings"] for item in samples_data]),
348
+ device=device,
349
+ dtype=dtype
350
+ )
351
+ masks = torch.tensor(
352
+ np.array([item["attention_mask"] for item in samples_data]),
353
+ device=device,
354
+ dtype=torch.int64
355
+ )
356
+ else:
357
+ embeddings, masks = encode_texts(texts)
358
+
359
+ fixed_samples[size] = (latents, embeddings, masks, texts)
360
+
361
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
362
+ return fixed_samples
363
+
364
+ if limit > 0:
365
+ dataset = load_from_disk(ds_path).select(range(limit))
366
+ else:
367
+ dataset = load_from_disk(ds_path)
368
+
369
+
370
+ print(f"images: {len(dataset)}")
371
+
372
+ def collate_fn_simple(batch):
373
+
374
+ latents = torch.from_numpy(
375
+ np.array([item["vae"] for item in batch], dtype=np.float16)
376
+ ).to(device, dtype=dtype)
377
+
378
+ if use_precomputed_embeddings:
379
+ embeddings = torch.from_numpy(
380
+ np.array([item["embeddings"] for item in batch], dtype=np.float16)
381
+ ).to(device, dtype=dtype)
382
+
383
+ attention_mask = torch.from_numpy(
384
+ np.array([item["attention_mask"] for item in batch], dtype=np.int64)
385
+ ).to(device)
386
+
387
+ return latents, embeddings, attention_mask
388
+
389
+ raw_texts = [item["text"] for item in batch]
390
+
391
+ texts = [
392
+ "" if t.lower().startswith("zero")
393
+ else "" if random.random() < cfg_dropout
394
+ else t[1:].lstrip() if t.startswith(".")
395
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
396
+ for t in raw_texts
397
+ ]
398
+
399
+ embeddings, attention_mask = encode_texts(texts)
400
+ attention_mask = attention_mask.to(dtype=torch.int64)
401
+
402
+ return latents, embeddings, attention_mask
403
+
404
+ batch_sampler = DistributedResolutionBatchSampler(
405
+ dataset=dataset,
406
+ batch_size=batch_size,
407
+ num_replicas=accelerator.num_processes,
408
+ rank=accelerator.process_index,
409
+ shuffle = shuffle
410
+ )
411
+
412
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
413
+
414
+ if accelerator.is_main_process:
415
+ print("Total samples", len(dataloader))
416
+ dataloader = accelerator.prepare(dataloader)
417
+
418
+ start_epoch = 0
419
+ global_step = 0
420
+ total_training_steps = (len(dataloader) * num_epochs)
421
+ world_size = accelerator.state.num_processes
422
+
423
+ # Загрузка UNet
424
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
425
+ if os.path.isdir(latest_checkpoint):
426
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
427
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
428
+ if unet_gradient:
429
+ unet.enable_gradient_checkpointing()
430
+ unet.set_use_memory_efficient_attention_xformers(False)
431
+ try:
432
+ unet.set_attn_processor(AttnProcessor2_0())
433
+ except Exception as e:
434
+ print(f"Ошибка при включении SDPA: {e}")
435
+ unet.set_use_memory_efficient_attention_xformers(True)
436
+ else:
437
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
438
+
439
+ # --------------------------- Muon Implementation (FIXED for UNet Conv2d) ---------------------------
440
+ import torch.distributed as dist
441
+
442
+ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
443
+ """
444
+ Оптимизированная версия: поддерживает BF16 и батчинг.
445
+ """
446
+ a, b, c = (3.4445, -4.7750, 2.0315)
447
+
448
+ # Используем bfloat16 для скорости (если карта поддерживает, иначе float16/32)
449
+ # Важно: приведение типа делаем здесь
450
+ if G.dtype != torch.bfloat16 and G.dtype != torch.float32:
451
+ # Если пришло float16, лучше считать в float32 или bfloat16 для точности
452
+ X = G.float()
453
+ else:
454
+ X = G
455
+
456
+ # Логика транспонирования (M > N)
457
+ transposed = G.size(-2) > G.size(-1)
458
+ if transposed:
459
+ X = X.mT # .mT - это View, не копирует память
460
+
461
+ # Нормализация по последним двум осям (корректно для любых размерностей)
462
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
463
+
464
+ for _ in range(steps):
465
+ # A = X @ X.T
466
+ A = X @ X.mT
467
+ # B = b*A + c*A@A
468
+ B = b * A + c * A @ A
469
+ # X = a*X + B@X
470
+ X = a * X + B @ X
471
+
472
+ return X.mT if transposed else X
473
+
474
+ class MuonAdamW8bit(torch.optim.Optimizer):
475
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01,
476
+ muon_lr=None, muon_momentum=0.95, ns_steps=5):
477
+ params = list(params)
478
+
479
+ matrix_params = []
480
+ scalar_params = []
481
+
482
+ for p in params:
483
+ if p.ndim >= 2:
484
+ matrix_params.append(p)
485
+ else:
486
+ scalar_params.append(p)
487
+
488
+ # --- ВАЖНОЕ ИСПРАВЛЕНИЕ LR ---
489
+ # Если базовый lr=1e-5, Muon должен быть ~0.01 -> 0.02
490
+ # Ставим коэф x1000 или x2000 от базового, или задаем явно.
491
+ # Автор ставит 0.02. Давайте поставим 0.01 как безопасный минимум для старта.
492
+ actual_muon_lr = muon_lr if muon_lr is not None else lr * muon_lr_scale
493
+ self.muon_lr_scale = actual_muon_lr / lr if lr > 0 else 1.0
494
+ self.ns_steps = ns_steps
495
+
496
+ # Инициализируем Muon (передаем ему только матрицы)
497
+ # Обратите внимание: lr передаем актуальный
498
+ self.muon_opt = MuonInternal(matrix_params, lr=actual_muon_lr, momentum=muon_momentum, ns_steps=ns_steps)
499
+
500
+ # Adam для скаляров (biases, norms)
501
+ self.adam_opt = bnb.optim.AdamW8bit(scalar_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
502
+
503
+ self.defaults = dict(lr=lr)
504
+ self.state = {}
505
+ self.param_groups = self.adam_opt.param_groups
506
+
507
+ def zero_grad(self, set_to_none=False):
508
+ self.muon_opt.zero_grad(set_to_none)
509
+ self.adam_opt.zero_grad(set_to_none)
510
+
511
+ @torch.no_grad()
512
+ def step(self, closure=None):
513
+ # Синхронизация LR (если шедулер меняет lr в adam группах)
514
+ current_base_lr = self.adam_opt.param_groups[0]['lr']
515
+ for group in self.muon_opt.param_groups:
516
+ group['lr'] = current_base_lr * self.muon_lr_scale
517
+
518
+ self.muon_opt.step(closure)
519
+ self.adam_opt.step(closure)
520
+
521
+ # state_dict/load_state_dict оставляем как было у вас
522
+ def state_dict(self):
523
+ return {
524
+ 'muon': self.muon_opt.state_dict(),
525
+ 'adam': self.adam_opt.state_dict()
526
+ }
527
+
528
+ def load_state_dict(self, state_dict):
529
+ self.muon_opt.load_state_dict(state_dict['muon'])
530
+ self.adam_opt.load_state_dict(state_dict['adam'])
531
+
532
+
533
+ # Внутренний класс Muon (логика обновления матриц)
534
+ class MuonInternal(torch.optim.Optimizer):
535
+ def __init__(self, params, lr=0.01, momentum=0.95, ns_steps=5):
536
+ super().__init__(params, dict(lr=lr, momentum=momentum, ns_steps=ns_steps))
537
+ self.distributed = torch.distributed.is_initialized()
538
+
539
+ @torch.no_grad()
540
+ def step(self, closure=None):
541
+ loss = None
542
+ if closure is not None:
543
+ with torch.enable_grad():
544
+ loss = closure()
545
+
546
+ for group in self.param_groups:
547
+ lr = group["lr"]
548
+ momentum = group["momentum"]
549
+ ns_steps = group["ns_steps"]
550
+
551
+ for p in group["params"]:
552
+ if p.grad is None:
553
+ continue
554
+
555
+ g = p.grad
556
+ state = self.state[p]
557
+
558
+ if "momentum_buffer" not in state:
559
+ state["momentum_buffer"] = torch.zeros_like(g, dtype=torch.float32)
560
+
561
+ buf = state["momentum_buffer"]
562
+
563
+ # Обновление момента
564
+ buf.mul_(momentum).add_(g.float())
565
+
566
+ # Nesterov
567
+ update = g.float().add(buf, alpha=momentum)
568
+
569
+ # Обработка Conv2d (flatten)
570
+ original_shape = g.shape
571
+ if g.ndim == 4:
572
+ update = update.view(update.size(0), -1)
573
+
574
+ # Ортогонализация (теперь принимает update и steps)
575
+ update = zeropower_via_newtonschulz5(update, steps=ns_steps)
576
+
577
+ # Масштабирование (стандарт для Muon)
578
+ update *= max(1, update.size(-2) / update.size(-1)) ** 0.5
579
+
580
+ # Возврат формы для Conv
581
+ if g.ndim == 4:
582
+ update = update.view(original_shape)
583
+
584
+ p.add_(update, alpha=-lr)
585
+ return loss
586
+
587
+ # --------------------------- End Muon Implementation ---------------------------
588
+
589
+ def create_optimizer(name, params):
590
+ if name == "adam8bit":
591
+ return bnb.optim.AdamW8bit(
592
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01
593
+ )
594
+ elif name == "adam":
595
+ return torch.optim.AdamW(
596
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01
597
+ )
598
+ elif name == "muon_adam8bit":
599
+ return MuonAdamW8bit(
600
+ params,
601
+ lr=base_learning_rate,
602
+ betas=(0.9, betta2),
603
+ eps=eps,
604
+ weight_decay=0.01,
605
+ muon_momentum=0.95,
606
+ ns_dtype=torch.bfloat16,
607
+ muon_lr_mult=1000.0,
608
+ )
609
+ else:
610
+ raise ValueError(f"Unknown optimizer: {name}")
611
+
612
+ if fbp:
613
+ trainable_params = list(unet.parameters())
614
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
615
+ def optimizer_hook(param):
616
+ optimizer_dict[param].step()
617
+ optimizer_dict[param].zero_grad(set_to_none=True)
618
+ for param in trainable_params:
619
+ param.register_post_accumulate_grad_hook(optimizer_hook)
620
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
621
+ else:
622
+ unet.requires_grad_(True)
623
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
624
+ # 1. Сначала замораживаем ВСЕ параметры UNet
625
+ #unet.requires_grad_(False)
626
+
627
+ # 2. Размораживаем только нужные
628
+ #trainable_params_names = ["conv_in.weight", "conv_in.bias", "conv_out.weight", "conv_out.bias"]
629
+ #train_params = []
630
+
631
+ #for name, param in unet.named_parameters():
632
+ # if any(target in name for target in trainable_params_names):
633
+ # param.requires_grad = True
634
+ # train_params.append(param)
635
+ # print(f"Обучаемый слой: {name}")
636
+
637
+ def lr_schedule(step):
638
+ x = step / (total_training_steps * world_size)
639
+ warmup = warmup_percent
640
+ if not use_decay:
641
+ return base_learning_rate
642
+ if x < warmup:
643
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
644
+ decay_ratio = (x - warmup) / (1 - warmup)
645
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
646
+ (1 + math.cos(math.pi * decay_ratio))
647
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
648
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
649
+
650
+ if torch_compile:
651
+ print("compiling")
652
+ unet = torch.compile(unet)
653
+ print("compiling - ok")
654
+
655
+ # Фиксированные семплы
656
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
657
+
658
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
659
+ def get_negative_embedding(neg_prompt="", batch_size=1):
660
+ if not neg_prompt:
661
+ hidden_dim = 2048
662
+ seq_len = max_length
663
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
664
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
665
+ return empty_emb, empty_mask
666
+
667
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
668
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
669
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
670
+
671
+ return uncond_emb, uncond_mask
672
+
673
+ # Получаем негативные (пустые) условия для валидации
674
+ if use_precomputed_embeddings:
675
+ # 1. грузим encoder ВРЕМЕННО
676
+ load_text_encoder()
677
+
678
+ # 2. считаем negative
679
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
680
+
681
+ # 3. уносим на CPU (очень важно)
682
+ uncond_emb = uncond_emb.to("cpu")
683
+ uncond_mask = uncond_mask.to("cpu")
684
+
685
+ # 4. выгружаем encoder с GPU
686
+ del text_encoder
687
+ torch.cuda.empty_cache()
688
+ gc.collect()
689
+
690
+ text_encoder = None
691
+
692
+ else:
693
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
694
+
695
+ # --- Функция генерации семплов ---
696
+ @torch.compiler.disable()
697
+ @torch.no_grad()
698
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
699
+ uncond_emb, uncond_mask = uncond_data
700
+ uncond_emb = uncond_emb.to(device)
701
+ uncond_mask = uncond_mask.to(device)
702
+
703
+ original_model = None
704
+ try:
705
+ if not torch_compile:
706
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
707
+ else:
708
+ original_model = unet.eval()
709
+
710
+ vae.to(device=device).eval()
711
+
712
+ all_generated_images = []
713
+ all_captions = []
714
+
715
+ # Распаковываем 5 элементов (добавились mask)
716
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
717
+ width, height = size
718
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
719
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
720
+ sample_mask = sample_mask.to(device=device)
721
+
722
+ latents = torch.randn(
723
+ sample_latents.shape,
724
+ device=device,
725
+ dtype=sample_latents.dtype,
726
+ generator=torch.Generator(device=device).manual_seed(seed)
727
+ )
728
+
729
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
730
+
731
+ for t in scheduler.timesteps:
732
+ if guidance_scale != 1:
733
+ latent_model_input = torch.cat([latents, latents], dim=0)
734
+
735
+ curr_batch_size = sample_text_embeddings.shape[0]
736
+ seq_len = sample_text_embeddings.shape[1]
737
+ hidden_dim = sample_text_embeddings.shape[2]
738
+
739
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
740
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
741
+
742
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
743
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
744
+
745
+ else:
746
+ latent_model_input = latents
747
+ text_embeddings_batch = sample_text_embeddings
748
+ attention_mask_batch = sample_mask
749
+
750
+ # Теперь всё имеет одинаковый batch size
751
+ model_out = original_model(
752
+ latent_model_input,
753
+ t,
754
+ encoder_hidden_states=text_embeddings_batch,
755
+ encoder_attention_mask=attention_mask_batch,
756
+ )
757
+
758
+ flow = getattr(model_out, "sample", model_out)
759
+
760
+ if guidance_scale != 1:
761
+ flow_uncond, flow_cond = flow.chunk(2)
762
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
763
+
764
+ latents = scheduler.step(flow, t, latents).prev_sample
765
+
766
+ current_latents = latents
767
+ if step==0:
768
+ current_latents = sample_latents
769
+
770
+ if latents_mean is not None and latents_std is not None:
771
+ latents = current_latents * latents_std + latents_mean
772
+
773
+ decoded = vae.decode(latents.to(torch.float32)).sample
774
+ decoded_fp32 = decoded.to(torch.float32)
775
+
776
+ for img_idx, img_tensor in enumerate(decoded_fp32):
777
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
778
+ img = img.transpose(1, 2, 0)
779
+
780
+ if np.isnan(img).any():
781
+ print("NaNs found, saving stopped! Step:", step)
782
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
783
+
784
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
785
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
786
+ max_w_overall = max(255, max_w_overall)
787
+ max_h_overall = max(255, max_h_overall)
788
+
789
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
790
+ all_generated_images.append(padded_img)
791
+
792
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
793
+ all_captions.append(caption_text)
794
+
795
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
796
+ pil_img.save(sample_path, "JPEG", quality=95)
797
+
798
+ if use_wandb and accelerator.is_main_process:
799
+ wandb_images = [
800
+ wandb.Image(img, caption=f"{all_captions[i]}")
801
+ for i, img in enumerate(all_generated_images)
802
+ ]
803
+ wandb.log({"generated_images": wandb_images})
804
+ if use_comet_ml and accelerator.is_main_process:
805
+ for i, img in enumerate(all_generated_images):
806
+ comet_experiment.log_image(
807
+ image_data=img,
808
+ name=f"step_{step}_img_{i}",
809
+ step=step,
810
+ metadata={"caption": all_captions[i]}
811
+ )
812
+ finally:
813
+ vae.to("cpu")
814
+ uncond_emb = uncond_emb.to("cpu")
815
+ uncond_mask = uncond_mask.to("cpu")
816
+ try:
817
+ all_generated_images.clear()
818
+ all_captions.clear()
819
+ del all_generated_images, all_captions
820
+ del latents, current_latents, latent_model_input, flow
821
+ del decoded, decoded_fp32
822
+ del sample_latents, sample_text_embeddings, sample_mask # Копии на GPU
823
+ del model_out
824
+ except UnboundLocalError:
825
+ pass
826
+
827
+ # 3. Синхронизируем CUDA перед очисткой
828
+ torch.cuda.synchronize()
829
+ # 4. Теперь чистим кэш аллокатора и вызываем GC
830
+ torch.cuda.empty_cache()
831
+ gc.collect()
832
+
833
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
834
+ if accelerator.is_main_process:
835
+ if save_model:
836
+ print("Генерация сэмплов до старта обучения...")
837
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
838
+ accelerator.wait_for_everyone()
839
+
840
+ def save_checkpoint(unet, variant=""):
841
+ if accelerator.is_main_process:
842
+ model_to_save = None
843
+ if not torch_compile:
844
+ model_to_save = accelerator.unwrap_model(unet)
845
+ else:
846
+ model_to_save = unet
847
+
848
+ if variant != "":
849
+ model_to_save.to(dtype=torch.float16).save_pretrained(
850
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
851
+ )
852
+ else:
853
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
854
+
855
+ torch.cuda.synchronize()
856
+ torch.cuda.empty_cache()
857
+ gc.collect()
858
+
859
+ # --------------------------- Тренировочный цикл ---------------------------
860
+ if accelerator.is_main_process:
861
+ print(f"Total steps per GPU: {total_training_steps}")
862
+
863
+ epoch_loss_points = []
864
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
865
+
866
+ steps_per_epoch = len(dataloader)
867
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
868
+ min_loss = 4.
869
+ last_sample_time = time.time()
870
+ sample_interval_seconds = 60 * 60 # 60 минут
871
+
872
+ for epoch in range(start_epoch, start_epoch + num_epochs):
873
+ batch_losses = []
874
+ batch_grads = []
875
+ batch_sampler.set_epoch(epoch)
876
+ accelerator.wait_for_everyone()
877
+ unet.train()
878
+
879
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
880
+ with accelerator.accumulate(unet):
881
+ if save_model == False and epoch == 0 and step == 5 :
882
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
883
+ print(f"Шаг {step}: {used_gb:.2f} GB")
884
+
885
+ # шум
886
+ noise = torch.randn_like(latents, dtype=latents.dtype)
887
+
888
+ # 3. Время t (сэмплим, как и раньше, но чуть сжимаем края)
889
+ u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
890
+ t = u * (1 - 2 * 1e-5) + 1e-5 # Теперь t строго в (0.00001 ... 0.99999)
891
+ # интерполяция между x0 и шумом
892
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
893
+ # делаем integer timesteps для UNet
894
+ timesteps = t.to(torch.float32).mul(999.0)
895
+ timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1)
896
+
897
+ # --- Вызов UNet с маской ---
898
+ model_pred = unet(
899
+ noisy_latents,
900
+ timesteps,
901
+ encoder_hidden_states=embeddings,
902
+ encoder_attention_mask=attention_mask,
903
+ ).sample
904
+
905
+ target = noise - latents
906
+
907
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
908
+ batch_losses.append(mse_loss.detach().item())
909
+
910
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
911
+ accelerator.wait_for_everyone()
912
+
913
+ losses_dict = {}
914
+ losses_dict["mse"] = mse_loss
915
+
916
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
917
+ accelerator.wait_for_everyone()
918
+
919
+ accelerator.backward(mse_loss)
920
+
921
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
922
+ accelerator.wait_for_everyone()
923
+
924
+ grad = 0.0
925
+ if not fbp:
926
+ if accelerator.sync_gradients:
927
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
928
+ grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
929
+ optimizer.step()
930
+ lr_scheduler.step()
931
+ optimizer.zero_grad(set_to_none=True)
932
+
933
+ if accelerator.sync_gradients:
934
+ global_step += 1
935
+ progress_bar.update(1)
936
+ if accelerator.is_main_process:
937
+ if fbp:
938
+ current_lr = base_learning_rate
939
+ else:
940
+ current_lr = lr_scheduler.get_last_lr()[0]
941
+ batch_grads.append(grad)
942
+
943
+ log_data = {}
944
+ log_data["loss_mse"] = mse_loss.detach().item()
945
+ log_data["lr"] = current_lr
946
+ log_data["grad"] = grad
947
+ if accelerator.sync_gradients:
948
+ if use_wandb:
949
+ wandb.log(log_data, step=global_step)
950
+ if use_comet_ml:
951
+ comet_experiment.log_metrics(log_data, step=global_step)
952
+
953
+ current_time = time.time()
954
+ is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds
955
+ if is_time_to_sample or global_step == 50:
956
+ # Передаем tuple (emb, mask) для негатива
957
+ if save_model:
958
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
959
+ elif epoch % 10 == 0:
960
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
961
+ last_n = sample_interval
962
+
963
+ if save_model:
964
+ has_losses = len(batch_losses) > 0
965
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
966
+ last_loss = batch_losses[-1] if has_losses else 0.0
967
+ max_loss = max(avg_sample_loss, last_loss)
968
+ should_save = max_loss < min_loss * save_barrier
969
+ print(
970
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
971
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
972
+ )
973
+ # 6. Сохранение и обновление
974
+ if should_save:
975
+ min_loss = max_loss
976
+ save_checkpoint(unet)
977
+ last_sample_time = current_time
978
+ unet.train()
979
+
980
+ if accelerator.is_main_process:
981
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
982
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
983
+
984
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
985
+ log_data_ep = {
986
+ "epoch_loss": avg_epoch_loss,
987
+ "epoch_grad": avg_epoch_grad,
988
+ "epoch": epoch + 1,
989
+ }
990
+ if use_wandb:
991
+ wandb.log(log_data_ep)
992
+ if use_comet_ml:
993
+ comet_experiment.log_metrics(log_data_ep)
994
+
995
+ if accelerator.is_main_process:
996
+ print("Обучение завершено! Сохраняем финальную модель...")
997
+ #if save_model:
998
+ save_checkpoint(unet,"fp16")
999
+ if use_comet_ml:
1000
+ comet_experiment.end()
1001
+ accelerator.free_memory()
1002
+ if torch.distributed.is_initialized():
1003
+ torch.distributed.destroy_process_group()
1004
+
1005
+ print("Готово!")
train.py CHANGED
@@ -28,6 +28,8 @@ import torch.nn.functional as F
28
  from collections import deque
29
  from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
30
  import argparse
 
 
31
 
32
  # --------------------------- Параметры ---------------------------
33
  ds_path = "datasets/ds1234_noanime_704_vae8x16x"
@@ -40,8 +42,8 @@ local_bs = max(1, int((gpu_mem_gb / 32) * 8))
40
  num_gpus = torch.cuda.device_count()
41
  batch_size = local_bs * num_gpus
42
  print(f"GPUs: {num_gpus}, Local BS: {local_bs}, Global BS: {local_bs * num_gpus}")
43
- base_learning_rate = 1e-5
44
- min_learning_rate = 1e-7
45
  num_epochs = num_gpus
46
  sample_interval_share = 20
47
  cfg_dropout = 0.10
@@ -60,7 +62,7 @@ shuffle = True
60
  muon_lr_scale = 300
61
  optimizer_type = "muon_adam8bit"
62
  if optimizer_type == "muon_adam8bit":
63
- batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3))
64
  comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
65
  comet_ml_workspace = "recoilme"
66
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -434,156 +436,6 @@ if os.path.isdir(latest_checkpoint):
434
  else:
435
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
436
 
437
- # --------------------------- Muon Implementation (FIXED for UNet Conv2d) ---------------------------
438
- import torch.distributed as dist
439
-
440
- def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
441
- """
442
- Оптимизированная версия: поддерживает BF16 и батчинг.
443
- """
444
- a, b, c = (3.4445, -4.7750, 2.0315)
445
-
446
- # Используем bfloat16 для скорости (если карта поддерживает, иначе float16/32)
447
- # Важно: приведение типа делаем здесь
448
- if G.dtype != torch.bfloat16 and G.dtype != torch.float32:
449
- # Если пришло float16, лучше считать в float32 или bfloat16 для точности
450
- X = G.float()
451
- else:
452
- X = G
453
-
454
- # Логика транспонирования (M > N)
455
- transposed = G.size(-2) > G.size(-1)
456
- if transposed:
457
- X = X.mT # .mT - это View, не копирует память
458
-
459
- # Нормализация по последним двум осям (корректно для любых размерностей)
460
- X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
461
-
462
- for _ in range(steps):
463
- # A = X @ X.T
464
- A = X @ X.mT
465
- # B = b*A + c*A@A
466
- B = b * A + c * A @ A
467
- # X = a*X + B@X
468
- X = a * X + B @ X
469
-
470
- return X.mT if transposed else X
471
-
472
- class MuonAdamW8bit(torch.optim.Optimizer):
473
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01,
474
- muon_lr=None, muon_momentum=0.95, ns_steps=5):
475
- params = list(params)
476
-
477
- matrix_params = []
478
- scalar_params = []
479
-
480
- for p in params:
481
- if p.ndim >= 2:
482
- matrix_params.append(p)
483
- else:
484
- scalar_params.append(p)
485
-
486
- # --- ВАЖНОЕ ИСПРАВЛЕНИЕ LR ---
487
- # Если базовый lr=1e-5, Muon должен быть ~0.01 -> 0.02
488
- # Ставим коэф x1000 или x2000 от базового, или задаем явно.
489
- # Автор ставит 0.02. Давайте поставим 0.01 как безопасный минимум для старта.
490
- actual_muon_lr = muon_lr if muon_lr is not None else lr * muon_lr_scale
491
- self.muon_lr_scale = actual_muon_lr / lr if lr > 0 else 1.0
492
- self.ns_steps = ns_steps
493
-
494
- # Инициализируем Muon (передаем ему только матрицы)
495
- # Обратите внимание: lr передаем актуальный
496
- self.muon_opt = MuonInternal(matrix_params, lr=actual_muon_lr, momentum=muon_momentum, ns_steps=ns_steps)
497
-
498
- # Adam для скаляров (biases, norms)
499
- self.adam_opt = bnb.optim.AdamW8bit(scalar_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
500
-
501
- self.defaults = dict(lr=lr)
502
- self.state = {}
503
- self.param_groups = self.adam_opt.param_groups
504
-
505
- def zero_grad(self, set_to_none=False):
506
- self.muon_opt.zero_grad(set_to_none)
507
- self.adam_opt.zero_grad(set_to_none)
508
-
509
- @torch.no_grad()
510
- def step(self, closure=None):
511
- # Синхронизация LR (если шедулер меняет lr в adam группах)
512
- current_base_lr = self.adam_opt.param_groups[0]['lr']
513
- for group in self.muon_opt.param_groups:
514
- group['lr'] = current_base_lr * self.muon_lr_scale
515
-
516
- self.muon_opt.step(closure)
517
- self.adam_opt.step(closure)
518
-
519
- # state_dict/load_state_dict оставляем как было у вас
520
- def state_dict(self):
521
- return {
522
- 'muon': self.muon_opt.state_dict(),
523
- 'adam': self.adam_opt.state_dict()
524
- }
525
-
526
- def load_state_dict(self, state_dict):
527
- self.muon_opt.load_state_dict(state_dict['muon'])
528
- self.adam_opt.load_state_dict(state_dict['adam'])
529
-
530
-
531
- # Внутренний класс Muon (логика обновления матриц)
532
- class MuonInternal(torch.optim.Optimizer):
533
- def __init__(self, params, lr=0.01, momentum=0.95, ns_steps=5):
534
- super().__init__(params, dict(lr=lr, momentum=momentum, ns_steps=ns_steps))
535
- self.distributed = torch.distributed.is_initialized()
536
-
537
- @torch.no_grad()
538
- def step(self, closure=None):
539
- loss = None
540
- if closure is not None:
541
- with torch.enable_grad():
542
- loss = closure()
543
-
544
- for group in self.param_groups:
545
- lr = group["lr"]
546
- momentum = group["momentum"]
547
- ns_steps = group["ns_steps"]
548
-
549
- for p in group["params"]:
550
- if p.grad is None:
551
- continue
552
-
553
- g = p.grad
554
- state = self.state[p]
555
-
556
- if "momentum_buffer" not in state:
557
- state["momentum_buffer"] = torch.zeros_like(g, dtype=torch.float32)
558
-
559
- buf = state["momentum_buffer"]
560
-
561
- # Обновление момента
562
- buf.mul_(momentum).add_(g.float())
563
-
564
- # Nesterov
565
- update = g.float().add(buf, alpha=momentum)
566
-
567
- # Обработка Conv2d (flatten)
568
- original_shape = g.shape
569
- if g.ndim == 4:
570
- update = update.view(update.size(0), -1)
571
-
572
- # Ортогонализация (теперь принимает update и steps)
573
- update = zeropower_via_newtonschulz5(update, steps=ns_steps)
574
-
575
- # Масштабирование (стандарт для Muon)
576
- update *= max(1, update.size(-2) / update.size(-1)) ** 0.5
577
-
578
- # Возврат формы для Conv
579
- if g.ndim == 4:
580
- update = update.view(original_shape)
581
-
582
- p.add_(update, alpha=-lr)
583
- return loss
584
-
585
- # --------------------------- End Muon Implementation ---------------------------
586
-
587
  def create_optimizer(name, params):
588
  if name == "adam8bit":
589
  return bnb.optim.AdamW8bit(
@@ -600,7 +452,9 @@ def create_optimizer(name, params):
600
  betas=(0.9, betta2),
601
  eps=eps,
602
  weight_decay=0.01,
603
- muon_momentum=0.95
 
 
604
  )
605
  else:
606
  raise ValueError(f"Unknown optimizer: {name}")
 
28
  from collections import deque
29
  from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
30
  import argparse
31
+ # pip install git+https://github.com/recoilme/muon_adamw8bit.git
32
+ from muon_adamw8bit import MuonAdamW8bit
33
 
34
  # --------------------------- Параметры ---------------------------
35
  ds_path = "datasets/ds1234_noanime_704_vae8x16x"
 
42
  num_gpus = torch.cuda.device_count()
43
  batch_size = local_bs * num_gpus
44
  print(f"GPUs: {num_gpus}, Local BS: {local_bs}, Global BS: {local_bs * num_gpus}")
45
+ base_learning_rate = 4e-5
46
+ min_learning_rate = 4e-7
47
  num_epochs = num_gpus
48
  sample_interval_share = 20
49
  cfg_dropout = 0.10
 
62
  muon_lr_scale = 300
63
  optimizer_type = "muon_adam8bit"
64
  if optimizer_type == "muon_adam8bit":
65
+ batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 4))
66
  comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
67
  comet_ml_workspace = "recoilme"
68
  torch.backends.cuda.matmul.allow_tf32 = True
 
436
  else:
437
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def create_optimizer(name, params):
440
  if name == "adam8bit":
441
  return bnb.optim.AdamW8bit(
 
452
  betas=(0.9, betta2),
453
  eps=eps,
454
  weight_decay=0.01,
455
+ muon_momentum=0.95,
456
+ ns_dtype=torch.bfloat16,
457
+ muon_lr_mult=1000.0,
458
  )
459
  else:
460
  raise ValueError(f"Unknown optimizer: {name}")
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:24a108643923cb0dc32d1be767a8004df4a4cc5dfd880464b864bd2b4b841fc9
3
  size 6294042336
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:feb68ae24f6d1ada2f50a1bbd849b45f5ffd116c6bc0438b28ef9f1bd4cbd3cd
3
  size 6294042336