recoilme commited on
Commit
63c5f0a
·
1 Parent(s): e282e82
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: 1d163e86821273d081c2e69818c628f56ffb932d713dbb1f4a3c3c742d464840
  • Pointer size: 132 Bytes
  • Size of remote file: 4.39 MB

Git LFS Details

  • SHA256: af9cd2ac730d4c22045af10f38266896637c1e28fd54aff58f0274718081331f
  • Pointer size: 132 Bytes
  • Size of remote file: 4.39 MB
samples/unet_384x768_0.jpg CHANGED

Git LFS Details

  • SHA256: a040f86547070a4c57b2df71dc3292a424f9f35bcb78eefa075465e50ebb1fdf
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB

Git LFS Details

  • SHA256: c31d94cb6a9311f40a7ceb41d42d65385851b155d5416d13143cf3ddfa3452de
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
samples/unet_416x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 43225e12cdadcb7a8aa356df3caf9c20bf21be65a9e43c23e5d0f2adc602531e
  • Pointer size: 130 Bytes
  • Size of remote file: 59.4 kB

Git LFS Details

  • SHA256: 8f13b696eb73a16e349d1af8c7a49fdce35abafb19c67b788df6e1bb82753460
  • Pointer size: 130 Bytes
  • Size of remote file: 59.3 kB
samples/unet_448x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 6e1f8b8cf67bae35819f1924f2b553ea1e7605acb74a3b94e2af55f94a0bc3b1
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB

Git LFS Details

  • SHA256: 1c5bbfae5758b169851db91b63b62dcc323902eefb249e642fdcf0bdc1cc95e0
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
samples/unet_480x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 96c51b5eb7457576638ca743c57374b75d6cbb19921613e2eaa05a411c53b913
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB

Git LFS Details

  • SHA256: 94b985a9fdefbe301e9146086fc333ddd91c202fabc520389396108cd083f227
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
samples/unet_512x768_0.jpg CHANGED

Git LFS Details

  • SHA256: f9d60f168c3716835ec953ca6e683b26f86294bd4854c7d84c470482e77e5347
  • Pointer size: 131 Bytes
  • Size of remote file: 202 kB

Git LFS Details

  • SHA256: 6dc90f559bb696c5f4d3c35a6a032b582021bd1beb20f2af5470e4e075ab0d4e
  • Pointer size: 131 Bytes
  • Size of remote file: 266 kB
samples/unet_544x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 8399621f9c37e25f5d077bc104ffa07481abff4e40542c5da594e29f24266e42
  • Pointer size: 130 Bytes
  • Size of remote file: 84.2 kB

Git LFS Details

  • SHA256: 1868b53ea6df1ff82a3be9166659907dc52a329a495170ce7e34db741d2fe5db
  • Pointer size: 130 Bytes
  • Size of remote file: 83.1 kB
samples/unet_576x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 11e3da1c46f6bf7ac04dcca16dbf38bab919ab4572d2162a6d7c1e1487746dc6
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB

Git LFS Details

  • SHA256: 098a0981874cccc3dd31e8a03d0a0a2efac9b7c62c13c74186a2c6272b35e807
  • Pointer size: 131 Bytes
  • Size of remote file: 100 kB
samples/unet_608x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 477cde18a41782c5c8f3b80b4538520dbf4cd4870391dc0e9864596606694827
  • Pointer size: 130 Bytes
  • Size of remote file: 98.6 kB

Git LFS Details

  • SHA256: 3de722ebcc80c553becc1d806a9ce3688bd0d30d0155f2d284ff784e0c057c0c
  • Pointer size: 130 Bytes
  • Size of remote file: 95.2 kB
samples/unet_640x768_0.jpg CHANGED

Git LFS Details

  • SHA256: d57c88ae482d8fe0b1c2ab060e3e664c4f9ae783ee694b30dc68c44c1fce5017
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB

Git LFS Details

  • SHA256: e21123ea9c8647c22eb0e82cb03ab4166259dfdf748206b90755d5d41a118e18
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
samples/unet_672x768_0.jpg CHANGED

Git LFS Details

  • SHA256: c59da6bc264d3f34db700a4313353bc9c15d4e0f5cf6f0e035c55a210c95125a
  • Pointer size: 130 Bytes
  • Size of remote file: 84.8 kB

Git LFS Details

  • SHA256: 2e4da32f771eb2cf9192b91bf4632de5b0761ff667dc9c8e5937cd1e762f89c6
  • Pointer size: 130 Bytes
  • Size of remote file: 74.5 kB
samples/unet_704x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 9e169b71ebb0b785ed0612ae7139e58fcfa6e205ff40581f335771c3f3927ad6
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB

Git LFS Details

  • SHA256: 39551b77b520c4f3e5b0ca121b28518414bcd454973916ec122127828da34a1f
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
samples/unet_736x768_0.jpg CHANGED

Git LFS Details

  • SHA256: e6e54404c509e26c7f36c56ed3c93d64929048f9756052ac91d846d5438b1b25
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB

Git LFS Details

  • SHA256: aac31ccb66c3c013352a935e8ca25c02bd47ba04be2b8765822f122eb814b797
  • Pointer size: 130 Bytes
  • Size of remote file: 94.1 kB
samples/unet_768x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 0a8ad0b43141e74bff9dfbad0759ef1e1c0a9e06a40e24d4b2417f34f6c8ead5
  • Pointer size: 130 Bytes
  • Size of remote file: 79 kB

Git LFS Details

  • SHA256: 2c0d411167fecc6da9ce330ce84a400c6807f2b355fe497ebd767df4b5405d03
  • Pointer size: 131 Bytes
  • Size of remote file: 134 kB
