recoilme commited on
Commit
c0d2dc5
·
1 Parent(s): 9fee47d
samples/unet_192x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 00bf96924be80bc2a3656bb36657dc7492c3babcc1a56f41c7b2a747605624f8
  • Pointer size: 130 Bytes
  • Size of remote file: 29.5 kB

Git LFS Details

  • SHA256: 083d5ba0480de202e0271a98263773ec37ac536108af0ca9c4ccb9f8fa184a3d
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
samples/unet_256x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 75eadb9af4e3814c044a7512900438b119ff5c7bd34bb5cea3651109a52d7d50
  • Pointer size: 130 Bytes
  • Size of remote file: 37.4 kB

Git LFS Details

  • SHA256: 85d6827f291dc5a57172534a3b6d6abf0bad7752402410bd79a0140031700c51
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
samples/unet_320x192_0.jpg CHANGED

Git LFS Details

  • SHA256: a53e366130a9efef73b77a5638745fb0c7972730e496327388362676f7aa85d9
  • Pointer size: 130 Bytes
  • Size of remote file: 11.1 kB

Git LFS Details

  • SHA256: 8cdc04d8198e97cf795bf650768c3ad8ceaedeee7a066664e1721e188ea60374
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
samples/unet_320x256_0.jpg CHANGED

Git LFS Details

  • SHA256: 0148b2b4fdcd99a420bc6ad84e0e441879ffc885bd367834f517f9abfcc0f493
  • Pointer size: 130 Bytes
  • Size of remote file: 18.5 kB

Git LFS Details

  • SHA256: e75834987b0ee7d80dbdc92ebfba377756bba33b2a41f64bd8818c31b4f7246f
  • Pointer size: 130 Bytes
  • Size of remote file: 65.1 kB
samples/unet_320x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 5a3f8abbd017e52f53dac630997d2380288fc6a8d415e6a6b1b7517595174e7d
  • Pointer size: 130 Bytes
  • Size of remote file: 24.2 kB

Git LFS Details

  • SHA256: ca8fdb2f3556005f78d2eae6cf0411abb3152b3a46a1766d4c5df6c0916e8f1c
  • Pointer size: 131 Bytes
  • Size of remote file: 134 kB
sdxs_08b/train.py DELETED
@@ -1,798 +0,0 @@
1
- #from comet_ml import Experiment
2
- import os
3
- import math
4
- import torch
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- from torch.utils.data import DataLoader, Sampler
8
- from torch.utils.data.distributed import DistributedSampler
9
- from torch.optim.lr_scheduler import LambdaLR
10
- from collections import defaultdict
11
- from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image, ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
- from collections import deque
27
- from transformers import AutoTokenizer, AutoModel
28
-
29
- # --------------------------- Параметры ---------------------------
30
- ds_path = "/workspace/sdxs/datasets/mjnj"
31
- project = "sdxs_08b"
32
- batch_size = 128
33
- base_learning_rate = 4e-5 #2.7e-5
34
- min_learning_rate = 1e-5 #2.7e-5
35
- num_epochs = 15
36
- sample_interval_share = 3
37
- cfg_dropout = 0.25
38
- max_length = 192
39
- use_wandb = True
40
- use_comet_ml = False
41
- save_model = True
42
- use_decay = True
43
- fbp = False
44
- optimizer_type = "adam8bit"
45
- torch_compile = False
46
- unet_gradient = True
47
- fixed_seed = False
48
- shuffle = True
49
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
50
- comet_ml_workspace = "recoilme"
51
- torch.backends.cuda.matmul.allow_tf32 = True
52
- torch.backends.cudnn.allow_tf32 = True
53
- torch.backends.cuda.enable_mem_efficient_sdp(False)
54
- dtype = torch.float32
55
- save_barrier = 1.01
56
- warmup_percent = 0.01
57
- percentile_clipping = 95 #96 #97
58
- betta2 = 0.995
59
- eps = 1e-7
60
- clip_grad_norm = 1.0
61
- limit = 0
62
- checkpoints_folder = ""
63
- mixed_precision = "no"
64
- gradient_accumulation_steps = 1
65
-
66
- accelerator = Accelerator(
67
- mixed_precision=mixed_precision,
68
- gradient_accumulation_steps=gradient_accumulation_steps
69
- )
70
- device = accelerator.device
71
-
72
- # Параметры для диффузии
73
- n_diffusion_steps = 40
74
- samples_to_generate = 12
75
- guidance_scale = 4
76
-
77
- # Папки для сохранения результатов
78
- generated_folder = "samples"
79
- os.makedirs(generated_folder, exist_ok=True)
80
-
81
- # Настройка seed
82
- current_date = datetime.now()
83
- seed = int(current_date.strftime("%Y%m%d"))
84
- if fixed_seed:
85
- torch.manual_seed(seed)
86
- np.random.seed(seed)
87
- random.seed(seed)
88
- if torch.cuda.is_available():
89
- torch.cuda.manual_seed_all(seed)
90
-
91
- # --------------------------- Параметры LoRA ---------------------------
92
- lora_name = ""
93
- lora_rank = 32
94
- lora_alpha = 64
95
-
96
- print("init")
97
-
98
- loss_ratios = {
99
- "mse": 1.5,
100
- "mae": 0.5,
101
- }
102
- median_coeff_steps = 256
103
-
104
- # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
105
- class MedianLossNormalizer:
106
- def __init__(self, desired_ratios: dict, window_steps: int):
107
- # нормируем доли на случай, если сумма != 1
108
- #s = sum(desired_ratios.values())
109
- #self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
110
- self.ratios = {k: float(v) for k, v in desired_ratios.items()}
111
- self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
112
- self.window = window_steps
113
-
114
- def update_and_total(self, losses: dict):
115
- """
116
- losses: dict ключ->тензор (значения лоссов)
117
- Поведение:
118
- - буферим ABS(l) только для активных (ratio>0) лоссов
119
- - coeff = ratio / median(abs(loss))
120
- - total = sum(coeff * loss) по активным лоссам
121
- CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
122
- """
123
- # буферим только активные лоссы
124
- for k, v in losses.items():
125
- if k in self.buffers and self.ratios.get(k, 0) > 0:
126
- val = v.detach().abs().mean().cpu().item() # .item() лучше float() для тензоров
127
- self.buffers[k].append(val)
128
- #self.buffers[k].append(float(v.detach().abs().cpu()))
129
-
130
- meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
131
- coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
132
-
133
- # суммируем только по активным (ratio>0)
134
- total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
135
- return total, coeffs, meds
136
-
137
- # создаём normalizer после определения loss_ratios
138
- normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
139
-
140
- # --------------------------- Инициализация WandB ---------------------------
141
- if accelerator.is_main_process:
142
- if use_wandb:
143
- wandb.init(project=project+lora_name, config={
144
- "batch_size": batch_size,
145
- "base_learning_rate": base_learning_rate,
146
- "num_epochs": num_epochs,
147
- "optimizer_type": optimizer_type,
148
- })
149
- if use_comet_ml:
150
- from comet_ml import Experiment
151
- comet_experiment = Experiment(
152
- api_key=comet_ml_api_key,
153
- project_name=project,
154
- workspace=comet_ml_workspace
155
- )
156
- hyper_params = {
157
- "batch_size": batch_size,
158
- "base_learning_rate": base_learning_rate,
159
- "num_epochs": num_epochs,
160
- }
161
- comet_experiment.log_parameters(hyper_params)
162
-
163
- # Включение Flash Attention 2/SDPA
164
- torch.backends.cuda.enable_flash_sdp(True)
165
-
166
- # --------------------------- Загрузка моделей ---------------------------
167
- vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
168
- #vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
169
- tokenizer = AutoTokenizer.from_pretrained("tokenizer")
170
- text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
171
-
172
- # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
173
- def encode_texts(texts, max_length=max_length):
174
- # Если тексты пустые (для unconditional), создаем заглушки
175
- if texts is None:
176
- # В случае None возвращаем нули (логика для get_negative_embedding)
177
- # Но здесь мы обычно ожидаем список строк.
178
- pass
179
-
180
- with torch.no_grad():
181
- if isinstance(texts, str):
182
- texts = [texts]
183
-
184
- for i, prompt_item in enumerate(texts):
185
- messages = [
186
- {"role": "user", "content": prompt_item},
187
- ]
188
- prompt_item = tokenizer.apply_chat_template(
189
- messages,
190
- tokenize=False,
191
- add_generation_prompt=True,
192
- #enable_thinking=True,
193
- )
194
- #print(prompt_item+"\n")
195
- texts[i] = prompt_item
196
-
197
- toks = tokenizer(
198
- texts,
199
- return_tensors="pt",
200
- padding="max_length",
201
- truncation=True,
202
- max_length=max_length
203
- ).to(device)
204
-
205
- outs = text_model(**toks, output_hidden_states=True, return_dict=True)
206
-
207
- # Используем last_hidden_state или hidden_states[-1] (если Qwen, лучше last_hidden_state - прим человека: ХУЙ)
208
- hidden = outs.hidden_states[-2]
209
-
210
- # 2. Маска внимания
211
- attention_mask = toks["attention_mask"]
212
-
213
- # 3. Пулинг-эмбеддинг (Последний токен)
214
- sequence_lengths = attention_mask.sum(dim=1) - 1
215
- batch_size = hidden.shape[0]
216
- pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
217
-
218
- #return hidden, attention_mask
219
- # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
220
- # 1. Расширяем пулинг-вектор до последовательности [B, 1, emb]
221
- pooled_expanded = pooled.unsqueeze(1)
222
-
223
- # 2. Объединяем последовательность токенов и пулинг-вектор
224
- # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
225
- # Теперь: [B, 1 + L, emb]. Пулинг стал токеном в НАЧАЛЕ.
226
- new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
227
-
228
- # 3. Обновляем маску внимания для нового токена
229
- # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
230
- # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
231
- new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
232
-
233
- return new_encoder_hidden_states, new_attention_mask
234
-
235
- shift_factor = getattr(vae.config, "shift_factor", 0.0)
236
- if shift_factor is None: shift_factor = 0.0
237
- scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
238
- if scaling_factor is None: scaling_factor = 1.0
239
-
240
- from diffusers import FlowMatchEulerDiscreteScheduler
241
- num_train_timesteps = 1000
242
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps)
243
-
244
- class DistributedResolutionBatchSampler(Sampler):
245
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
246
- self.dataset = dataset
247
- self.batch_size = max(1, batch_size // num_replicas)
248
- self.num_replicas = num_replicas
249
- self.rank = rank
250
- self.shuffle = shuffle
251
- self.drop_last = drop_last
252
- self.epoch = 0
253
-
254
- try:
255
- widths = np.array(dataset["width"])
256
- heights = np.array(dataset["height"])
257
- except KeyError:
258
- widths = np.zeros(len(dataset))
259
- heights = np.zeros(len(dataset))
260
-
261
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
262
- self.size_groups = {}
263
- for w, h in self.size_keys:
264
- mask = (widths == w) & (heights == h)
265
- self.size_groups[(w, h)] = np.where(mask)[0]
266
-
267
- self.group_num_batches = {}
268
- total_batches = 0
269
- for size, indices in self.size_groups.items():
270
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
271
- self.group_num_batches[size] = num_full_batches
272
- total_batches += num_full_batches
273
-
274
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
275
-
276
- def __iter__(self):
277
- if torch.cuda.is_available():
278
- torch.cuda.empty_cache()
279
- all_batches = []
280
- rng = np.random.RandomState(self.epoch)
281
-
282
- for size, indices in self.size_groups.items():
283
- indices = indices.copy()
284
- if self.shuffle:
285
- rng.shuffle(indices)
286
- num_full_batches = self.group_num_batches[size]
287
- if num_full_batches == 0:
288
- continue
289
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
290
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
291
- start_idx = self.rank * self.batch_size
292
- end_idx = start_idx + self.batch_size
293
- gpu_batches = batches[:, start_idx:end_idx]
294
- all_batches.extend(gpu_batches)
295
-
296
- if self.shuffle:
297
- rng.shuffle(all_batches)
298
- accelerator.wait_for_everyone()
299
- return iter(all_batches)
300
-
301
- def __len__(self):
302
- return self.num_batches
303
-
304
- def set_epoch(self, epoch):
305
- self.epoch = epoch
306
-
307
- # --- [UPDATED] Функция для фиксированных семплов ---
308
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
309
- size_groups = defaultdict(list)
310
- try:
311
- widths = dataset["width"]
312
- heights = dataset["height"]
313
- except KeyError:
314
- widths = [0] * len(dataset)
315
- heights = [0] * len(dataset)
316
- for i, (w, h) in enumerate(zip(widths, heights)):
317
- size = (w, h)
318
- size_groups[size].append(i)
319
-
320
- fixed_samples = {}
321
- for size, indices in size_groups.items():
322
- n_samples = min(samples_per_group, len(indices))
323
- if len(size_groups)==1:
324
- n_samples = samples_to_generate
325
- if n_samples == 0:
326
- continue
327
- sample_indices = random.sample(indices, n_samples)
328
- samples_data = [dataset[idx] for idx in sample_indices]
329
-
330
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
331
- texts = [item["text"] for item in samples_data]
332
-
333
- # Кодируем тексты на лету, чтобы получить маски и пулинг
334
- embeddings, masks = encode_texts(texts)
335
-
336
- fixed_samples[size] = (latents, embeddings, masks, texts)
337
-
338
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
339
- return fixed_samples
340
-
341
- if limit > 0:
342
- dataset = load_from_disk(ds_path).select(range(limit))
343
- else:
344
- dataset = load_from_disk(ds_path)
345
-
346
- dataset = dataset.filter(
347
- lambda x: [not (path.startswith("/workspace/ds/animesfw") or path.startswith("/workspace/ds/d4/animesfw")) for path in x["image_path"]],
348
- batched=True,
349
- batch_size=10000, # обрабатываем по 10к строк за раз
350
- num_proc=8
351
- )
352
- print(f"Осталось примеров после фильтрации: {len(dataset)}")
353
-
354
- # --- [UPDATED] Collate Function ---
355
- def collate_fn_simple(batch):
356
- # 1. Латенты (VAE)
357
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype)
358
-
359
- # 2. Текст берем сырой из датасета
360
- raw_texts = [item["text"] for item in batch]
361
- texts = [
362
- "" if t.lower().startswith("zero")
363
- else "" if random.random() < cfg_dropout
364
- else t[1:].lstrip() if t.startswith(".")
365
- else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
366
- for t in raw_texts
367
- ]
368
-
369
- # 3. Кодируем на лету
370
- # Возвращает: hidden (B, L, D), mask (B, L)
371
- embeddings, attention_mask = encode_texts(texts)
372
-
373
- # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
374
- attention_mask = attention_mask.to(dtype=torch.int64)
375
-
376
- return latents, embeddings, attention_mask
377
-
378
- batch_sampler = DistributedResolutionBatchSampler(
379
- dataset=dataset,
380
- batch_size=batch_size,
381
- num_replicas=accelerator.num_processes,
382
- rank=accelerator.process_index,
383
- shuffle=shuffle
384
- )
385
-
386
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
387
- if accelerator.is_main_process:
388
- print("Total samples", len(dataloader))
389
- dataloader = accelerator.prepare(dataloader)
390
-
391
- start_epoch = 0
392
- global_step = 0
393
- total_training_steps = (len(dataloader) * num_epochs)
394
- world_size = accelerator.state.num_processes
395
-
396
- # Загрузка UNet
397
- latest_checkpoint = os.path.join(checkpoints_folder, project)
398
- if os.path.isdir(latest_checkpoint):
399
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
400
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
401
- if unet_gradient:
402
- unet.enable_gradient_checkpointing()
403
- unet.set_use_memory_efficient_attention_xformers(False)
404
- try:
405
- unet.set_attn_processor(AttnProcessor2_0())
406
- except Exception as e:
407
- print(f"Ошибка при включении SDPA: {e}")
408
- unet.set_use_memory_efficient_attention_xformers(True)
409
- else:
410
- raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
411
-
412
- if lora_name:
413
- # ... (Код LoRA без изменений, опущен для краткости, если не используется, иначе раскомментируйте оригинальный блок) ...
414
- pass
415
-
416
- # Оптимизатор
417
- if lora_name:
418
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
419
- else:
420
- if fbp:
421
- trainable_params = list(unet.parameters())
422
-
423
- def create_optimizer(name, params):
424
- if name == "adam8bit":
425
- return bnb.optim.AdamW8bit(
426
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
427
- percentile_clipping=percentile_clipping
428
- )
429
- elif name == "adam":
430
- return torch.optim.AdamW(
431
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
432
- )
433
- elif name == "muon":
434
- from muon import MuonWithAuxAdam
435
- trainable_params = [p for p in params if p.requires_grad]
436
- hidden_weights = [p for p in trainable_params if p.ndim >= 2]
437
- hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
438
-
439
- param_groups = [
440
- dict(params=hidden_weights, use_muon=True,
441
- lr=1e-3, weight_decay=1e-4),
442
- dict(params=hidden_gains_biases, use_muon=False,
443
- lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
444
- ]
445
- optimizer = MuonWithAuxAdam(param_groups)
446
- from snooc import SnooC
447
- return SnooC(optimizer)
448
- else:
449
- raise ValueError(f"Unknown optimizer: {name}")
450
-
451
- if fbp:
452
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
453
- def optimizer_hook(param):
454
- optimizer_dict[param].step()
455
- optimizer_dict[param].zero_grad(set_to_none=True)
456
- for param in trainable_params:
457
- param.register_post_accumulate_grad_hook(optimizer_hook)
458
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
459
- else:
460
- optimizer = create_optimizer(optimizer_type, unet.parameters())
461
- def lr_schedule(step):
462
- x = step / (total_training_steps * world_size)
463
- warmup = warmup_percent
464
- if not use_decay:
465
- return base_learning_rate
466
- if x < warmup:
467
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
468
- decay_ratio = (x - warmup) / (1 - warmup)
469
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
470
- (1 + math.cos(math.pi * decay_ratio))
471
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
472
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
473
-
474
- if torch_compile:
475
- print("compiling")
476
- unet = torch.compile(unet)
477
- print("compiling - ok")
478
-
479
- # Фиксированные семплы
480
- fixed_samples = get_fixed_samples_by_resolution(dataset)
481
-
482
- # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
483
- def get_negative_embedding(neg_prompt="", batch_size=1):
484
- if not neg_prompt:
485
- hidden_dim = 2048
486
- seq_len = max_length
487
- empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
488
- empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
489
- return empty_emb, empty_mask
490
-
491
- uncond_emb, uncond_mask = encode_texts([neg_prompt])
492
- uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
493
- uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
494
-
495
- return uncond_emb, uncond_mask
496
-
497
- # Получаем негативные (пустые) усл��вия для валидации
498
- uncond_emb, uncond_mask = get_negative_embedding("low quality")
499
-
500
- # --- Функция генерации семплов ---
501
- @torch.compiler.disable()
502
- @torch.no_grad()
503
- def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
504
- uncond_emb, uncond_mask = uncond_data
505
-
506
- original_model = None
507
- try:
508
- if not torch_compile:
509
- original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
510
- else:
511
- original_model = unet.eval()
512
-
513
- vae.to(device=device).eval()
514
-
515
- all_generated_images = []
516
- all_captions = []
517
-
518
- # Распаковываем 5 элементов (добавились mask)
519
- for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
520
- width, height = size
521
- sample_latents = sample_latents.to(dtype=dtype, device=device)
522
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
523
- sample_mask = sample_mask.to(device=device)
524
-
525
- latents = torch.randn(
526
- sample_latents.shape,
527
- device=device,
528
- dtype=sample_latents.dtype,
529
- generator=torch.Generator(device=device).manual_seed(seed)
530
- )
531
-
532
- scheduler.set_timesteps(n_diffusion_steps, device=device)
533
-
534
- for t in scheduler.timesteps:
535
- if guidance_scale != 1:
536
- latent_model_input = torch.cat([latents, latents], dim=0)
537
-
538
- # Подготовка батчей для CFG (Negative + Positive)
539
- # 1. Embeddings
540
- curr_batch_size = sample_text_embeddings.shape[0]
541
- seq_len = sample_text_embeddings.shape[1]
542
- hidden_dim = sample_text_embeddings.shape[2]
543
-
544
- neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
545
- text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
546
-
547
- # 2. Masks
548
- neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
549
- attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
550
-
551
- else:
552
- latent_model_input = latents
553
- text_embeddings_batch = sample_text_embeddings
554
- attention_mask_batch = sample_mask
555
-
556
- # Предсказание с передачей всех условий
557
- model_out = original_model(
558
- latent_model_input,
559
- t,
560
- encoder_hidden_states=text_embeddings_batch,
561
- encoder_attention_mask=attention_mask_batch,
562
- )
563
- flow = getattr(model_out, "sample", model_out)
564
-
565
- if guidance_scale != 1:
566
- flow_uncond, flow_cond = flow.chunk(2)
567
- flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
568
-
569
- latents = scheduler.step(flow, t, latents).prev_sample
570
-
571
- current_latents = latents
572
- if step==0:
573
- current_latents = sample_latents
574
-
575
- latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
576
- decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
577
- decoded_fp32 = decoded.to(torch.float32)
578
-
579
- for img_idx, img_tensor in enumerate(decoded_fp32):
580
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
581
- img = img.transpose(1, 2, 0)
582
-
583
- if np.isnan(img).any():
584
- print("NaNs found, saving stopped! Step:", step)
585
- pil_img = Image.fromarray((img * 255).astype("uint8"))
586
-
587
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
588
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
589
- max_w_overall = max(255, max_w_overall)
590
- max_h_overall = max(255, max_h_overall)
591
-
592
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
593
- all_generated_images.append(padded_img)
594
-
595
- caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
596
- all_captions.append(caption_text)
597
-
598
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
599
- pil_img.save(sample_path, "JPEG", quality=96)
600
-
601
- if use_wandb and accelerator.is_main_process:
602
- wandb_images = [
603
- wandb.Image(img, caption=f"{all_captions[i]}")
604
- for i, img in enumerate(all_generated_images)
605
- ]
606
- wandb.log({"generated_images": wandb_images})
607
- if use_comet_ml and accelerator.is_main_process:
608
- for i, img in enumerate(all_generated_images):
609
- comet_experiment.log_image(
610
- image_data=img,
611
- name=f"step_{step}_img_{i}",
612
- step=step,
613
- metadata={"caption": all_captions[i]}
614
- )
615
- finally:
616
- vae.to("cpu")
617
- torch.cuda.empty_cache()
618
- gc.collect()
619
-
620
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
621
- if accelerator.is_main_process:
622
- if save_model:
623
- print("Генерация сэмплов до старта обучения...")
624
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
625
- accelerator.wait_for_everyone()
626
-
627
- def save_checkpoint(unet, variant=""):
628
- if accelerator.is_main_process:
629
- if lora_name:
630
- save_lora_checkpoint(unet)
631
- else:
632
- model_to_save = None
633
- if not torch_compile:
634
- model_to_save = accelerator.unwrap_model(unet)
635
- else:
636
- model_to_save = unet
637
-
638
- if variant != "":
639
- model_to_save.to(dtype=torch.float16).save_pretrained(
640
- os.path.join(checkpoints_folder, f"{project}"), variant=variant
641
- )
642
- else:
643
- model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
644
-
645
- unet = unet.to(dtype=dtype)
646
-
647
- # --------------------------- Тренировочный цикл ---------------------------
648
- if accelerator.is_main_process:
649
- print(f"Total steps per GPU: {total_training_steps}")
650
-
651
- epoch_loss_points = []
652
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
653
-
654
- steps_per_epoch = len(dataloader)
655
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
656
- min_loss = 4.
657
-
658
- for epoch in range(start_epoch, start_epoch + num_epochs):
659
- batch_losses = []
660
- batch_grads = []
661
- batch_sampler.set_epoch(epoch)
662
- accelerator.wait_for_everyone()
663
- unet.train()
664
-
665
- for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
666
- with accelerator.accumulate(unet):
667
- if save_model == False and epoch == 0 and step == 5 :
668
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
669
- print(f"Шаг {step}: {used_gb:.2f} GB")
670
-
671
- # шум
672
- noise = torch.randn_like(latents, dtype=latents.dtype)
673
- # берём t из [0, 1]
674
- t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
675
- #u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
676
- #t = torch.sigmoid(torch.randn_like(u))
677
-
678
- # интерполяция между x0 и шумом
679
- noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
680
- # делаем integer timesteps для UNet
681
- timesteps = (t * scheduler.config.num_train_timesteps).long()
682
-
683
- # --- Вызов UNet с маской ---
684
- model_pred = unet(
685
- noisy_latents,
686
- timesteps,
687
- encoder_hidden_states=embeddings,
688
- encoder_attention_mask=attention_mask
689
- ).sample
690
-
691
- target = noise - latents
692
-
693
- mse_loss = F.mse_loss(model_pred.float(), target.float())
694
- mae_loss = F.l1_loss(model_pred.float(), target.float())
695
- batch_losses.append(mse_loss.detach().item())
696
-
697
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
698
- accelerator.wait_for_everyone()
699
-
700
- losses_dict = {}
701
- losses_dict["mse"] = mse_loss
702
- losses_dict["mae"] = mae_loss
703
-
704
- # === Нормализация всех лоссов ===
705
- abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
706
- total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
707
-
708
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
709
- accelerator.wait_for_everyone()
710
-
711
- accelerator.backward(total_loss)
712
-
713
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
714
- accelerator.wait_for_everyone()
715
-
716
- grad = 0.0
717
- if not fbp:
718
- if accelerator.sync_gradients:
719
- #with torch.amp.autocast('cuda', enabled=False):
720
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
721
- grad = float(grad_val)
722
- optimizer.step()
723
- lr_scheduler.step()
724
- optimizer.zero_grad(set_to_none=True)
725
-
726
- if accelerator.sync_gradients:
727
- global_step += 1
728
- progress_bar.update(1)
729
- if accelerator.is_main_process:
730
- if fbp:
731
- current_lr = base_learning_rate
732
- else:
733
- current_lr = lr_scheduler.get_last_lr()[0]
734
- batch_grads.append(grad)
735
-
736
- log_data = {}
737
- log_data["loss_mse"] = mse_loss.detach().item()
738
- log_data["loss_mae"] = mae_loss.detach().item()
739
- log_data["lr"] = current_lr
740
- log_data["grad"] = grad
741
- log_data["loss_norm"] = float(total_loss.item())
742
- for k, c in coeffs.items():
743
- log_data[f"coeff_{k}"] = float(c)
744
- if accelerator.sync_gradients:
745
- if use_wandb:
746
- wandb.log(log_data, step=global_step)
747
- if use_comet_ml:
748
- comet_experiment.log_metrics(log_data, step=global_step)
749
-
750
- if global_step % sample_interval == 0:
751
- # Передаем tuple (emb, mask) для негатива
752
- if save_model:
753
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
754
- elif epoch % 10 == 0:
755
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
756
- last_n = sample_interval
757
-
758
- if save_model:
759
- has_losses = len(batch_losses) > 0
760
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
761
- last_loss = batch_losses[-1] if has_losses else 0.0
762
- max_loss = max(avg_sample_loss, last_loss)
763
- should_save = max_loss < min_loss * save_barrier
764
- print(
765
- f"Saving: {should_save} | Max: {max_loss:.4f} | "
766
- f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
767
- )
768
- # 6. Сохранение и обновление
769
- if should_save:
770
- min_loss = max_loss
771
- save_checkpoint(unet)
772
-
773
- if accelerator.is_main_process:
774
- avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
775
- avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
776
-
777
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
778
- log_data_ep = {
779
- "epoch_loss": avg_epoch_loss,
780
- "epoch_grad": avg_epoch_grad,
781
- "epoch": epoch + 1,
782
- }
783
- if use_wandb:
784
- wandb.log(log_data_ep)
785
- if use_comet_ml:
786
- comet_experiment.log_metrics(log_data_ep)
787
-
788
- if accelerator.is_main_process:
789
- print("Обучение завершено! Сохраняем финальную модель...")
790
- #if save_model:
791
- save_checkpoint(unet,"fp16")
792
- if use_comet_ml:
793
- comet_experiment.end()
794
- accelerator.free_memory()
795
- if torch.distributed.is_initialized():
796
- torch.distributed.destroy_process_group()
797
-
798
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/sd15_2048.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:057785145ba468215062bf2b6ea7dca9aa3186f0f010215eb00e4739bf213e17
3
+ size 51250
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:20f9413ba8b38673409699a49f00a6f9794a0c5153625b422b6151af0df9a940
3
- size 2047294
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25096d7d566bf784726fbdbd762781a0ff0a25523554272c3c426989633dc969
3
+ size 5145945
train.py CHANGED
@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader, Sampler
8
  from torch.utils.data.distributed import DistributedSampler
9
  from torch.optim.lr_scheduler import LambdaLR
10
  from collections import defaultdict
11
- from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
@@ -30,10 +30,10 @@ from transformers import AutoTokenizer, AutoModel
30
  ds_path = "/workspace/sdxs/datasets/mjnj"
31
  project = "unet"
32
  batch_size = 48
33
- base_learning_rate = 2.7e-5 #2.7e-5
34
  min_learning_rate = 1e-5 #2.7e-5
35
  num_epochs = 50
36
- sample_interval_share = 5
37
  cfg_dropout = 0.15
38
  max_length = 192
39
  use_wandb = False
@@ -96,8 +96,8 @@ lora_alpha = 64
96
  print("init")
97
 
98
  loss_ratios = {
99
- "mse": 1.25,
100
- "mae": 0.25,
101
  }
102
  median_coeff_steps = 256
103
 
@@ -164,10 +164,12 @@ if accelerator.is_main_process:
164
  torch.backends.cuda.enable_flash_sdp(True)
165
 
166
  # --------------------------- Загрузка моделей ---------------------------
167
- vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
 
168
  #vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
169
  tokenizer = AutoTokenizer.from_pretrained("tokenizer")
170
  text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
 
171
 
172
  # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
173
  def encode_texts(texts, max_length=max_length):
@@ -237,10 +239,6 @@ if shift_factor is None: shift_factor = 0.0
237
  scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
238
  if scaling_factor is None: scaling_factor = 1.0
239
 
240
- from diffusers import FlowMatchEulerDiscreteScheduler
241
- num_train_timesteps = 1000
242
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps)
243
-
244
  class DistributedResolutionBatchSampler(Sampler):