samples/unet_768x416_0.jpg CHANGED

Git LFS Details

  • SHA256: 272c7c92656f8285f68cd5c3c8036dcbd860bc6481e47eebe0c0a07d77d955e2
  • Pointer size: 130 Bytes
  • Size of remote file: 39.6 kB

Git LFS Details

  • SHA256: b551cf13e4b90e65bb80de5c488837c27e6908d37d40a478582a8812f49ef410
  • Pointer size: 130 Bytes
  • Size of remote file: 59.9 kB
samples/unet_768x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 00e5551870277a8a7b5acfcad2567555bc47c62cb6fcc4b302368026266ada4a
  • Pointer size: 130 Bytes
  • Size of remote file: 66.9 kB

Git LFS Details

  • SHA256: ff9f92093f9fe3f6f42278efdbd19bd0355c8865325e8ffd6944edb5a8ac5e2b
  • Pointer size: 130 Bytes
  • Size of remote file: 91 kB
samples/unet_768x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 85362442b8c96a255dcf7a2024f39d1b11c26edfe89fc3148a0e79cdd77f2335
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB

Git LFS Details

  • SHA256: 034ec77216964be75713bfef7eb3362db72fb13cecc1700d5f335722f3a4f22f
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
samples/unet_768x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 5037aada1a846cebb52dd11f88670f38de6c314dc12b7270194f3b2cdd6ec20e
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB

Git LFS Details

  • SHA256: d502866d550b8fa7a1ad9ccc7c520b0e7a0c734f65466575e39f0a7eac10adc4
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
samples/unet_768x544_0.jpg CHANGED

Git LFS Details

  • SHA256: 49dae1c4c588ea8c9be7d8146768598c6aaeced265c6ecbb6365e37adfe00e3d
  • Pointer size: 130 Bytes
  • Size of remote file: 97.5 kB

Git LFS Details

  • SHA256: 439e2eedb3256702e7486d1241f4c221694bbe6d5b715a4e281128f7883532d3
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
samples/unet_768x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 1b427e06d854f93984341a1043e2d01289cbdbea53734e6f20fd743f98a21610
  • Pointer size: 130 Bytes
  • Size of remote file: 66.4 kB

Git LFS Details

  • SHA256: 3e202f36f2e3316d908a0fa2bade4a5f396401dd37a027ac0bfea944591d3845
  • Pointer size: 130 Bytes
  • Size of remote file: 57.7 kB
samples/unet_768x608_0.jpg CHANGED

Git LFS Details

  • SHA256: 35434ac937839cd80735218c173d01afa9c9ade7a7af7ee33b841688f0e0d08f
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB

Git LFS Details

  • SHA256: 59744a32db00f0e6c5f8d40539ce09c665f52c5afbbcc4b5348a776c3553b9a3
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
samples/unet_768x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 81dba776b65e7c474ea3685cffa33223ac8b1db222a932cb934932674841c1af
  • Pointer size: 131 Bytes
  • Size of remote file: 302 kB

Git LFS Details

  • SHA256: 3299f80dd0536458074e4efed7af91b28247e1c4577ec0427bc188915f3889cc
  • Pointer size: 131 Bytes
  • Size of remote file: 223 kB
samples/unet_768x672_0.jpg CHANGED

Git LFS Details

  • SHA256: 9813461aec2e9d5a6c1954bcd06f0bbc83c90ab36ad5405f1cc770031542dd51
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB

Git LFS Details

  • SHA256: f55b7eb9f15682634f0db94983f86719de721f0aa8e9d51843a585d96120ef6d
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
samples/unet_768x704_0.jpg CHANGED

Git LFS Details

  • SHA256: 8c65c85ee5c11ea87ddb993df9b01f647a906dfd28f15410435585e9ad46d38c
  • Pointer size: 130 Bytes
  • Size of remote file: 98.9 kB

Git LFS Details

  • SHA256: 1fe1c1bf08f82bb30b142e313b2b7b300c7b19f6b49e2b291df60b5b432453fc
  • Pointer size: 130 Bytes
  • Size of remote file: 87.5 kB
samples/unet_768x736_0.jpg CHANGED

Git LFS Details

  • SHA256: b26c17b10dcd4ea9f5cb5db1882d8345c73c9f48461b3dd33fc28bc70a379851
  • Pointer size: 131 Bytes
  • Size of remote file: 188 kB

Git LFS Details

  • SHA256: 6b0c0d0addba9ad99604f99e4f6396a2179c261cae5cdc324ef706ced751a446
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
samples/unet_768x768_0.jpg CHANGED

Git LFS Details

  • SHA256: a525d2879293102655f7f02a9beb383193a513d85a4d6642408f81276ce3885b
  • Pointer size: 131 Bytes
  • Size of remote file: 257 kB