245
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
246
  self.dataset = dataset
@@ -708,7 +706,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
708
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
709
  accelerator.wait_for_everyone()
710
 
711
- accelerator.backward(total_loss)
712
 
713
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
714
  accelerator.wait_for_everyone()
 
8
  from torch.utils.data.distributed import DistributedSampler
9
  from torch.optim.lr_scheduler import LambdaLR
10
  from collections import defaultdict
11
+ from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2,AsymmetricAutoencoderKL,FlowMatchEulerDiscreteScheduler
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
 
30
  ds_path = "/workspace/sdxs/datasets/mjnj"
31
  project = "unet"
32
  batch_size = 48
33
+ base_learning_rate = 4e-5 #2.7e-5
34
  min_learning_rate = 1e-5 #2.7e-5
35
  num_epochs = 50
36
+ sample_interval_share = 10
37
  cfg_dropout = 0.15
38
  max_length = 192
39
  use_wandb = False
 
96
  print("init")
97
 
98
  loss_ratios = {
99
+ "mse": 1.5,
100
+ "mae": 0.5,
101
  }
102
  median_coeff_steps = 256
103
 
 
164
  torch.backends.cuda.enable_flash_sdp(True)
165
 
166
  # --------------------------- Загрузка моделей ---------------------------
167
+ #vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
168
+ vae = AsymmetricAutoencoderKL.from_pretrained("vae",torch_dtype=dtype).to(device).eval()
169
  #vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
170
  tokenizer = AutoTokenizer.from_pretrained("tokenizer")
171
  text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
172
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
173
 
174
  # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
175
  def encode_texts(texts, max_length=max_length):
 
239
  scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
240
  if scaling_factor is None: scaling_factor = 1.0
241
 
 
 
 
 
242
  class DistributedResolutionBatchSampler(Sampler):
243
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
244
  self.dataset = dataset
 
706
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
707
  accelerator.wait_for_everyone()
708
 
709
+ accelerator.backward(mse_loss)
710
 
711
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
712
  accelerator.wait_for_everyone()
unet/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd80ac5e521c295e9b9cc5361127114bcbfea059d76f3e3a80775c91ac666281
3
- size 1798
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78d4828222ad3a8cddeadf895d9a3afce5c95869d374458dc2c7e5d3b9bf9864
3
+ size 1813
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f3bec6d01b481146de7c8e2adcbce93df1f1cc10ac89834e551fb2d450d286f4
3
- size 6078588464
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e903b0e21f57f4ebe996e08e09c76f0605377cde3eea67fe8b9ffce399b153f
3
+ size 3566239360
{sdxs_08b → unet_sdxl5}/config.json RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f4c96435f2980db8743704e9361889fb5df8c50443518f76cfe966e8dfc9dc53
3
- size 1803
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd80ac5e521c295e9b9cc5361127114bcbfea059d76f3e3a80775c91ac666281
3
+ size 1798
{sdxs_08b → unet_sdxl5}/diffusion_pytorch_model.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9b771189641f108016e1642640e43cb7cc65924a9e6d104cd46831854771cb7b
3
- size 3376002424
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3bec6d01b481146de7c8e2adcbce93df1f1cc10ac89834e551fb2d450d286f4
3
+ size 6078588464
vae/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
vae/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e2e5b25e39aec4b6a75e4837adec277dfc830e00992e6ce4dd75eb2627d73197
3
- size 774
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bf69a4b1ec4c3b6666326d7d17e98e7f7ed6880084c702101bdb3e75905535c
3
+ size 773
vae/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:31c83db36d96ddfd42003f85abe8c22bd03a07b8174135351345d2726bd75c38
3
  size 382598708
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffde397a3e78a779adff8ba78297f66d01af5e397512f6ed6d500df30e9833a1
3
  size 382598708