Git LFS Details

  • SHA256: 69f78847fed9e8a3e2124cc73ddb2fe00c5be9ec32db47b3af2240b6060c47f6
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ab81c847e03c202d860e45d2b1150d0708302a73198a5f5c6cab734ebd81960
3
- size 3143592
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39f108e7891549639260ddfe82c52a689652efa7287e7d1bf5be6ac9453d34b1
3
+ size 8665517
train-Copy2.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from comet_ml import Experiment
2
+ import os
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from torch.utils.data import DataLoader, Sampler
8
+ from torch.utils.data.distributed import DistributedSampler
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from collections import defaultdict
11
+ from diffusers import UNet2DConditionModel, AutoencoderKL
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image, ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+ from transformers import AutoTokenizer, AutoModel
28
+
29
+ # --------------------------- Параметры ---------------------------
30
+ ds_path = "/workspace/sdxs/datasets/768"
31
+ project = "unet"
32
+ batch_size = 36
33
+ base_learning_rate = 2.7e-5 #4e-5
34
+ min_learning_rate = 1e-5 #2.7e-5
35
+ num_epochs = 80
36
+ sample_interval_share = 5
37
+ max_length = 192
38
+ use_wandb = True
39
+ use_comet_ml = False
40
+ save_model = True
41
+ use_decay = True
42
+ fbp = False
43
+ optimizer_type = "adam8bit"
44
+ torch_compile = False
45
+ unet_gradient = True
46
+ fixed_seed = False
47
+ shuffle = True
48
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
49
+ comet_ml_workspace = "recoilme"
50
+ torch.backends.cuda.matmul.allow_tf32 = True
51
+ torch.backends.cudnn.allow_tf32 = True
52
+ #torch.backends.cuda.enable_mem_efficient_sdp(False)
53
+ dtype = torch.float32
54
+ save_barrier = 1.01
55
+ warmup_percent = 0.01
56
+ percentile_clipping = 96 #97
57
+ betta2 = 0.999
58
+ eps = 1e-7
59
+ clip_grad_norm = 1.0
60
+ limit = 0
61
+ checkpoints_folder = ""
62
+ mixed_precision = "no"
63
+ gradient_accumulation_steps = 1
64
+
65
+ accelerator = Accelerator(
66
+ mixed_precision=mixed_precision,
67
+ gradient_accumulation_steps=gradient_accumulation_steps
68
+ )
69
+ device = accelerator.device
70
+
71
+ # Параметры для диффузии
72
+ n_diffusion_steps = 40
73
+ samples_to_generate = 12
74
+ guidance_scale = 4
75
+
76
+ # Папки для сохранения результатов
77
+ generated_folder = "samples"
78
+ os.makedirs(generated_folder, exist_ok=True)
79
+
80
+ # Настройка seed
81
+ current_date = datetime.now()
82
+ seed = int(current_date.strftime("%Y%m%d"))
83
+ if fixed_seed:
84
+ torch.manual_seed(seed)
85
+ np.random.seed(seed)
86
+ random.seed(seed)
87
+ if torch.cuda.is_available():
88
+ torch.cuda.manual_seed_all(seed)
89
+
90
+ # --------------------------- Параметры LoRA ---------------------------
91
+ lora_name = ""
92
+ lora_rank = 32
93
+ lora_alpha = 64
94
+
95
+ print("init")
96
+
97
+ loss_ratios = {
98
+ "mse": 1.,
99
+ }
100
+ median_coeff_steps = 256
101
+
102
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
103
+ class MedianLossNormalizer:
104
+ def __init__(self, desired_ratios: dict, window_steps: int):
105
+ # нормируем доли на случай, если сумма != 1
106
+ s = sum(desired_ratios.values())
107
+ self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
108
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
109
+ self.window = window_steps
110
+
111
+ def update_and_total(self, losses: dict):
112
+ """
113
+ losses: dict ключ->тензор (значения лоссов)
114
+ Поведение:
115
+ - буферим ABS(l) только для активных (ratio>0) лоссов
116
+ - coeff = ratio / median(abs(loss))
117
+ - total = sum(coeff * loss) по активным лоссам
118
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
119
+ """
120
+ # буферим только активные лоссы
121
+ for k, v in losses.items():
122
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
123
+ self.buffers[k].append(float(v.detach().abs().cpu()))
124
+
125
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
126
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
127
+
128
+ # суммируем только по активным (ratio>0)
129
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
130
+ return total, coeffs, meds
131
+
132
+ # создаём normalizer после определения loss_ratios
133
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
134
+
135
+ # --------------------------- Инициализация WandB ---------------------------
136
+ if accelerator.is_main_process:
137
+ if use_wandb:
138
+ wandb.init(project=project+lora_name, config={
139
+ "batch_size": batch_size,
140
+ "base_learning_rate": base_learning_rate,
141
+ "num_epochs": num_epochs,
142
+ "optimizer_type": optimizer_type,
143
+ })
144
+ if use_comet_ml:
145
+ from comet_ml import Experiment
146
+ comet_experiment = Experiment(
147
+ api_key=comet_ml_api_key,
148
+ project_name=project,
149
+ workspace=comet_ml_workspace
150
+ )
151
+ hyper_params = {
152
+ "batch_size": batch_size,
153
+ "base_learning_rate": base_learning_rate,
154
+ "num_epochs": num_epochs,
155
+ }
156
+ comet_experiment.log_parameters(hyper_params)
157
+
158
+ # Включение Flash Attention 2/SDPA
159
+ torch.backends.cuda.enable_flash_sdp(True)
160
+
161
+ # --------------------------- Загрузка моделей ---------------------------
162
+ vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
163
+ tokenizer = AutoTokenizer.from_pretrained("tokenizer")
164
+ text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
165
+
166
+ # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
167
+ def encode_texts(texts, max_length=max_length):
168
+ # Если тексты пустые (для unconditional), создаем заглушки
169
+ if texts is None:
170
+ # В случае None возвращаем нули (логика для get_negative_embedding)
171
+ # Но здесь мы обычно ожидаем список строк.
172
+ pass
173
+
174
+ with torch.no_grad():
175
+ if isinstance(texts, str):
176
+ texts = [texts]
177
+
178
+ for i, prompt_item in enumerate(texts):
179
+ messages = [
180
+ {"role": "user", "content": prompt_item},
181
+ ]
182
+ prompt_item = tokenizer.apply_chat_template(
183
+ messages,
184
+ tokenize=False,
185
+ add_generation_prompt=True,
186
+ #enable_thinking=True,
187
+ )
188
+ #print(prompt_item+"\n")
189
+ texts[i] = prompt_item
190
+
191
+ toks = tokenizer(
192
+ texts,
193
+ return_tensors="pt",
194
+ padding="max_length",
195
+ truncation=True,
196
+ max_length=max_length
197
+ ).to(device)
198
+
199
+ outs = text_model(**toks, output_hidden_states=True, return_dict=True)
200
+
201
+ # Используем last_hidden_state или hidden_states[-1] (если Qwen, лучше last_hidden_state - прим человека: ХУЙ)
202
+ hidden = outs.hidden_states[-2]
203
+
204
+ # 2. Маска внимания
205
+ attention_mask = toks["attention_mask"]
206
+
207
+ # 3. Пулинг-эмбеддинг (Последний токен)
208
+ sequence_lengths = attention_mask.sum(dim=1) - 1
209
+ batch_size = hidden.shape[0]
210
+ pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
211
+
212
+ #return hidden, attention_mask
213
+ # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
214
+ # 1. Расширяем пулинг-вектор до последовательности [B, 1, emb]
215
+ pooled_expanded = pooled.unsqueeze(1)
216
+
217
+ # 2. Объединяем последовательность токенов и пулинг-вектор
218
+ # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
219
+ # Теперь: [B, 1 + L, emb]. Пулинг стал токеном в НАЧАЛЕ.
220
+ new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
221
+
222
+ # 3. Обновляем маску внимания для нового токена
223
+ # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
224
+ # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
225
+ new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
226
+
227
+ return new_encoder_hidden_states, new_attention_mask
228
+
229
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
230
+ if shift_factor is None: shift_factor = 0.0
231
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
232
+ if scaling_factor is None: scaling_factor = 1.0
233
+
234
+ from diffusers import FlowMatchEulerDiscreteScheduler
235
+ num_train_timesteps = 1000
236
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps)
237
+
238
+ class DistributedResolutionBatchSampler(Sampler):
239
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
240
+ self.dataset = dataset
241
+ self.batch_size = max(1, batch_size // num_replicas)
242
+ self.num_replicas = num_replicas
243
+ self.rank = rank
244
+ self.shuffle = shuffle
245
+ self.drop_last = drop_last
246
+ self.epoch = 0
247
+
248
+ try:
249
+ widths = np.array(dataset["width"])
250
+ heights = np.array(dataset["height"])
251
+ except KeyError:
252
+ widths = np.zeros(len(dataset))
253
+ heights = np.zeros(len(dataset))
254
+
255
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
256
+ self.size_groups = {}
257
+ for w, h in self.size_keys:
258
+ mask = (widths == w) & (heights == h)
259
+ self.size_groups[(w, h)] = np.where(mask)[0]
260
+
261
+ self.group_num_batches = {}
262
+ total_batches = 0
263
+ for size, indices in self.size_groups.items():
264
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
265
+ self.group_num_batches[size] = num_full_batches
266
+ total_batches += num_full_batches
267
+
268
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
269
+
270
+ def __iter__(self):
271
+ if torch.cuda.is_available():
272
+ torch.cuda.empty_cache()
273
+ all_batches = []
274
+ rng = np.random.RandomState(self.epoch)
275
+
276
+ for size, indices in self.size_groups.items():
277
+ indices = indices.copy()
278
+ if self.shuffle:
279
+ rng.shuffle(indices)
280
+ num_full_batches = self.group_num_batches[size]
281
+ if num_full_batches == 0:
282
+ continue
283
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
284
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
285
+ start_idx = self.rank * self.batch_size
286
+ end_idx = start_idx + self.batch_size
287
+ gpu_batches = batches[:, start_idx:end_idx]
288
+ all_batches.extend(gpu_batches)
289
+
290
+ if self.shuffle:
291
+ rng.shuffle(all_batches)
292
+ accelerator.wait_for_everyone()
293
+ return iter(all_batches)
294
+
295
+ def __len__(self):
296
+ return self.num_batches
297
+
298
+ def set_epoch(self, epoch):
299
+ self.epoch = epoch
300
+
301
+ # --- [UPDATED] Функция для фиксированных семплов ---
302
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
303
+ size_groups = defaultdict(list)
304
+ try:
305
+ widths = dataset["width"]
306
+ heights = dataset["height"]
307
+ except KeyError:
308
+ widths = [0] * len(dataset)
309
+ heights = [0] * len(dataset)
310
+ for i, (w, h) in enumerate(zip(widths, heights)):
311
+ size = (w, h)
312
+ size_groups[size].append(i)
313
+
314
+ fixed_samples = {}
315
+ for size, indices in size_groups.items():
316
+ n_samples = min(samples_per_group, len(indices))
317
+ if len(size_groups)==1:
318
+ n_samples = samples_to_generate
319
+ if n_samples == 0:
320
+ continue
321
+ sample_indices = random.sample(indices, n_samples)
322
+ samples_data = [dataset[idx] for idx in sample_indices]
323
+
324
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
325
+ texts = [item["text"] for item in samples_data]
326
+
327
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
328
+ embeddings, masks = encode_texts(texts)
329
+
330
+ fixed_samples[size] = (latents, embeddings, masks, texts)
331
+
332
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
333
+ return fixed_samples
334
+
335
+ if limit > 0:
336
+ dataset = load_from_disk(ds_path).select(range(limit))
337
+ else:
338
+ dataset = load_from_disk(ds_path)
339
+
340
+ # --- [UPDATED] Collate Function ---
341
+ def collate_fn_simple(batch):
342
+ # 1. Латенты (VAE)
343
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype)
344
+
345
+ # 2. Текст берем сырой из датасета
346
+ raw_texts = [item["text"] for item in batch]
347
+ texts = [
348
+ "" if t.lower().startswith("zero")
349
+ else "" if random.random() < 0.05
350
+ else t[1:].lstrip() if t.startswith(".")
351
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
352
+ for t in raw_texts
353
+ ]
354
+
355
+ # 3. Кодируем на лету
356
+ # Возвращает: hidden (B, L, D), mask (B, L)
357
+ embeddings, attention_mask = encode_texts(texts)
358
+
359
+ # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
360
+ attention_mask = attention_mask.to(dtype=torch.int64)
361
+
362
+ return latents, embeddings, attention_mask
363
+
364
+ batch_sampler = DistributedResolutionBatchSampler(
365
+ dataset=dataset,
366
+ batch_size=batch_size,
367
+ num_replicas=accelerator.num_processes,
368
+ rank=accelerator.process_index,
369
+ shuffle=shuffle
370
+ )
371
+
372
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
373
+ print("Total samples", len(dataloader))
374
+ dataloader = accelerator.prepare(dataloader)
375
+
376
+ start_epoch = 0
377
+ global_step = 0
378
+ total_training_steps = (len(dataloader) * num_epochs)
379
+ world_size = accelerator.state.num_processes
380
+
381
+ # Загрузка UNet
382
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
383
+ if os.path.isdir(latest_checkpoint):
384
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
385
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
386
+ if unet_gradient:
387
+ unet.enable_gradient_checkpointing()
388
+ unet.set_use_memory_efficient_attention_xformers(False)
389
+ try:
390
+ unet.set_attn_processor(AttnProcessor2_0())
391
+ except Exception as e:
392
+ print(f"Ошибка при включении SDPA: {e}")
393
+ unet.set_use_memory_efficient_attention_xformers(True)
394
+ else:
395
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
396
+
397
+ if lora_name:
398
+ # ... (Код LoRA без изменений, опущен для краткости, если не используется, иначе раскомментируйте оригинальный блок) ...
399
+ pass
400
+
401
+ # Оптимизатор
402
+ if lora_name:
403
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
404
+ else:
405
+ if fbp:
406
+ trainable_params = list(unet.parameters())
407
+
408
+ def create_optimizer(name, params):
409
+ if name == "adam8bit":
410
+ return bnb.optim.AdamW8bit(
411
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
412
+ percentile_clipping=percentile_clipping
413
+ )
414
+ elif name == "adam":
415
+ return torch.optim.AdamW(
416
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
417
+ )
418
+ elif name == "muon":
419
+ from muon import MuonWithAuxAdam
420
+ trainable_params = [p for p in params if p.requires_grad]
421
+ hidden_weights = [p for p in trainable_params if p.ndim >= 2]
422
+ hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
423
+
424
+ param_groups = [
425
+ dict(params=hidden_weights, use_muon=True,
426
+ lr=1e-3, weight_decay=1e-4),
427
+ dict(params=hidden_gains_biases, use_muon=False,
428
+ lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
429
+ ]
430
+ optimizer = MuonWithAuxAdam(param_groups)
431
+ from snooc import SnooC
432
+ return SnooC(optimizer)
433
+ else:
434
+ raise ValueError(f"Unknown optimizer: {name}")
435
+
436
+ if fbp:
437
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
438
+ def optimizer_hook(param):
439
+ optimizer_dict[param].step()
440
+ optimizer_dict[param].zero_grad(set_to_none=True)
441
+ for param in trainable_params:
442
+ param.register_post_accumulate_grad_hook(optimizer_hook)
443
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
444
+ else:
445
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
446
+ def lr_schedule(step):
447
+ x = step / (total_training_steps * world_size)
448
+ warmup = warmup_percent
449
+ if not use_decay:
450
+ return base_learning_rate
451
+ if x < warmup:
452
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
453
+ decay_ratio = (x - warmup) / (1 - warmup)
454
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
455
+ (1 + math.cos(math.pi * decay_ratio))
456
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
457
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
458
+
459
+ if torch_compile:
460
+ print("compiling")
461
+ unet = torch.compile(unet)
462
+ print("compiling - ok")
463
+
464
+ # Фиксированные семплы
465
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
466
+
467
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
468
+ def get_negative_embedding(neg_prompt="", batch_size=1):
469
+ if not neg_prompt:
470
+ hidden_dim = 2048
471
+ seq_len = max_length
472
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
473
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
474
+ return empty_emb, empty_mask
475
+
476
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
477
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
478
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
479
+
480
+ return uncond_emb, uncond_mask
481
+
482
+ # Получаем негативные (пустые) условия для валидации
483
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
484
+
485
+ # --- Функция генерации семплов ---
486
+ @torch.compiler.disable()
487
+ @torch.no_grad()
488
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
489
+ uncond_emb, uncond_mask = uncond_data
490
+
491
+ original_model = None
492
+ try:
493
+ if not torch_compile:
494
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
495
+ else:
496
+ original_model = unet.eval()
497
+
498
+ vae.to(device=device).eval()
499
+
500
+ all_generated_images = []
501
+ all_captions = []
502
+
503
+ # Распаковываем 5 элементов (добавились mask)
504
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
505
+ width, height = size
506
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
507
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
508
+ sample_mask = sample_mask.to(device=device)
509
+
510
+ latents = torch.randn(
511
+ sample_latents.shape,
512
+ device=device,
513
+ dtype=sample_latents.dtype,
514
+ generator=torch.Generator(device=device).manual_seed(seed)
515
+ )
516
+
517
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
518
+
519
+ for t in scheduler.timesteps:
520
+ if guidance_scale != 1:
521
+ latent_model_input = torch.cat([latents, latents], dim=0)
522
+
523
+ # Подготовка батчей для CFG (Negative + Positive)
524
+ # 1. Embeddings
525
+ curr_batch_size = sample_text_embeddings.shape[0]
526
+ seq_len = sample_text_embeddings.shape[1]
527
+ hidden_dim = sample_text_embeddings.shape[2]
528
+
529
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
530
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
531
+
532
+ # 2. Masks
533
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
534
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
535
+
536
+ else:
537
+ latent_model_input = latents
538
+ text_embeddings_batch = sample_text_embeddings
539
+ attention_mask_batch = sample_mask
540
+
541
+ # Предсказание с передачей всех условий
542
+ model_out = original_model(
543
+ latent_model_input,
544
+ t,
545
+ encoder_hidden_states=text_embeddings_batch,
546
+ encoder_attention_mask=attention_mask_batch,
547
+ )
548
+ flow = getattr(model_out, "sample", model_out)
549
+
550
+ if guidance_scale != 1:
551
+ flow_uncond, flow_cond = flow.chunk(2)
552
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
553
+
554
+ latents = scheduler.step(flow, t, latents).prev_sample
555
+
556
+ current_latents = latents
557
+
558
+ latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
559
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
560
+ decoded_fp32 = decoded.to(torch.float32)
561
+
562
+ for img_idx, img_tensor in enumerate(decoded_fp32):
563
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
564
+ img = img.transpose(1, 2, 0)
565
+
566
+ if np.isnan(img).any():
567
+ print("NaNs found, saving stopped! Step:", step)
568
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
569
+
570
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
571
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
572
+ max_w_overall = max(255, max_w_overall)
573
+ max_h_overall = max(255, max_h_overall)
574
+
575
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
576
+ all_generated_images.append(padded_img)
577
+
578
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
579
+ all_captions.append(caption_text)
580
+
581
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
582
+ pil_img.save(sample_path, "JPEG", quality=96)
583
+
584
+ if use_wandb and accelerator.is_main_process:
585
+ wandb_images = [
586
+ wandb.Image(img, caption=f"{all_captions[i]}")
587
+ for i, img in enumerate(all_generated_images)
588
+ ]
589
+ wandb.log({"generated_images": wandb_images})
590
+ if use_comet_ml and accelerator.is_main_process:
591
+ for i, img in enumerate(all_generated_images):
592
+ comet_experiment.log_image(
593
+ image_data=img,
594
+ name=f"step_{step}_img_{i}",
595
+ step=step,
596
+ metadata={"caption": all_captions[i]}
597
+ )
598
+ finally:
599
+ vae.to("cpu")
600
+ torch.cuda.empty_cache()
601
+ gc.collect()
602
+
603
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
604
+ if accelerator.is_main_process:
605
+ if save_model:
606
+ print("Генерация сэмплов до старта обучения...")
607
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
608
+ accelerator.wait_for_everyone()
609
+
610
+ def save_checkpoint(unet, variant=""):
611
+ if accelerator.is_main_process:
612
+ if lora_name:
613
+ save_lora_checkpoint(unet)
614
+ else:
615
+ model_to_save = None
616
+ if not torch_compile:
617
+ model_to_save = accelerator.unwrap_model(unet)
618
+ else:
619
+ model_to_save = unet
620
+
621
+ if variant != "":
622
+ model_to_save.to(dtype=torch.float16).save_pretrained(
623
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
624
+ )
625
+ else:
626
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
627
+
628
+ unet = unet.to(dtype=dtype)
629
+
630
+ # --------------------------- Тренировочный цикл ---------------------------
631
+ if accelerator.is_main_process:
632
+ print(f"Total steps per GPU: {total_training_steps}")
633
+
634
+ epoch_loss_points = []
635
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
636
+
637
+ steps_per_epoch = len(dataloader)
638
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
639
+ min_loss = 2.
640
+
641
+ for epoch in range(start_epoch, start_epoch + num_epochs):
642
+ batch_losses = []
643
+ batch_grads = []
644
+ batch_sampler.set_epoch(epoch)
645
+ accelerator.wait_for_everyone()
646
+ unet.train()
647
+
648
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
649
+ with accelerator.accumulate(unet):
650
+ if save_model == False and step == 5 :
651
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
652
+ print(f"Шаг {step}: {used_gb:.2f} GB")
653
+
654
+ # шум
655
+ noise = torch.randn_like(latents, dtype=latents.dtype)
656
+ # берём t из [0, 1]
657
+ t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
658
+ # интерполяция между x0 и шумом
659
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
660
+ # делаем integer timesteps для UNet
661
+ timesteps = (t * scheduler.config.num_train_timesteps).long()
662
+
663
+ # --- Вызов UNet с маской ---
664
+ model_pred = unet(
665
+ noisy_latents,
666
+ timesteps,
667
+ encoder_hidden_states=embeddings,
668
+ encoder_attention_mask=attention_mask
669
+ ).sample
670
+
671
+ target = noise - latents
672
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
673
+ batch_losses.append(mse_loss.detach().item())
674
+
675
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
676
+ accelerator.wait_for_everyone()
677
+
678
+ losses_dict = {}
679
+ losses_dict["mse"] = mse_loss
680
+
681
+ # === Нормализация всех лоссов ===
682
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
683
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
684
+
685
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
686
+ accelerator.wait_for_everyone()
687
+
688
+ accelerator.backward(total_loss)
689
+
690
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
691
+ accelerator.wait_for_everyone()
692
+
693
+ grad = 0.0
694
+ if not fbp:
695
+ if accelerator.sync_gradients:
696
+ #with torch.amp.autocast('cuda', enabled=False):
697
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
698
+ grad = float(grad_val)
699
+ optimizer.step()
700
+ lr_scheduler.step()
701
+ optimizer.zero_grad(set_to_none=True)
702
+
703
+ if accelerator.sync_gradients:
704
+ global_step += 1
705
+ progress_bar.update(1)
706
+ if accelerator.is_main_process:
707
+ if fbp:
708
+ current_lr = base_learning_rate
709
+ else:
710
+ current_lr = lr_scheduler.get_last_lr()[0]
711
+ batch_grads.append(grad)
712
+
713
+ log_data = {}
714
+ log_data["loss"] = mse_loss.detach().item()
715
+ log_data["lr"] = current_lr
716
+ log_data["grad"] = grad
717
+ log_data["loss_total"] = float(total_loss.item())
718
+ for k, c in coeffs.items():
719
+ log_data[f"coeff_{k}"] = float(c)
720
+ if accelerator.sync_gradients:
721
+ if use_wandb:
722
+ wandb.log(log_data, step=global_step)
723
+ if use_comet_ml:
724
+ comet_experiment.log_metrics(log_data, step=global_step)
725
+
726
+ if global_step % sample_interval == 0:
727
+ # Передаем tuple (emb, mask) для негатива
728
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
729
+ last_n = sample_interval
730
+
731
+ if save_model:
732
+ has_losses = len(batch_losses) > 0
733
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
734
+ last_loss = batch_losses[-1] if has_losses else 0.0
735
+ max_loss = max(avg_sample_loss, last_loss)
736
+ should_save = max_loss < min_loss * save_barrier
737
+ print(
738
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
739
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
740
+ )
741
+ # 6. Сохранение и обновление
742
+ if should_save:
743
+ min_loss = max_loss
744
+ save_checkpoint(unet)
745
+
746
+ if accelerator.is_main_process:
747
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
748
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
749
+
750
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
751
+ log_data_ep = {
752
+ "epoch_loss": avg_epoch_loss,
753
+ "epoch_grad": avg_epoch_grad,
754
+ "epoch": epoch + 1,
755
+ }
756
+ if use_wandb:
757
+ wandb.log(log_data_ep)
758
+ if use_comet_ml:
759
+ comet_experiment.log_metrics(log_data_ep)
760
+
761
+ if accelerator.is_main_process:
762
+ print("Обучение завершено! Сохраняем финальную модель...")
763
+ if save_model:
764
+ save_checkpoint(unet,"fp16")
765
+ if use_comet_ml:
766
+ comet_experiment.end()
767
+ accelerator.free_memory()
768
+ if torch.distributed.is_initialized():
769
+ torch.distributed.destroy_process_group()
770
+
771
+ print("Готово!")
train.py CHANGED
@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader, Sampler
8
  from torch.utils.data.distributed import DistributedSampler
9
  from torch.optim.lr_scheduler import LambdaLR
10
  from collections import defaultdict
11
- from diffusers import UNet2DConditionModel, AutoencoderKL
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
@@ -30,7 +30,7 @@ from transformers import AutoTokenizer, AutoModel
30
  ds_path = "/workspace/sdxs/datasets/768"
31
  project = "unet"
32
  batch_size = 36
33
- base_learning_rate = 2.7e-5 #4e-5
34
  min_learning_rate = 1e-5 #2.7e-5
35
  num_epochs = 80
36
  sample_interval_share = 5
@@ -49,7 +49,7 @@ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
49
  comet_ml_workspace = "recoilme"
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
- #torch.backends.cuda.enable_mem_efficient_sdp(False)
53
  dtype = torch.float32
54
  save_barrier = 1.01
55
  warmup_percent = 0.01
@@ -71,7 +71,7 @@ device = accelerator.device
71
  # Параметры для диффузии
72
  n_diffusion_steps = 40
73
  samples_to_generate = 12
74
- guidance_scale = 4
75
 
76
  # Папки для сохранения результатов
77
  generated_folder = "samples"
@@ -95,9 +95,10 @@ lora_alpha = 64
95
  print("init")
96
 
97
  loss_ratios = {
98
- "mse": 1.,
 
99
  }
100
- median_coeff_steps = 128
101
 
102
  # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
103
  class MedianLossNormalizer:
@@ -120,7 +121,9 @@ class MedianLossNormalizer:
120
  # буферим только активные лоссы
121
  for k, v in losses.items():
122
  if k in self.buffers and self.ratios.get(k, 0) > 0:
123
- self.buffers[k].append(float(v.detach().abs().cpu()))
 
 
124
 
125
  meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
126
  coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
@@ -160,6 +163,7 @@ torch.backends.cuda.enable_flash_sdp(True)
160
 
161
  # --------------------------- Загрузка моделей ---------------------------
162
  vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
 
163
  tokenizer = AutoTokenizer.from_pretrained("tokenizer")
164
  text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
165
 
@@ -346,7 +350,7 @@ def collate_fn_simple(batch):
346
  raw_texts = [item["text"] for item in batch]
347
  texts = [
348
  "" if t.lower().startswith("zero")
349
- else "" if random.random() < 0.05
350
  else t[1:].lstrip() if t.startswith(".")
351
  else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
352
  for t in raw_texts
@@ -654,7 +658,10 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
654
  # шум
655
  noise = torch.randn_like(latents, dtype=latents.dtype)
656
  # берём t из [0, 1]
657
- t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
 
 
 
658
  # интерполяция между x0 и шумом
659
  noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
660
  # делаем integer timesteps для UNet
@@ -668,8 +675,10 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
668
  encoder_attention_mask=attention_mask
669
  ).sample
670
 
671
- target = noise - latents
 
672
  mse_loss = F.mse_loss(model_pred.float(), target.float())
 
673
  batch_losses.append(mse_loss.detach().item())
674
 
675
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
@@ -677,6 +686,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
677
 
678
  losses_dict = {}
679
  losses_dict["mse"] = mse_loss
 
680
 
681
  # === Нормализация всех лоссов ===
682
  abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
@@ -711,10 +721,11 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
711
  batch_grads.append(grad)
712
 
713
  log_data = {}
714
- log_data["loss"] = mse_loss.detach().item()
 
715
  log_data["lr"] = current_lr
716
  log_data["grad"] = grad
717
- log_data["loss_total"] = float(total_loss.item())
718
  for k, c in coeffs.items():
719
  log_data[f"coeff_{k}"] = float(c)
720
  if accelerator.sync_gradients:
@@ -725,7 +736,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
725
 
726
  if global_step % sample_interval == 0:
727
  # Передаем tuple (emb, mask) для негатива
728
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
 
729
  last_n = sample_interval
730
 
731
  if save_model:
 
8
  from torch.utils.data.distributed import DistributedSampler
9
  from torch.optim.lr_scheduler import LambdaLR
10
  from collections import defaultdict
11
+ from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
 
30
  ds_path = "/workspace/sdxs/datasets/768"
31
  project = "unet"
32
  batch_size = 36
33
+ base_learning_rate = 2.7e-5
34
  min_learning_rate = 1e-5 #2.7e-5
35
  num_epochs = 80
36
  sample_interval_share = 5
 
49
  comet_ml_workspace = "recoilme"
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
53
  dtype = torch.float32
54
  save_barrier = 1.01
55
  warmup_percent = 0.01
 
71
  # Параметры для диффузии
72
  n_diffusion_steps = 40
73
  samples_to_generate = 12
74
+ guidance_scale = 2
75
 
76
  # Папки для сохранения результатов
77
  generated_folder = "samples"
 
95
  print("init")
96
 
97
  loss_ratios = {
98
+ "mse": 0.8,
99
+ "mae": 0.2,
100
  }
101
+ median_coeff_steps = 256
102
 
103
  # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
104
  class MedianLossNormalizer:
 
121
  # буферим только активные лоссы
122
  for k, v in losses.items():
123
  if k in self.buffers and self.ratios.get(k, 0) > 0:
124
+ val = v.detach().abs().mean().cpu().item() # .item() лучше float() для тензоров
125
+ self.buffers[k].append(val)
126
+ #self.buffers[k].append(float(v.detach().abs().cpu()))
127
 
128
  meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
129
  coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
 
163
 
164
  # --------------------------- Загрузка моделей ---------------------------
165
  vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
166
+ #vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
167
  tokenizer = AutoTokenizer.from_pretrained("tokenizer")
168
  text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
169
 
 
350
  raw_texts = [item["text"] for item in batch]
351
  texts = [
352
  "" if t.lower().startswith("zero")
353
+ else "" if random.random() < 0.1
354
  else t[1:].lstrip() if t.startswith(".")
355
  else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
356
  for t in raw_texts
 
658
  # шум
659
  noise = torch.randn_like(latents, dtype=latents.dtype)
660
  # берём t из [0, 1]
661
+ #t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
662
+ u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
663
+ t = torch.sigmoid(torch.randn_like(u))
664
+
665
  # интерполяция между x0 и шумом
666
  noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
667
  # делаем integer timesteps для UNet
 
675
  encoder_attention_mask=attention_mask
676
  ).sample
677
 
678
+ target = noise - latents
679
+
680
  mse_loss = F.mse_loss(model_pred.float(), target.float())
681
+ mae_loss = F.l1_loss(model_pred.float(), target.float())
682
  batch_losses.append(mse_loss.detach().item())
683
 
684
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
 
686
 
687
  losses_dict = {}
688
  losses_dict["mse"] = mse_loss
689
+ losses_dict["mae"] = mae_loss
690
 
691
  # === Нормализация всех лоссов ===
692
  abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
 
721
  batch_grads.append(grad)
722
 
723
  log_data = {}
724
+ log_data["loss_mse"] = mse_loss.detach().item()
725
+ log_data["loss_mae"] = mae_loss.detach().item()
726
  log_data["lr"] = current_lr
727
  log_data["grad"] = grad
728
+ log_data["loss_norm"] = float(total_loss.item())
729
  for k, c in coeffs.items():
730
  log_data[f"coeff_{k}"] = float(c)
731
  if accelerator.sync_gradients:
 
736
 
737
  if global_step % sample_interval == 0:
738
  # Передаем tuple (emb, mask) для негатива
739
+ if epoch % 10 == 0:
740
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
741
  last_n = sample_interval
742
 
743
  if save_model:
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18ba499f21e7872f47137bc29a4ae8f8d8d48d7b7023fc9bd68b9fb80408be81
3
  size 7444321360
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c9cc2753742ec7627ccddab26e135a2016ac587d40a6c699ee4c7058e5e5107
3
  size 7444321360