vae/train_vae_fdl_distil.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
+ from collections import deque
29
+
30
+ # --------------------------- Параметры ---------------------------
31
+ ds_path = "/workspace/d23"
32
+ project = "vae5"
33
+ batch_size = 2
34
+ base_learning_rate = 4e-5
35
+ min_learning_rate = 2e-5
36
+ num_epochs = 10
37
+ sample_interval_share = 10
38
+ use_wandb = True
39
+ save_model = True
40
+ use_decay = True
41
+ optimizer_type = "adam8bit"
42
+ dtype = torch.float32
43
+
44
+ model_resolution = 256
45
+ high_resolution = 512
46
+ limit = 0
47
+ save_barrier = 1.3
48
+ warmup_percent = 0.005
49
+ percentile_clipping = 99
50
+ beta2 = 0.997
51
+ eps = 1e-8
52
+ clip_grad_norm = 1.0
53
+ mixed_precision = "no"
54
+ gradient_accumulation_steps = 4
55
+ generated_folder = "samples"
56
+ save_as = "vae6"
57
+ num_workers = 0
58
+ device = None
59
+
60
+ # --- Режимы обучения ---
61
+ # QWEN: учим только декодер
62
+ train_decoder_only = False
63
+ train_up_only = False
64
+ full_training = True # если True — учим весь VAE и добавляем KL (ниже)
65
+ kl_ratio = 0.001
66
+
67
+ # Доли лоссов
68
+ loss_ratios = {
69
+ "lpips": 0.55,#0.50,
70
+ "fdl" : 0.05,#0.25,
71
+ "edge": 0.05,
72
+ "mse": 0.10,
73
+ "mae": 0.04,
74
+ "kl": 0.001, # активируем при full_training=True
75
+ "vae2": 0.199,
76
+ }
77
+ median_coeff_steps = 1000
78
+
79
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
80
+
81
+ # QWEN: конфиг загрузки модели
82
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
83
+
84
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
85
+
86
+ accelerator = Accelerator(
87
+ mixed_precision=mixed_precision,
88
+ gradient_accumulation_steps=gradient_accumulation_steps
89
+ )
90
+ device = accelerator.device
91
+
92
+ # reproducibility
93
+ seed = int(datetime.now().strftime("%Y%m%d"))
94
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
95
+ torch.backends.cudnn.benchmark = False
96
+
97
+ # --------------------------- WandB ---------------------------
98
+ if use_wandb and accelerator.is_main_process:
99
+ wandb.init(project=project, config={
100
+ "batch_size": batch_size,
101
+ "base_learning_rate": base_learning_rate,
102
+ "num_epochs": num_epochs,
103
+ "optimizer_type": optimizer_type,
104
+ "model_resolution": model_resolution,
105
+ "high_resolution": high_resolution,
106
+ "gradient_accumulation_steps": gradient_accumulation_steps,
107
+ "train_decoder_only": train_decoder_only,
108
+ "full_training": full_training,
109
+ "kl_ratio": kl_ratio,
110
+ "vae_kind": vae_kind,
111
+ })
112
+
113
+ # --------------------------- VAE ---------------------------
114
+ def get_core_model(model):
115
+ m = model
116
+ # если модель уже обёрнута torch.compile
117
+ if hasattr(m, "_orig_mod"):
118
+ m = m._orig_mod
119
+ return m
120
+
121
+ def is_video_vae(model) -> bool:
122
+ # WAN/Qwen — это видео-VAEs
123
+ if vae_kind in ("wan", "qwen"):
124
+ return True
125
+ # fallback по структуре (если понадобится)
126
+ try:
127
+ core = get_core_model(model)
128
+ enc = getattr(core, "encoder", None)
129
+ conv_in = getattr(enc, "conv_in", None)
130
+ w = getattr(conv_in, "weight", None)
131
+ if isinstance(w, torch.nn.Parameter):
132
+ return w.ndim == 5
133
+ except Exception:
134
+ pass
135
+ return False
136
+
137
+ # загрузка
138
+ if vae_kind == "qwen":
139
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
140
+ else:
141
+ if vae_kind == "wan":
142
+ vae = AutoencoderKLWan.from_pretrained(project)
143
+ else:
144
+ # старое поведение (пример)
145
+ if model_resolution==high_resolution:
146
+ vae = AutoencoderKL.from_pretrained(project)
147
+ else:
148
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
149
+
150
+ vae = vae.to(dtype)
151
+
152
+ # --------------------------- VAE2 (Distillation Teacher) ---------------------------
153
+ # Загружаем учителя (SD 1.4) для дистилляции
154
+ print("[INFO] Loading VAE2 (Teacher) for distillation...")
155
+ vae2 = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
156
+ vae2.requires_grad_(False)
157
+ vae2.eval()
158
+ # vae2 перенесем на device позже внутри accelerator.prepare или явно,
159
+ # но для надежности сделаем это здесь, чтобы не занимать лишнюю память оптимизатором
160
+ vae2.to(device, dtype=dtype)
161
+
162
+ # Адаптер для проекции 16 каналов студента -> 4 канала учителя
163
+ # Kernel size 1 делает линейную проекцию по пикселям
164
+ distill_adapter = torch.nn.Conv2d(16, 4, kernel_size=1, stride=1, padding=0, bias=True)
165
+ distill_adapter.to(device, dtype=dtype)
166
+ distill_adapter.train() # Адаптер мы обучаем!
167
+
168
+
169
+ # torch.compile (опционально)
170
+ if hasattr(torch, "compile"):
171
+ try:
172
+ vae = torch.compile(vae)
173
+ except Exception as e:
174
+ print(f"[WARN] torch.compile failed: {e}")
175
+
176
+ # --------------------------- Freeze/Unfreeze ---------------------------
177
+ core = get_core_model(vae)
178
+
179
+ for p in core.parameters():
180
+ p.requires_grad = False
181
+
182
+ unfrozen_param_names = []
183
+
184
+ if full_training and not train_decoder_only:
185
+ for name, p in core.named_parameters():
186
+ p.requires_grad = True
187
+ unfrozen_param_names.append(name)
188
+ loss_ratios["kl"] = float(kl_ratio)
189
+ trainable_module = core
190
+ else:
191
+ # учим только 0-й блок декодера + post_quant_conv
192
+ if hasattr(core, "decoder"):
193
+ if train_up_only:#hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
194
+ # --- только 0-й up_block ---
195
+ for name, p in core.decoder.up_blocks[0].named_parameters():
196
+ p.requires_grad = True
197
+ unfrozen_param_names.append(f"{name}")
198
+ else:
199
+ print("Decoder — fallback to full decoder")
200
+ for name, p in core.decoder.named_parameters():
201
+ p.requires_grad = True
202
+ unfrozen_param_names.append(f"decoder.{name}")
203
+ if hasattr(core, "post_quant_conv"):
204
+ for name, p in core.post_quant_conv.named_parameters():
205
+ p.requires_grad = True
206
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
207
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
208
+
209
+
210
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
211
+ for nm in unfrozen_param_names[:200]:
212
+ print(" ", nm)
213
+
214
+ # --------------------------- Датасет ---------------------------
215
+ class PngFolderDataset(Dataset):
216
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
217
+ self.root_dir = root_dir
218
+ self.resolution = resolution
219
+ self.paths = []
220
+ for root, _, files in os.walk(root_dir):
221
+ for fname in files:
222
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
223
+ self.paths.append(os.path.join(root, fname))
224
+ if limit:
225
+ self.paths = self.paths[:limit]
226
+ valid = []
227
+ for p in self.paths:
228
+ try:
229
+ with Image.open(p) as im:
230
+ im.verify()
231
+ valid.append(p)
232
+ except (OSError, UnidentifiedImageError):
233
+ continue
234
+ self.paths = valid
235
+ if len(self.paths) == 0:
236
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
237
+ random.shuffle(self.paths)
238
+
239
+ def __len__(self):
240
+ return len(self.paths)
241
+
242
+ def __getitem__(self, idx):
243
+ p = self.paths[idx % len(self.paths)]
244
+ with Image.open(p) as img:
245
+ img = img.convert("RGB")
246
+ if not resize_long_side or resize_long_side <= 0:
247
+ return img
248
+ w, h = img.size
249
+ long = max(w, h)
250
+ if long <= resize_long_side:
251
+ return img
252
+ scale = resize_long_side / float(long)
253
+ new_w = int(round(w * scale))
254
+ new_h = int(round(h * scale))
255
+ return img.resize((new_w, new_h), Image.BICUBIC)
256
+
257
+ def random_crop(img, sz):
258
+ w, h = img.size
259
+ if w < sz or h < sz:
260
+ img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
261
+ x = random.randint(0, max(1, img.width - sz))
262
+ y = random.randint(0, max(1, img.height - sz))
263
+ return img.crop((x, y, x + sz, y + sz))
264
+
265
+ tfm = transforms.Compose([
266
+ transforms.ToTensor(),
267
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
268
+ ])
269
+
270
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
271
+ if len(dataset) < batch_size:
272
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
273
+
274
+ def collate_fn(batch):
275
+ imgs = []
276
+ for img in batch:
277
+ img = random_crop(img, high_resolution)
278
+ imgs.append(tfm(img))
279
+ return torch.stack(imgs)
280
+
281
+ dataloader = DataLoader(
282
+ dataset,
283
+ batch_size=batch_size,
284
+ shuffle=True,
285
+ collate_fn=collate_fn,
286
+ num_workers=num_workers,
287
+ pin_memory=True,
288
+ drop_last=True
289
+ )
290
+
291
+ # --------------------------- Оптимизатор ---------------------------
292
+ def get_param_groups(module, weight_decay=0.001):
293
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
294
+ decay_params, no_decay_params = [], []
295
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
296
+ if not p.requires_grad:
297
+ continue
298
+ if any(nd in n for nd in no_decay):
299
+ no_decay_params.append(p)
300
+ else:
301
+ decay_params.append(p)
302
+ return [
303
+ {"params": decay_params, "weight_decay": weight_decay},
304
+ {"params": no_decay_params, "weight_decay": 0.0},
305
+ ]
306
+
307
+ def get_param_groups(module, weight_decay=0.001):
308
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
309
+ decay_params, no_decay_params = [], []
310
+ for n, p in module.named_parameters():
311
+ if not p.requires_grad:
312
+ continue
313
+ n_l = n.lower()
314
+ if any(t in n_l for t in no_decay_tokens):
315
+ no_decay_params.append(p)
316
+ else:
317
+ decay_params.append(p)
318
+ return [
319
+ {"params": decay_params, "weight_decay": weight_decay},
320
+ {"params": no_decay_params, "weight_decay": 0.0},
321
+ ]
322
+
323
+ def create_optimizer(name, param_groups):
324
+ if name == "adam8bit":
325
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
326
+ raise ValueError(name)
327
+
328
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
329
+
330
+ # --- ИЗМЕНЕНИЕ: Добавляем параметры адаптера в оптимизатор ---
331
+ # Адаптер маленький, weight_decay ему особо не нужен, но пусть будет стандартный
332
+ adapter_params = get_param_groups(distill_adapter, weight_decay=0.001)
333
+ param_groups.extend(adapter_params)
334
+ optimizer = create_optimizer(optimizer_type, param_groups)
335
+
336
+ # --------------------------- LR schedule ---------------------------
337
+ batches_per_epoch = len(dataloader)
338
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
339
+ total_steps = steps_per_epoch * num_epochs
340
+
341
+ def lr_lambda(step):
342
+ if not use_decay:
343
+ return 1.0
344
+ x = float(step) / float(max(1, total_steps))
345
+ warmup = float(warmup_percent)
346
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
347
+ if x < warmup:
348
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
349
+ decay_ratio = (x - warmup) / (1.0 - warmup)
350
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
351
+
352
+ scheduler = LambdaLR(optimizer, lr_lambda)
353
+
354
+ # Подготовка
355
+ dataloader, vae, distill_adapter, optimizer, scheduler = accelerator.prepare(
356
+ dataloader, vae, distill_adapter, optimizer, scheduler
357
+ )
358
+ # vae2 остался обычным тензором на GPU, accelerator его не трогает
359
+ # ----------------------------------------------------------
360
+
361
+ trainable_params = [p for p in vae.parameters() if p.requires_grad] + \
362
+ [p for p in distill_adapter.parameters() if p.requires_grad]
363
+
364
+ # fdl
365
+ fdl_loss = FDL_loss()
366
+ fdl_loss = fdl_loss.to(accelerator.device)
367
+
368
+ # --------------------------- LPIPS и вспомогательные ---------------------------
369
+ _lpips_net = None
370
+ def _get_lpips():
371
+ global _lpips_net
372
+ if _lpips_net is None:
373
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
374
+ return _lpips_net
375
+
376
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
377
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
378
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
379
+ C = x.shape[1]
380
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
381
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
382
+ gx = F.conv2d(x, kx, padding=1, groups=C)
383
+ gy = F.conv2d(x, ky, padding=1, groups=C)
384
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
385
+
386
+ class MedianLossNormalizer:
387
+ def __init__(self, desired_ratios: dict, window_steps: int):
388
+ s = sum(desired_ratios.values())
389
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
390
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
391
+ self.window = window_steps
392
+
393
+ def update_and_total(self, abs_losses: dict):
394
+ for k, v in abs_losses.items():
395
+ if k in self.buffers:
396
+ self.buffers[k].append(float(v.detach().abs().cpu()))
397
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
398
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
399
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
400
+ return total, coeffs, meds
401
+
402
+ if full_training and not train_decoder_only:
403
+ loss_ratios["kl"] = float(kl_ratio)
404
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
405
+
406
+ # --------------------------- Сэмплы ---------------------------
407
+ @torch.no_grad()
408
+ def get_fixed_samples(n=3):
409
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
410
+ pil_imgs = [dataset[i] for i in idx]
411
+ tensors = []
412
+ for img in pil_imgs:
413
+ img = random_crop(img, high_resolution)
414
+ tensors.append(tfm(img))
415
+ return torch.stack(tensors).to(accelerator.device, dtype)
416
+
417
+ fixed_samples = get_fixed_samples()
418
+
419
+ @torch.no_grad()
420
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
421
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
422
+ return Image.fromarray(arr)
423
+
424
+
425
+ @torch.no_grad()
426
+ def generate_and_save_samples(step=None):
427
+ try:
428
+ temp_vae = accelerator.unwrap_model(vae).eval()
429
+ lpips_net = _get_lpips()
430
+ with torch.no_grad():
431
+ orig_high = fixed_samples
432
+ orig_low = F.interpolate(
433
+ orig_high,
434
+ size=(model_resolution, model_resolution),
435
+ mode="bilinear",
436
+ align_corners=False
437
+ )
438
+ model_dtype = next(temp_vae.parameters()).dtype
439
+ orig_low = orig_low.to(dtype=model_dtype)
440
+
441
+ # Encode/decode с учётом видео-режима
442
+ if is_video_vae(temp_vae):
443
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
444
+ enc = temp_vae.encode(x_in)
445
+ latents_mean = enc.latent_dist.mean
446
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
447
+ rec = dec.squeeze(2) # [B,3,H,W]
448
+ else:
449
+ enc = temp_vae.encode(orig_low)
450
+ latents_mean = enc.latent_dist.mean
451
+ rec = temp_vae.decode(latents_mean).sample
452
+
453
+ # Подгон размеров, если надо
454
+ #if rec.shape[-2:] != orig_high.shape[-2:]:
455
+ # rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
456
+
457
+ # Сохраняем все real/decoded
458
+ for i in range(rec.shape[0]):
459
+ real_img = _to_pil_uint8(orig_high[i])
460
+ dec_img = _to_pil_uint8(rec[i])
461
+ real_img.save(f"{generated_folder}/sample_real_{i}.jpg", quality=95)
462
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.jpg", quality=95)
463
+
464
+ # LPIPS
465
+ lpips_scores = []
466
+ for i in range(rec.shape[0]):
467
+ orig_full = orig_high[i:i+1].to(torch.float32)
468
+ rec_full = rec[i:i+1].to(torch.float32)
469
+ #if rec_full.shape[-2:] != orig_full.shape[-2:]:
470
+ # rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
471
+ lpips_val = lpips_net(orig_full, rec_full).item()
472
+ lpips_scores.append(lpips_val)
473
+ avg_lpips = float(np.mean(lpips_scores))
474
+
475
+ # W&B логирование
476
+ if use_wandb and accelerator.is_main_process:
477
+ log_data = {"lpips_mean": avg_lpips}
478
+ for i in range(rec.shape[0]):
479
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.jpg", caption=f"real_{i}")
480
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.jpg", caption=f"decoded_{i}")
481
+ wandb.log(log_data, step=step)
482
+
483
+ finally:
484
+ gc.collect()
485
+ torch.cuda.empty_cache()
486
+
487
+
488
+ if accelerator.is_main_process and save_model:
489
+ print("Генерация сэмплов до старта обучения...")
490
+ generate_and_save_samples(0)
491
+
492
+ accelerator.wait_for_everyone()
493
+
494
+ # --------------------------- Тренировка ---------------------------
495
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
496
+ global_step = 0
497
+ min_loss = float("inf")
498
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
499
+
500
+ for epoch in range(num_epochs):
501
+ vae.train()
502
+ batch_losses, batch_grads = [], []
503
+ track_losses = {k: [] for k in loss_ratios.keys()}
504
+
505
+ for imgs in dataloader:
506
+ with accelerator.accumulate(vae):
507
+ imgs = imgs.to(accelerator.device)
508
+
509
+ if high_resolution != model_resolution:
510
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
511
+ else:
512
+ imgs_low = imgs
513
+
514
+ model_dtype = next(vae.parameters()).dtype
515
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
516
+
517
+ # QWEN: encode/decode с T=1
518
+ if is_video_vae(vae):
519
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
520
+ enc = vae.encode(x_in)
521
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
522
+ dec = vae.decode(latents).sample # [B,3,1,H,W]
523
+ rec = dec.squeeze(2) # [B,3,H,W]
524
+ else:
525
+ enc = vae.encode(imgs_low_model)
526
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
527
+ rec = vae.decode(latents).sample
528
+
529
+ #if rec.shape[-2:] != imgs.shape[-2:]:
530
+ # rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
531
+
532
+ rec_f32 = rec.to(torch.float32)
533
+ imgs_f32 = imgs.to(torch.float32)
534
+
535
+ # ... (предыдущий код получения rec, rec_f32, imgs_f32) ...
536
+
537
+ # --- ИЗМЕНЕНИЕ: Расчет VAE2 Distillation Loss ---
538
+ # 1. Получаем латенты учителя (SD).
539
+ # SD VAE ожидает вход [-1, 1], у нас imgs_low_model уже нормализован так же.
540
+ with torch.no_grad():
541
+ # scale_factor=0.18215 в SD обычно применяется ПОСЛЕ энкодера для диффузии.
542
+ # Но для дистилляции мы можем сравнивать "сырые" распределения (moments).
543
+ # Главное сравнивать яблоки с яблоками. Берем .mean (детерминированный выход).
544
+ teacher_dist = vae2.encode(imgs_low_model).latent_dist
545
+ teacher_mean = teacher_dist.mean
546
+
547
+ # 2. Получаем латенты студента.
548
+ # Они уже посчитаны выше в переменной 'enc'
549
+ # enc.latent_dist.mean - это 16 каналов
550
+ student_mean = enc.latent_dist.mean
551
+
552
+ # 3. Проецируем 16 -> 4 через наш обучаемый адаптер
553
+ # distill_adapter уже обернут акселератором, dtype будет правильный
554
+ student_projected = distill_adapter(student_mean)
555
+
556
+ # 4. Считаем MSE между проекцией студента и учителем
557
+ # Оба тензора должны быть [B, 4, H_lat, W_lat]
558
+ loss_distill = F.mse_loss(student_projected.float(), teacher_mean.float())
559
+ # ------------------------------------------------
560
+
561
+ abs_losses = {
562
+ "mae": F.l1_loss(rec_f32, imgs_f32),
563
+ "mse": F.mse_loss(rec_f32, imgs_f32),
564
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
565
+ "fdl": fdl_loss(rec_f32, imgs_f32),
566
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
567
+ "vae2": loss_distill, # <--- Добавляем в словарь лоссов
568
+ }
569
+
570
+ if full_training and not train_decoder_only:
571
+ mean = enc.latent_dist.mean
572
+ logvar = enc.latent_dist.logvar
573
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
574
+ abs_losses["kl"] = kl
575
+ else:
576
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
577
+
578
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
579
+
580
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
581
+ raise RuntimeError("NaN/Inf loss")
582
+
583
+ accelerator.backward(total_loss)
584
+
585
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
586
+ if accelerator.sync_gradients:
587
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
588
+ optimizer.step()
589
+ scheduler.step()
590
+ optimizer.zero_grad(set_to_none=True)
591
+ global_step += 1
592
+ progress.update(1)
593
+
594
+ if accelerator.is_main_process:
595
+ try:
596
+ current_lr = optimizer.param_groups[0]["lr"]
597
+ except Exception:
598
+ current_lr = scheduler.get_last_lr()[0]
599
+
600
+ batch_losses.append(total_loss.detach().item())
601
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
602
+ for k, v in abs_losses.items():
603
+ track_losses[k].append(float(v.detach().item()))
604
+
605
+ if use_wandb and accelerator.sync_gradients:
606
+ log_dict = {
607
+ "total_loss": float(total_loss.detach().item()),
608
+ "learning_rate": current_lr,
609
+ "epoch": epoch,
610
+ "grad_norm": batch_grads[-1],
611
+ }
612
+ for k, v in abs_losses.items():
613
+ log_dict[f"loss_{k}"] = float(v.detach().item())
614
+ for k in coeffs:
615
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
616
+ log_dict[f"median_{k}"] = float(meds[k])
617
+ wandb.log(log_dict, step=global_step)
618
+
619
+ if global_step > 0 and global_step % sample_interval == 0:
620
+ if accelerator.is_main_process:
621
+ generate_and_save_samples(global_step)
622
+ accelerator.wait_for_everyone()
623
+
624
+ n_micro = sample_interval * gradient_accumulation_steps
625
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
626
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
627
+
628
+ if accelerator.is_main_process:
629
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
630
+ if save_model and avg_loss < min_loss * save_barrier:
631
+ min_loss = avg_loss
632
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
633
+ if use_wandb:
634
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
635
+
636
+ if accelerator.is_main_process:
637
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
638
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
639
+ if use_wandb:
640
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
641
+
642
+ # --------------------------- Финальное сохранение ---------------------------
643
+ if accelerator.is_main_process:
644
+ print("Training finished – saving final model")
645
+ if save_model:
646
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
647
+
648
+ accelerator.free_memory()
649
+ if torch.distributed.is_initialized():
650
+ torch.distributed.destroy_process_group()
651
+ print("Готово!")
vae2x/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2e5b25e39aec4b6a75e4837adec277dfc830e00992e6ce4dd75eb2627d73197
3
+ size 774
vae2x/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31c83db36d96ddfd42003f85abe8c22bd03a07b8174135351345d2726bd75c38
3
+ size 382598708