recoilme commited on
Commit
70c33fe
·
verified ·
1 Parent(s): 4c2f739

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/config-checkpoint.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": [
9
+ 10,
10
+ 20,
11
+ 20
12
+ ],
13
+ "attention_type": "default",
14
+ "block_out_channels": [
15
+ 320,
16
+ 640,
17
+ 1280
18
+ ],
19
+ "center_input_sample": false,
20
+ "class_embed_type": null,
21
+ "class_embeddings_concat": false,
22
+ "conv_in_kernel": 3,
23
+ "conv_out_kernel": 3,
24
+ "cross_attention_dim": 1024,
25
+ "cross_attention_norm": null,
26
+ "down_block_types": [
27
+ "DownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D"
30
+ ],
31
+ "downsample_padding": 1,
32
+ "dropout": 0.0,
33
+ "dual_cross_attention": false,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "in_channels": 32,
39
+ "layers_per_block": 2,
40
+ "mid_block_only_cross_attention": null,
41
+ "mid_block_scale_factor": 1.0,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "out_channels": 32,
49
+ "projection_class_embeddings_input_dim": null,
50
+ "resnet_out_scale_factor": 1.0,
51
+ "resnet_skip_time_act": false,
52
+ "resnet_time_scale_shift": "default",
53
+ "reverse_transformer_layers_per_block": null,
54
+ "sample_size": null,
55
+ "time_cond_proj_dim": null,
56
+ "time_embedding_act_fn": null,
57
+ "time_embedding_dim": null,
58
+ "time_embedding_type": "positional",
59
+ "timestep_post_act": null,
60
+ "transformer_layers_per_block": [
61
+ 1,
62
+ 2,
63
+ 4
64
+ ],
65
+ "up_block_types": [
66
+ "CrossAttnUpBlock2D",
67
+ "CrossAttnUpBlock2D",
68
+ "UpBlock2D"
69
+ ],
70
+ "upcast_attention": false,
71
+ "use_linear_projection": true
72
+ }
.ipynb_checkpoints/dataset_flux-checkpoint.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKL,AutoencoderKLWan,AutoencoderKLFlux2
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+
20
+ # ---------------- 1️⃣ Настройки ----------------
21
+ dtype = torch.float32
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ batch_size = 5
24
+ min_size = 192 #384 #320 #192 #256 #192
25
+ max_size = 320 #768 #640 #384 #256 #384
26
+ step = 64 #64
27
+ empty_share = 0.0
28
+ limit = 0
29
+ # Основная процедура обработки
30
+ folder_path = "/workspace/mjnj" #alchemist"
31
+ save_path = "/workspace/sdxs/datasets/mjnj" #"alchemist"
32
+ os.makedirs(save_path, exist_ok=True)
33
+
34
+ # Функция для очистки CUDA памяти
35
+ def clear_cuda_memory():
36
+ if torch.cuda.is_available():
37
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
38
+ print(f"used_gb: {used_gb:.2f} GB")
39
+ torch.cuda.empty_cache()
40
+ gc.collect()
41
+
42
+ # ---------------- 2️⃣ Загрузка моделей ----------------
43
+ def load_models():
44
+ print("Загрузка моделей...")
45
+ #vae = AutoencoderKL.from_pretrained("AiArtLab/sdxs",subfolder="vae1x",torch_dtype=dtype).to(device).eval()
46
+ vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
47
+
48
+ #model_name = "Qwen/Qwen3-0.6B"
49
+ #tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+ #model = AutoModelForCausalLM.from_pretrained(
51
+ # model_name,
52
+ # torch_dtype=dtype,
53
+ # device_map=device
54
+ #).eval()
55
+ #tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
56
+ #model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B').to("cuda")
57
+ return vae#, model, tokenizer
58
+
59
+ #vae, model, tokenizer = load_models()
60
+ vae = load_models()
61
+
62
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
63
+ if shift_factor is None:
64
+ shift_factor = 0.0
65
+
66
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
67
+ if scaling_factor is None:
68
+ scaling_factor = 1.0
69
+
70
+ latents_mean = getattr(vae.config, "latents_mean", None)
71
+ latents_std = getattr(vae.config, "latents_std", None)
72
+
73
+ # ---------------- 3️⃣ Трансформации ----------------
74
+ def get_image_transform(min_size=256, max_size=512, step=64):
75
+ def transform(img, dry_run=False):
76
+ # Сохраняем исходные размеры изображения
77
+ original_width, original_height = img.size
78
+
79
+ # 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size
80
+ if original_width >= original_height:
81
+ new_width = max_size
82
+ new_height = int(max_size * original_height / original_width)
83
+ else:
84
+ new_height = max_size
85
+ new_width = int(max_size * original_width / original_height)
86
+
87
+ if new_height < min_size or new_width < min_size:
88
+ # 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size
89
+ if original_width <= original_height:
90
+ new_width = min_size
91
+ new_height = int(min_size * original_height / original_width)
92
+ else:
93
+ new_height = min_size
94
+ new_width = int(min_size * original_width / original_height)
95
+
96
+ # 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке
97
+ crop_width = min(max_size, (new_width // step) * step)
98
+ crop_height = min(max_size, (new_height // step) * step)
99
+
100
+ # Убеждаемся, что размеры обрезки не меньше min_size
101
+ crop_width = max(min_size, crop_width)
102
+ crop_height = max(min_size, crop_height)
103
+
104
+ # Если запрошен только предварительный расчёт размеров
105
+ if dry_run:
106
+ return crop_width, crop_height
107
+
108
+ # Конвертация в RGB и ресайз
109
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
110
+
111
+ # Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху)
112
+ top = (new_height - crop_height) // 3
113
+ left = 0
114
+
115
+ # Обрезка изображения
116
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
117
+
118
+ # Сохраняем итоговые размеры после всех преобразований
119
+ final_width, final_height = img_cropped.size
120
+
121
+ # тензор
122
+ img_tensor = ToTensor()(img_cropped)
123
+ img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor)
124
+ return img_tensor, img_cropped, final_width, final_height
125
+
126
+ return transform
127
+
128
+ # ---------------- 4️⃣ Функции обработки ----------------
129
+ def last_token_pool(last_hidden_states: torch.Tensor,
130
+ attention_mask: torch.Tensor) -> torch.Tensor:
131
+ # Определяем, есть ли left padding
132
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
133
+ if left_padding:
134
+ return last_hidden_states[:, -1]
135
+ else:
136
+ sequence_lengths = attention_mask.sum(dim=1) - 1
137
+ batch_size = last_hidden_states.shape[0]
138
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
139
+
140
+ def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150, normalize=False):
141
+ with torch.inference_mode():
142
+ # Токенизация
143
+ batch = tokenizer(
144
+ texts,
145
+ return_tensors="pt",
146
+ padding="max_length",
147
+ truncation=True,
148
+ max_length=max_length
149
+ ).to(device)
150
+
151
+ # Прогон через модель
152
+ #outputs = model(**batch)
153
+
154
+ # Пулинг по last token
155
+ #embeddings = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
156
+
157
+ # L2-нормализация (опционально, обычно нужна для семантического поиска)
158
+ #if normalize:
159
+ # embeddings = F.normalize(embeddings, p=2, dim=1)
160
+
161
+ # Прогон через базовую модель (внутри CausalLM)
162
+ outputs = model.model(**batch, output_hidden_states=True)
163
+
164
+ # Берем последний слой (эмбеддинги всех токенов)
165
+ hidden_states = outputs.hidden_states[-1] # [B, L, D]
166
+
167
+ # Можно применить нормализацию по каждому токену (как в CLIP)
168
+ if normalize:
169
+ hidden_states = F.normalize(hidden_states, p=2, dim=-1)
170
+
171
+ return hidden_states.cpu().numpy() # embeddings.unsqueeze(1).cpu().numpy()
172
+
173
+ def clean_label(label):
174
+ label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
175
+ if label.startswith("."):
176
+ label = label[1:].lstrip()
177
+ return label
178
+
179
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
180
+ """
181
+ Обрабатывает список меток для classifier-free guidance.
182
+
183
+ С вероятностью prob_to_make_empty:
184
+ - Метка в первом списке заменяется на пустую строку.
185
+ - К метке во втором списке добавляется префикс "zero:".
186
+
187
+ В противном случае метки в обоих списках остаются оригинальными.
188
+
189
+ """
190
+ labels_for_model = []
191
+ labels_for_logging = []
192
+
193
+ for label in original_labels:
194
+ if random.random() < prob_to_make_empty:
195
+ labels_for_model.append("") # Заменяем на пустую строку для модели
196
+ labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования
197
+ else:
198
+ labels_for_model.append(label) # Оставляем оригинальную метку для модели
199
+ labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования
200
+
201
+ return labels_for_model, labels_for_logging
202
+
203
+ def encode_to_latents(images, texts):
204
+ transform = get_image_transform(min_size, max_size, step)
205
+
206
+ try:
207
+ # Обработка изображений (все одинакового размера)
208
+ transformed_tensors = []
209
+ pil_images = []
210
+ widths, heights = [], []
211
+
212
+ # Применяем трансформацию ко всем изображениям
213
+ for img in images:
214
+ try:
215
+ t_img, pil_img, w, h = transform(img)
216
+ transformed_tensors.append(t_img)
217
+ pil_images.append(pil_img)
218
+ widths.append(w)
219
+ heights.append(h)
220
+ except Exception as e:
221
+ print(f"Ошибка трансформации: {e}")
222
+ continue
223
+
224
+ if not transformed_tensors:
225
+ return None
226
+
227
+ # Создаём батч
228
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
229
+ if batch_tensor.ndim==5:
230
+ batch_tensor = batch_tensor.unsqueeze(2) # [B, C, 1, H, W]
231
+
232
+ # Кодируем батч
233
+ with torch.no_grad():
234
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
235
+ latents = (posteriors - shift_factor) / scaling_factor
236
+
237
+ latents_np = latents.to(dtype).cpu().numpy()
238
+
239
+ # Обрабатываем тексты
240
+ text_labels = [clean_label(text) for text in texts]
241
+
242
+ model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
243
+ #embeddings = encode_texts_batch(model_prompts, tokenizer, model)
244
+
245
+ return {
246
+ "vae": latents_np,
247
+ #"embeddings": embeddings,
248
+ "text": text_labels,
249
+ "width": widths,
250
+ "height": heights
251
+ }
252
+
253
+ except Exception as e:
254
+ print(f"Критическая ошибка в encode_to_latents: {e}")
255
+ raise
256
+
257
+
258
+ # ---------------- 5️⃣ Обработка папки с изображениями и текстами ----------------
259
+ def process_folder(folder_path, limit=None):
260
+ """
261
+ Рекурсивно обходит указанную директорию и все вложенные директории,
262
+ собирая пути к изображениям и соответствующим текстовым файлам.
263
+ """
264
+ image_paths = []
265
+ text_paths = []
266
+ width = []
267
+ height = []
268
+ transform = get_image_transform(min_size, max_size, step)
269
+
270
+ # Используем os.walk для рекурсивного обхода директорий
271
+ for root, dirs, files in os.walk(folder_path):
272
+ for filename in files:
273
+ # Проверяем, является ли файл изображением
274
+ if filename.lower().endswith((".jpg", ".jpeg", ".png")):
275
+ image_path = os.path.join(root, filename)
276
+ try:
277
+ img = Image.open(image_path)
278
+ except Exception as e:
279
+ print(f"Ошибка при открытии {image_path}: {e}")
280
+ os.remove(image_path)
281
+ text_path = os.path.splitext(image_path)[0] + ".txt"
282
+ if os.path.exists(text_path):
283
+ os.remove(text_path)
284
+ continue
285
+ # Применяем трансформацию только для получения размеров
286
+ w, h = transform(img, dry_run=True)
287
+ # Формируем путь к текстовому файлу
288
+ text_path = os.path.splitext(image_path)[0] + ".txt"
289
+
290
+ # Добавляем пути, если текстовый файл существует
291
+ if os.path.exists(text_path) and min(w, h)>0:
292
+ image_paths.append(image_path)
293
+ text_paths.append(text_path)
294
+ width.append(w) # Добавляем в список
295
+ height.append(h) # Добавляем в список
296
+
297
+ # Проверяем ограничение на количество
298
+ if limit and limit>0 and len(image_paths) >= limit:
299
+ print(f"Достигнут лимит в {limit} изображений")
300
+ return image_paths, text_paths, width, height
301
+
302
+ print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями")
303
+ return image_paths, text_paths, width, height
304
+
305
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1):
306
+ total_files = len(image_paths)
307
+ start_time = time.time()
308
+ chunks = range(0, total_files, chunk_size)
309
+
310
+ for chunk_idx, start in enumerate(chunks, 1):
311
+ end = min(start + chunk_size, total_files)
312
+ chunk_image_paths = image_paths[start:end]
313
+ chunk_text_paths = text_paths[start:end]
314
+ chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths)
315
+ chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths)
316
+
317
+ # Чтение текстов
318
+ chunk_texts = []
319
+ for text_path in chunk_text_paths:
320
+ try:
321
+ with open(text_path, 'r', encoding='utf-8') as f:
322
+ text = f.read().strip()
323
+ chunk_texts.append(text)
324
+ except Exception as e:
325
+ print(f"Ошибка чтения {text_path}: {e}")
326
+ chunk_texts.append("")
327
+
328
+ # Группируем изображения по размерам
329
+ size_groups = {}
330
+ for i in range(len(chunk_image_paths)):
331
+ size_key = (chunk_widths[i], chunk_heights[i])
332
+ if size_key not in size_groups:
333
+ size_groups[size_key] = {"image_paths": [], "texts": []}
334
+ size_groups[size_key]["image_paths"].append(chunk_image_paths[i])
335
+ size_groups[size_key]["texts"].append(chunk_texts[i])
336
+
337
+ # Обрабатываем каждую группу размеров отдельно
338
+ for size_key, group_data in size_groups.items():
339
+ print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений")
340
+
341
+ group_dataset = Dataset.from_dict({
342
+ "image_path": group_data["image_paths"],
343
+ "text": group_data["texts"]
344
+ })
345
+
346
+ # Теперь можно использовать указанный batch_size, т.к. все изображения одного размера
347
+ processed_group = group_dataset.map(
348
+ lambda examples: encode_to_latents(
349
+ [Image.open(path) for path in examples["image_path"]],
350
+ examples["text"]
351
+ ),
352
+ batched=True,
353
+ batch_size=batch_size,
354
+ #remove_columns=["image_path"],
355
+ desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}"
356
+ )
357
+
358
+ # Сохраняем результаты группы
359
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
360
+ processed_group.save_to_disk(group_save_path)
361
+ clear_cuda_memory()
362
+ elapsed = time.time() - start_time
363
+ processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
364
+ if processed > 0:
365
+ remaining = (elapsed / processed) * (total_files - processed)
366
+ elapsed_str = str(timedelta(seconds=int(elapsed)))
367
+ remaining_str = str(timedelta(seconds=int(remaining)))
368
+ print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
369
+
370
+ # ---------------- 7️⃣ Объединение чанков ----------------
371
+ def combine_chunks(temp_path, final_path):
372
+ """Объединение обработанных чанков в финальный датасет"""
373
+ chunks = sorted([
374
+ os.path.join(temp_path, d)
375
+ for d in os.listdir(temp_path)
376
+ if d.startswith("chunk_")
377
+ ])
378
+
379
+ datasets = [load_from_disk(chunk) for chunk in chunks]
380
+ combined = concatenate_datasets(datasets)
381
+ combined.save_to_disk(final_path)
382
+
383
+ print(f"✅ Датасет успешно сохранен в: {final_path}")
384
+
385
+
386
+
387
+ # Создаем временную папку для чанков
388
+ temp_path = f"{save_path}_temp"
389
+ os.makedirs(temp_path, exist_ok=True)
390
+
391
+ # Получаем список файлов
392
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
393
+ print(f"Всего найдено {len(image_paths)} изображений")
394
+
395
+ # Обработка с чанкованием
396
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size)
397
+
398
+ # Удаление папки
399
+ try:
400
+ shutil.rmtree(folder_path)
401
+ print(f"✅ Папка {folder_path} успешно удалена")
402
+ except Exception as e:
403
+ print(f"⚠️ Ошибка при удалении папки: {e}")
404
+
405
+ # Объединение чанков в финальный датасет
406
+ combine_chunks(temp_path, save_path)
407
+
408
+ # Удаление временной папки
409
+ try:
410
+ shutil.rmtree(temp_path)
411
+ print(f"✅ Временная папка {temp_path} успешно удалена")
412
+ except Exception as e:
413
+ print(f"⚠️ Ошибка при удалении временной папки: {e}")
.ipynb_checkpoints/sdxs_create_flux-checkpoint.ipynb ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "6bf71a1a-1bf0-42c7-8709-6686e8d2f46c",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "test unet\n",
14
+ "Количество параметров: 798780960\n",
15
+ "Output shape: torch.Size([1, 32, 60, 48])\n",
16
+ "UNet2DConditionModel(\n",
17
+ " (conv_in): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
18
+ " (time_proj): Timesteps()\n",
19
+ " (time_embedding): TimestepEmbedding(\n",
20
+ " (linear_1): Linear(in_features=128, out_features=512, bias=True)\n",
21
+ " (act): SiLU()\n",
22
+ " (linear_2): Linear(in_features=512, out_features=512, bias=True)\n",
23
+ " )\n",
24
+ " (down_blocks): ModuleList(\n",
25
+ " (0): DownBlock2D(\n",
26
+ " (resnets): ModuleList(\n",
27
+ " (0-1): 2 x ResnetBlock2D(\n",
28
+ " (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
29
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
30
+ " (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)\n",
31
+ " (norm2): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
32
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
33
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
34
+ " (nonlinearity): SiLU()\n",
35
+ " )\n",
36
+ " )\n",
37
+ " (downsamplers): ModuleList(\n",
38
+ " (0): Downsample2D(\n",
39
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
40
+ " )\n",
41
+ " )\n",
42
+ " )\n",
43
+ " (1): DownBlock2D(\n",
44
+ " (resnets): ModuleList(\n",
45
+ " (0): ResnetBlock2D(\n",
46
+ " (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
47
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
48
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
49
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
50
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
51
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
52
+ " (nonlinearity): SiLU()\n",
53
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
54
+ " )\n",
55
+ " (1): ResnetBlock2D(\n",
56
+ " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
57
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
58
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
59
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
60
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
61
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
62
+ " (nonlinearity): SiLU()\n",
63
+ " )\n",
64
+ " )\n",
65
+ " (downsamplers): ModuleList(\n",
66
+ " (0): Downsample2D(\n",
67
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
68
+ " )\n",
69
+ " )\n",
70
+ " )\n",
71
+ " (2): CrossAttnDownBlock2D(\n",
72
+ " (attentions): ModuleList(\n",
73
+ " (0-1): 2 x Transformer2DModel(\n",
74
+ " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
75
+ " (proj_in): Linear(in_features=512, out_features=512, bias=True)\n",
76
+ " (transformer_blocks): ModuleList(\n",
77
+ " (0-1): 2 x BasicTransformerBlock(\n",
78
+ " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
79
+ " (attn1): Attention(\n",
80
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
81
+ " (to_k): Linear(in_features=512, out_features=512, bias=False)\n",
82
+ " (to_v): Linear(in_features=512, out_features=512, bias=False)\n",
83
+ " (to_out): ModuleList(\n",
84
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
85
+ " (1): Dropout(p=0.0, inplace=False)\n",
86
+ " )\n",
87
+ " )\n",
88
+ " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
89
+ " (attn2): Attention(\n",
90
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
91
+ " (to_k): Linear(in_features=1024, out_features=512, bias=False)\n",
92
+ " (to_v): Linear(in_features=1024, out_features=512, bias=False)\n",
93
+ " (to_out): ModuleList(\n",
94
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
95
+ " (1): Dropout(p=0.0, inplace=False)\n",
96
+ " )\n",
97
+ " )\n",
98
+ " (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
99
+ " (ff): FeedForward(\n",
100
+ " (net): ModuleList(\n",
101
+ " (0): GEGLU(\n",
102
+ " (proj): Linear(in_features=512, out_features=4096, bias=True)\n",
103
+ " )\n",
104
+ " (1): Dropout(p=0.0, inplace=False)\n",
105
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
106
+ " )\n",
107
+ " )\n",
108
+ " )\n",
109
+ " )\n",
110
+ " (proj_out): Linear(in_features=512, out_features=512, bias=True)\n",
111
+ " )\n",
112
+ " )\n",
113
+ " (resnets): ModuleList(\n",
114
+ " (0): ResnetBlock2D(\n",
115
+ " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
116
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
117
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
118
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
119
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
120
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
121
+ " (nonlinearity): SiLU()\n",
122
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
123
+ " )\n",
124
+ " (1): ResnetBlock2D(\n",
125
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
126
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
127
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
128
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
129
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
130
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
131
+ " (nonlinearity): SiLU()\n",
132
+ " )\n",
133
+ " )\n",
134
+ " (downsamplers): ModuleList(\n",
135
+ " (0): Downsample2D(\n",
136
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
137
+ " )\n",
138
+ " )\n",
139
+ " )\n",
140
+ " (3): CrossAttnDownBlock2D(\n",
141
+ " (attentions): ModuleList(\n",
142
+ " (0-1): 2 x Transformer2DModel(\n",
143
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
144
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
145
+ " (transformer_blocks): ModuleList(\n",
146
+ " (0-3): 4 x BasicTransformerBlock(\n",
147
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
148
+ " (attn1): Attention(\n",
149
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
150
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
151
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
152
+ " (to_out): ModuleList(\n",
153
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
154
+ " (1): Dropout(p=0.0, inplace=False)\n",
155
+ " )\n",
156
+ " )\n",
157
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
158
+ " (attn2): Attention(\n",
159
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
160
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
161
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
162
+ " (to_out): ModuleList(\n",
163
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
164
+ " (1): Dropout(p=0.0, inplace=False)\n",
165
+ " )\n",
166
+ " )\n",
167
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
168
+ " (ff): FeedForward(\n",
169
+ " (net): ModuleList(\n",
170
+ " (0): GEGLU(\n",
171
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
172
+ " )\n",
173
+ " (1): Dropout(p=0.0, inplace=False)\n",
174
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
175
+ " )\n",
176
+ " )\n",
177
+ " )\n",
178
+ " )\n",
179
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
180
+ " )\n",
181
+ " )\n",
182
+ " (resnets): ModuleList(\n",
183
+ " (0): ResnetBlock2D(\n",
184
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
185
+ " (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
186
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
187
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
188
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
189
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
190
+ " (nonlinearity): SiLU()\n",
191
+ " (conv_shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
192
+ " )\n",
193
+ " (1): ResnetBlock2D(\n",
194
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
195
+ " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
196
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
197
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
198
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
199
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
200
+ " (nonlinearity): SiLU()\n",
201
+ " )\n",
202
+ " )\n",
203
+ " )\n",
204
+ " )\n",
205
+ " (up_blocks): ModuleList(\n",
206
+ " (0): CrossAttnUpBlock2D(\n",
207
+ " (attentions): ModuleList(\n",
208
+ " (0-2): 3 x Transformer2DModel(\n",
209
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
210
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
211
+ " (transformer_blocks): ModuleList(\n",
212
+ " (0-3): 4 x BasicTransformerBlock(\n",
213
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
214
+ " (attn1): Attention(\n",
215
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
216
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
217
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
218
+ " (to_out): ModuleList(\n",
219
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
220
+ " (1): Dropout(p=0.0, inplace=False)\n",
221
+ " )\n",
222
+ " )\n",
223
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
224
+ " (attn2): Attention(\n",
225
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
226
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
227
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
228
+ " (to_out): ModuleList(\n",
229
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
230
+ " (1): Dropout(p=0.0, inplace=False)\n",
231
+ " )\n",
232
+ " )\n",
233
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
234
+ " (ff): FeedForward(\n",
235
+ " (net): ModuleList(\n",
236
+ " (0): GEGLU(\n",
237
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
238
+ " )\n",
239
+ " (1): Dropout(p=0.0, inplace=False)\n",
240
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
241
+ " )\n",
242
+ " )\n",
243
+ " )\n",
244
+ " )\n",
245
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
246
+ " )\n",
247
+ " )\n",
248
+ " (resnets): ModuleList(\n",
249
+ " (0-1): 2 x ResnetBlock2D(\n",
250
+ " (norm1): GroupNorm(32, 2048, eps=1e-05, affine=True)\n",
251
+ " (conv1): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
252
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
253
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
254
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
255
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
256
+ " (nonlinearity): SiLU()\n",
257
+ " (conv_shortcut): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
258
+ " )\n",
259
+ " (2): ResnetBlock2D(\n",
260
+ " (norm1): GroupNorm(32, 1536, eps=1e-05, affine=True)\n",
261
+ " (conv1): Conv2d(1536, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
262
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
263
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
264
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
265
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
266
+ " (nonlinearity): SiLU()\n",
267
+ " (conv_shortcut): Conv2d(1536, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
268
+ " )\n",
269
+ " )\n",
270
+ " (upsamplers): ModuleList(\n",
271
+ " (0): Upsample2D(\n",
272
+ " (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
273
+ " )\n",
274
+ " )\n",
275
+ " )\n",
276
+ " (1): CrossAttnUpBlock2D(\n",
277
+ " (attentions): ModuleList(\n",
278
+ " (0-2): 3 x Transformer2DModel(\n",
279
+ " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
280
+ " (proj_in): Linear(in_features=512, out_features=512, bias=True)\n",
281
+ " (transformer_blocks): ModuleList(\n",
282
+ " (0-1): 2 x BasicTransformerBlock(\n",
283
+ " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
284
+ " (attn1): Attention(\n",
285
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
286
+ " (to_k): Linear(in_features=512, out_features=512, bias=False)\n",
287
+ " (to_v): Linear(in_features=512, out_features=512, bias=False)\n",
288
+ " (to_out): ModuleList(\n",
289
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
290
+ " (1): Dropout(p=0.0, inplace=False)\n",
291
+ " )\n",
292
+ " )\n",
293
+ " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
294
+ " (attn2): Attention(\n",
295
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
296
+ " (to_k): Linear(in_features=1024, out_features=512, bias=False)\n",
297
+ " (to_v): Linear(in_features=1024, out_features=512, bias=False)\n",
298
+ " (to_out): ModuleList(\n",
299
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
300
+ " (1): Dropout(p=0.0, inplace=False)\n",
301
+ " )\n",
302
+ " )\n",
303
+ " (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
304
+ " (ff): FeedForward(\n",
305
+ " (net): ModuleList(\n",
306
+ " (0): GEGLU(\n",
307
+ " (proj): Linear(in_features=512, out_features=4096, bias=True)\n",
308
+ " )\n",
309
+ " (1): Dropout(p=0.0, inplace=False)\n",
310
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
311
+ " )\n",
312
+ " )\n",
313
+ " )\n",
314
+ " )\n",
315
+ " (proj_out): Linear(in_features=512, out_features=512, bias=True)\n",
316
+ " )\n",
317
+ " )\n",
318
+ " (resnets): ModuleList(\n",
319
+ " (0): ResnetBlock2D(\n",
320
+ " (norm1): GroupNorm(32, 1536, eps=1e-05, affine=True)\n",
321
+ " (conv1): Conv2d(1536, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
322
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
323
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
324
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
325
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
326
+ " (nonlinearity): SiLU()\n",
327
+ " (conv_shortcut): Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1))\n",
328
+ " )\n",
329
+ " (1): ResnetBlock2D(\n",
330
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
331
+ " (conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
332
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
333
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
334
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
335
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
336
+ " (nonlinearity): SiLU()\n",
337
+ " (conv_shortcut): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))\n",
338
+ " )\n",
339
+ " (2): ResnetBlock2D(\n",
340
+ " (norm1): GroupNorm(32, 768, eps=1e-05, affine=True)\n",
341
+ " (conv1): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
342
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
343
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
344
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
345
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
346
+ " (nonlinearity): SiLU()\n",
347
+ " (conv_shortcut): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1))\n",
348
+ " )\n",
349
+ " )\n",
350
+ " (upsamplers): ModuleList(\n",
351
+ " (0): Upsample2D(\n",
352
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
353
+ " )\n",
354
+ " )\n",
355
+ " )\n",
356
+ " (2): UpBlock2D(\n",
357
+ " (resnets): ModuleList(\n",
358
+ " (0): ResnetBlock2D(\n",
359
+ " (norm1): GroupNorm(32, 768, eps=1e-05, affine=True)\n",
360
+ " (conv1): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
361
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
362
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
363
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
364
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
365
+ " (nonlinearity): SiLU()\n",
366
+ " (conv_shortcut): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1))\n",
367
+ " )\n",
368
+ " (1): ResnetBlock2D(\n",
369
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
370
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
371
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
372
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
373
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
374
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
375
+ " (nonlinearity): SiLU()\n",
376
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
377
+ " )\n",
378
+ " (2): ResnetBlock2D(\n",
379
+ " (norm1): GroupNorm(32, 384, eps=1e-05, affine=True)\n",
380
+ " (conv1): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
381
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
382
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
383
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
384
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
385
+ " (nonlinearity): SiLU()\n",
386
+ " (conv_shortcut): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))\n",
387
+ " )\n",
388
+ " )\n",
389
+ " (upsamplers): ModuleList(\n",
390
+ " (0): Upsample2D(\n",
391
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
392
+ " )\n",
393
+ " )\n",
394
+ " )\n",
395
+ " (3): UpBlock2D(\n",
396
+ " (resnets): ModuleList(\n",
397
+ " (0): ResnetBlock2D(\n",
398
+ " (norm1): GroupNorm(32, 384, eps=1e-05, affine=True)\n",
399
+ " (conv1): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
400
+ " (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)\n",
401
+ " (norm2): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
402
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
403
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
404
+ " (nonlinearity): SiLU()\n",
405
+ " (conv_shortcut): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1))\n",
406
+ " )\n",
407
+ " (1-2): 2 x ResnetBlock2D(\n",
408
+ " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
409
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
410
+ " (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)\n",
411
+ " (norm2): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
412
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
413
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
414
+ " (nonlinearity): SiLU()\n",
415
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
416
+ " )\n",
417
+ " )\n",
418
+ " )\n",
419
+ " )\n",
420
+ " (mid_block): UNetMidBlock2DCrossAttn(\n",
421
+ " (attentions): ModuleList(\n",
422
+ " (0): Transformer2DModel(\n",
423
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
424
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
425
+ " (transformer_blocks): ModuleList(\n",
426
+ " (0-3): 4 x BasicTransformerBlock(\n",
427
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
428
+ " (attn1): Attention(\n",
429
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
430
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
431
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
432
+ " (to_out): ModuleList(\n",
433
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
434
+ " (1): Dropout(p=0.0, inplace=False)\n",
435
+ " )\n",
436
+ " )\n",
437
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
438
+ " (attn2): Attention(\n",
439
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
440
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
441
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
442
+ " (to_out): ModuleList(\n",
443
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
444
+ " (1): Dropout(p=0.0, inplace=False)\n",
445
+ " )\n",
446
+ " )\n",
447
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
448
+ " (ff): FeedForward(\n",
449
+ " (net): ModuleList(\n",
450
+ " (0): GEGLU(\n",
451
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
452
+ " )\n",
453
+ " (1): Dropout(p=0.0, inplace=False)\n",
454
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
455
+ " )\n",
456
+ " )\n",
457
+ " )\n",
458
+ " )\n",
459
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
460
+ " )\n",
461
+ " )\n",
462
+ " (resnets): ModuleList(\n",
463
+ " (0-1): 2 x ResnetBlock2D(\n",
464
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
465
+ " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
466
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
467
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
468
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
469
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
470
+ " (nonlinearity): SiLU()\n",
471
+ " )\n",
472
+ " )\n",
473
+ " )\n",
474
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
475
+ " (conv_act): SiLU()\n",
476
+ " (conv_out): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
477
+ ")\n",
478
+ "Output shape: torch.Size([1, 32, 60, 48])\n"
479
+ ]
480
+ }
481
+ ],
482
+ "source": [
483
+ "config_sdxs = {\n",
484
+ " # === Основные размеры и каналы ===\n",
485
+ " \"in_channels\": 32, # Количество входных каналов (совместимость с 16-канальным VAE)\n",
486
+ " \"out_channels\": 32, # Количество выходных каналов (симметрично in_channels)\n",
487
+ " \"center_input_sample\": False, # Отключение центрирования входных данных (стандарт для диффузионных моделей)\n",
488
+ " \"flip_sin_to_cos\": True, # Автоматическое преобразование sin/cos в эмбеддингах времени (для стабильности)\n",
489
+ " \"freq_shift\": 0, # Сдвиг частоты (0 - стандартное значение для частотных эмбеддингов)\n",
490
+ "\n",
491
+ " # === Архитектура блоков ===\n",
492
+ " \"down_block_types\": [ # Типы блоков энкодера (иерархия обработки):\n",
493
+ " \"DownBlock2D\",\n",
494
+ " \"DownBlock2D\",\n",
495
+ " \"CrossAttnDownBlock2D\",\n",
496
+ " \"CrossAttnDownBlock2D\",\n",
497
+ " ],\n",
498
+ " \"mid_block_type\": \"UNetMidBlock2DCrossAttn\", # Центральный блок с cross-attention (бутылочное горлышко сети)\n",
499
+ " \"up_block_types\": [ # Типы блоков декодера (восстановление изображения):\n",
500
+ " \"CrossAttnUpBlock2D\",\n",
501
+ " \"CrossAttnUpBlock2D\", \n",
502
+ " \"UpBlock2D\",\n",
503
+ " \"UpBlock2D\",\n",
504
+ " ],\n",
505
+ " \"only_cross_attention\": False, # Использование как cross-attention, так и self-attention\n",
506
+ "\n",
507
+ " # === Конфигурация каналов ===\n",
508
+ " \"block_out_channels\": [128, 256, 512, 1024], \n",
509
+ " \"layers_per_block\": 2, # Число слоев в блоках\n",
510
+ " \"downsample_padding\": 1, # Паддинг при уменьшении разрешения\n",
511
+ " \"mid_block_scale_factor\": 1.0, # Усиление сигнала в центральном блоке\n",
512
+ "\n",
513
+ " # === Нормализация ===\n",
514
+ " \"norm_num_groups\": 32, # Число групп для GroupNorm (оптимально для стабильности)\n",
515
+ " \"norm_eps\": 1e-05, # Эпсилон для нормализации (стандартное значение)\n",
516
+ "\n",
517
+ " # === Cross-Attention ===\n",
518
+ " \"cross_attention_dim\": 1024, # Размерность текстовых эмбеддинго\n",
519
+ " \n",
520
+ " \"transformer_layers_per_block\": [1, 1, 2, 4], # Число трансформерных слоев (уменьшение с глубиной)\n",
521
+ " \"attention_head_dim\": [2, 4, 8, 16], # Размерность головы внимания \n",
522
+ " \"dual_cross_attention\": False, # Отключение двойного внимания (упрощение архитектуры)\n",
523
+ " \"use_linear_projection\": True, # Изменено на True для лучшей организации памяти\n",
524
+ "\n",
525
+ " # === ResNet Блоки ===\n",
526
+ " \"resnet_time_scale_shift\": \"default\", # Способ интеграции временных эмбеддингов\n",
527
+ " \"resnet_skip_time_act\": False, # Отключение активации в skip-соединениях\n",
528
+ " \"resnet_out_scale_factor\": 1.0, # Коэффициент масштабирования выхода ResNet\n",
529
+ "\n",
530
+ " # === Временные эмбеддинги ===\n",
531
+ " \"time_embedding_type\": \"positional\", # Тип временных эмбеддингов (стандартный подход)\n",
532
+ "\n",
533
+ " # === Свертки ===\n",
534
+ " \"conv_in_kernel\": 3, # Ядро входной свертки (баланс между рецептивным полем и параметрами)\n",
535
+ " \"conv_out_kernel\": 3, # Ядро выходной свертки (симметрично входной)\n",
536
+ "}\n",
537
+ "\n",
538
+ "if 1:\n",
539
+ " checkpoint_path = \"/workspace/sdxs/sdxs_flux\"#\"sdxs\"\n",
540
+ " import torch\n",
541
+ " from diffusers import UNet2DConditionModel\n",
542
+ " print(\"test unet\")\n",
543
+ " new_unet = UNet2DConditionModel(**config_sdxs).to(\"cuda\", dtype=torch.float16)\n",
544
+ "\n",
545
+ " assert all(ch % 32 == 0 for ch in new_unet.config[\"block_out_channels\"]), \"Каналы должны быть кратны 32\"\n",
546
+ " num_params = sum(p.numel() for p in new_unet.parameters())\n",
547
+ " print(f\"Количество параметров: {num_params}\")\n",
548
+ "\n",
549
+ " # Генерация тестового латента (640x512 в latent space)\n",
550
+ " test_latent = torch.randn(1, 32, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n",
551
+ " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n",
552
+ " encoder_hidden_states = torch.randn(1, 77, 1024).to(\"cuda\", dtype=torch.float16)\n",
553
+ " \n",
554
+ " with torch.no_grad():\n",
555
+ " output = new_unet(\n",
556
+ " test_latent, \n",
557
+ " timesteps, \n",
558
+ " encoder_hidden_states\n",
559
+ " ).sample\n",
560
+ " \n",
561
+ " print(f\"Output shape: {output.shape}\") \n",
562
+ " new_unet.save_pretrained(checkpoint_path)\n",
563
+ " print(new_unet)\n",
564
+ " del new_unet\n",
565
+ " torch.cuda.empty_cache()\n",
566
+ " print(f\"Output shape: {output.shape}\") \n",
567
+ " # Количество параметров: 1228819856 2042991632 1889899536"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "id": "a81bca2d-d96d-4794-90f4-b54ea8682b52",
574
+ "metadata": {},
575
+ "outputs": [],
576
+ "source": []
577
+ }
578
+ ],
579
+ "metadata": {
580
+ "kernelspec": {
581
+ "display_name": "Python 3 (ipykernel)",
582
+ "language": "python",
583
+ "name": "python3"
584
+ },
585
+ "language_info": {
586
+ "codemirror_mode": {
587
+ "name": "ipython",
588
+ "version": 3
589
+ },
590
+ "file_extension": ".py",
591
+ "mimetype": "text/x-python",
592
+ "name": "python",
593
+ "nbconvert_exporter": "python",
594
+ "pygments_lexer": "ipython3",
595
+ "version": "3.12.3"
596
+ }
597
+ },
598
+ "nbformat": 4,
599
+ "nbformat_minor": 5
600
+ }
.ipynb_checkpoints/train_flux-checkpoint.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from comet_ml import Experiment
2
+ import os
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from torch.utils.data import DataLoader, Sampler
8
+ from torch.utils.data.distributed import DistributedSampler
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from collections import defaultdict
11
+ from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image, ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+ from transformers import AutoTokenizer, AutoModel
28
+
29
+ # --------------------------- Параметры ---------------------------
30
+ ds_path = "/workspace/sdxs/datasets/mjnj"
31
+ project = "sdxs_flux"
32
+ batch_size = 32
33
+ base_learning_rate = 4e-5 #2.7e-5
34
+ min_learning_rate = 9e-6 #2.7e-5
35
+ num_epochs = 10
36
+ sample_interval_share = 10
37
+ cfg_dropout = 0.9
38
+ max_length = 192
39
+ use_wandb = True
40
+ use_comet_ml = False
41
+ save_model = True
42
+ use_decay = True
43
+ fbp = False
44
+ optimizer_type = "adam8bit"
45
+ torch_compile = False
46
+ unet_gradient = True
47
+ fixed_seed = False
48
+ shuffle = True
49
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
50
+ comet_ml_workspace = "recoilme"
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+ torch.backends.cudnn.allow_tf32 = True
53
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
54
+ dtype = torch.float32
55
+ save_barrier = 1.01
56
+ warmup_percent = 0.01
57
+ percentile_clipping = 95 #96 #97
58
+ betta2 = 0.995
59
+ eps = 1e-7
60
+ clip_grad_norm = 1.0
61
+ limit = 0
62
+ checkpoints_folder = ""
63
+ mixed_precision = "no"
64
+ gradient_accumulation_steps = 1
65
+
66
+ accelerator = Accelerator(
67
+ mixed_precision=mixed_precision,
68
+ gradient_accumulation_steps=gradient_accumulation_steps
69
+ )
70
+ device = accelerator.device
71
+
72
+ # Параметры для диффузии
73
+ n_diffusion_steps = 40
74
+ samples_to_generate = 12
75
+ guidance_scale = 4
76
+
77
+ # Папки для сохранения результатов
78
+ generated_folder = "samples"
79
+ os.makedirs(generated_folder, exist_ok=True)
80
+
81
+ # Настройка seed
82
+ current_date = datetime.now()
83
+ seed = int(current_date.strftime("%Y%m%d"))
84
+ if fixed_seed:
85
+ torch.manual_seed(seed)
86
+ np.random.seed(seed)
87
+ random.seed(seed)
88
+ if torch.cuda.is_available():
89
+ torch.cuda.manual_seed_all(seed)
90
+
91
+ # --------------------------- Параметры LoRA ---------------------------
92
+ lora_name = ""
93
+ lora_rank = 32
94
+ lora_alpha = 64
95
+
96
+ print("init")
97
+
98
+ loss_ratios = {
99
+ "mse": 1.5,
100
+ "mae": 0.5,
101
+ }
102
+ median_coeff_steps = 256
103
+
104
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
105
+ class MedianLossNormalizer:
106
+ def __init__(self, desired_ratios: dict, window_steps: int):
107
+ # нормируем доли на случай, если сумма != 1
108
+ #s = sum(desired_ratios.values())
109
+ #self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
110
+ #self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
111
+ #self.window = window_steps
112
+ self.ratios = {k: float(v) for k, v in desired_ratios.items()}
113
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
114
+ self.window = window_steps
115
+
116
+ def update_and_total(self, losses: dict):
117
+ """
118
+ losses: dict ключ->тензор (значения лоссов)
119
+ Поведение:
120
+ - буферим ABS(l) только для активных (ratio>0) лоссов
121
+ - coeff = ratio / median(abs(loss))
122
+ - total = sum(coeff * loss) по активным лоссам
123
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
124
+ """
125
+ # буферим только активные лоссы
126
+ for k, v in losses.items():
127
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
128
+ val = v.detach().abs().mean().cpu().item() # .item() лучше float() для тензоров
129
+ self.buffers[k].append(val)
130
+ #self.buffers[k].append(float(v.detach().abs().cpu()))
131
+
132
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
133
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
134
+
135
+ # суммируем только по активным (ratio>0)
136
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
137
+ return total, coeffs, meds
138
+
139
+ # создаём normalizer после определения loss_ratios
140
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
141
+
142
+ # --------------------------- Ин��циализация WandB ---------------------------
143
+ if accelerator.is_main_process:
144
+ if use_wandb:
145
+ wandb.init(project=project+lora_name, config={
146
+ "batch_size": batch_size,
147
+ "base_learning_rate": base_learning_rate,
148
+ "num_epochs": num_epochs,
149
+ "optimizer_type": optimizer_type,
150
+ })
151
+ if use_comet_ml:
152
+ from comet_ml import Experiment
153
+ comet_experiment = Experiment(
154
+ api_key=comet_ml_api_key,
155
+ project_name=project,
156
+ workspace=comet_ml_workspace
157
+ )
158
+ hyper_params = {
159
+ "batch_size": batch_size,
160
+ "base_learning_rate": base_learning_rate,
161
+ "num_epochs": num_epochs,
162
+ }
163
+ comet_experiment.log_parameters(hyper_params)
164
+
165
+ # Включение Flash Attention 2/SDPA
166
+ torch.backends.cuda.enable_flash_sdp(True)
167
+
168
+ # --------------------------- Загрузка моделей ---------------------------
169
+ #vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
170
+ vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
171
+ tokenizer = AutoTokenizer.from_pretrained("tokenizer")
172
+ text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
173
+
174
+ # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
175
+ def encode_texts(texts, max_length=max_length):
176
+ # Если тексты пустые (для unconditional), создаем заглушки
177
+ if texts is None:
178
+ # В случае None возвращаем нули (логика для get_negative_embedding)
179
+ # Но здесь мы обычно ожидаем список строк.
180
+ pass
181
+
182
+ with torch.no_grad():
183
+ if isinstance(texts, str):
184
+ texts = [texts]
185
+
186
+ for i, prompt_item in enumerate(texts):
187
+ messages = [
188
+ {"role": "user", "content": prompt_item},
189
+ ]
190
+ prompt_item = tokenizer.apply_chat_template(
191
+ messages,
192
+ tokenize=False,
193
+ add_generation_prompt=True,
194
+ #enable_thinking=True,
195
+ )
196
+ #print(prompt_item+"\n")
197
+ texts[i] = prompt_item
198
+
199
+ toks = tokenizer(
200
+ texts,
201
+ return_tensors="pt",
202
+ padding="max_length",
203
+ truncation=True,
204
+ max_length=max_length
205
+ ).to(device)
206
+
207
+ outs = text_model(**toks, output_hidden_states=True, return_dict=True)
208
+
209
+ # Используем last_hidden_state или hidden_states[-1] (если Qwen, лучше last_hidden_state - прим человека: ХУЙ)
210
+ hidden = outs.hidden_states[-2]
211
+
212
+ # 2. Маска внимания
213
+ attention_mask = toks["attention_mask"]
214
+
215
+ # 3. Пулинг-эмбеддинг (Последний токен)
216
+ sequence_lengths = attention_mask.sum(dim=1) - 1
217
+ batch_size = hidden.shape[0]
218
+ pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
219
+
220
+ #return hidden, attention_mask
221
+ # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
222
+ # 1. Расширяем пулинг-вектор до последовательности [B, 1, emb]
223
+ pooled_expanded = pooled.unsqueeze(1)
224
+
225
+ # 2. Объединяем последовательность токенов и пулинг-вектор
226
+ # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
227
+ # Теперь: [B, 1 + L, emb]. Пулинг стал токеном в НАЧАЛЕ.
228
+ new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
229
+
230
+ # 3. Обновляем маску внимания для нового токена
231
+ # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
232
+ # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
233
+ new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
234
+
235
+ return new_encoder_hidden_states, new_attention_mask
236
+
237
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
238
+ if shift_factor is None: shift_factor = 0.0
239
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
240
+ if scaling_factor is None: scaling_factor = 1.0
241
+
242
+ from diffusers import FlowMatchEulerDiscreteScheduler
243
+ num_train_timesteps = 1000
244
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps)
245
+
246
+ class DistributedResolutionBatchSampler(Sampler):
247
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
248
+ self.dataset = dataset
249
+ self.batch_size = max(1, batch_size // num_replicas)
250
+ self.num_replicas = num_replicas
251
+ self.rank = rank
252
+ self.shuffle = shuffle
253
+ self.drop_last = drop_last
254
+ self.epoch = 0
255
+
256
+ try:
257
+ widths = np.array(dataset["width"])
258
+ heights = np.array(dataset["height"])
259
+ except KeyError:
260
+ widths = np.zeros(len(dataset))
261
+ heights = np.zeros(len(dataset))
262
+
263
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
264
+ self.size_groups = {}
265
+ for w, h in self.size_keys:
266
+ mask = (widths == w) & (heights == h)
267
+ self.size_groups[(w, h)] = np.where(mask)[0]
268
+
269
+ self.group_num_batches = {}
270
+ total_batches = 0
271
+ for size, indices in self.size_groups.items():
272
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
273
+ self.group_num_batches[size] = num_full_batches
274
+ total_batches += num_full_batches
275
+
276
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
277
+
278
+ def __iter__(self):
279
+ if torch.cuda.is_available():
280
+ torch.cuda.empty_cache()
281
+ all_batches = []
282
+ rng = np.random.RandomState(self.epoch)
283
+
284
+ for size, indices in self.size_groups.items():
285
+ indices = indices.copy()
286
+ if self.shuffle:
287
+ rng.shuffle(indices)
288
+ num_full_batches = self.group_num_batches[size]
289
+ if num_full_batches == 0:
290
+ continue
291
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
292
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
293
+ start_idx = self.rank * self.batch_size
294
+ end_idx = start_idx + self.batch_size
295
+ gpu_batches = batches[:, start_idx:end_idx]
296
+ all_batches.extend(gpu_batches)
297
+
298
+ if self.shuffle:
299
+ rng.shuffle(all_batches)
300
+ accelerator.wait_for_everyone()
301
+ return iter(all_batches)
302
+
303
+ def __len__(self):
304
+ return self.num_batches
305
+
306
+ def set_epoch(self, epoch):
307
+ self.epoch = epoch
308
+
309
+ # --- [UPDATED] Функция для фиксированных семплов ---
310
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
311
+ size_groups = defaultdict(list)
312
+ try:
313
+ widths = dataset["width"]
314
+ heights = dataset["height"]
315
+ except KeyError:
316
+ widths = [0] * len(dataset)
317
+ heights = [0] * len(dataset)
318
+ for i, (w, h) in enumerate(zip(widths, heights)):
319
+ size = (w, h)
320
+ size_groups[size].append(i)
321
+
322
+ fixed_samples = {}
323
+ for size, indices in size_groups.items():
324
+ n_samples = min(samples_per_group, len(indices))
325
+ if len(size_groups)==1:
326
+ n_samples = samples_to_generate
327
+ if n_samples == 0:
328
+ continue
329
+ sample_indices = random.sample(indices, n_samples)
330
+ samples_data = [dataset[idx] for idx in sample_indices]
331
+
332
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
333
+ texts = [item["text"] for item in samples_data]
334
+
335
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
336
+ embeddings, masks = encode_texts(texts)
337
+
338
+ fixed_samples[size] = (latents, embeddings, masks, texts)
339
+
340
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
341
+ return fixed_samples
342
+
343
+ if limit > 0:
344
+ dataset = load_from_disk(ds_path).select(range(limit))
345
+ else:
346
+ dataset = load_from_disk(ds_path)
347
+
348
+ dataset = dataset.filter(
349
+ lambda x: [not (path.startswith("/workspace/ds/animesfw") or path.startswith("/workspace/ds/d4/animesfw")) for path in x["image_path"]],
350
+ batched=True,
351
+ batch_size=10000, # обрабатываем по 10к строк за раз
352
+ num_proc=8
353
+ )
354
+ print(f"Осталось примеров после фильтрации: {len(dataset)}")
355
+
356
+ # --- [UPDATED] Collate Function ---
357
+ def collate_fn_simple(batch):
358
+ # 1. Латенты (VAE)
359
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype)
360
+
361
+ # 2. Текст берем сырой из датасета
362
+ raw_texts = [item["text"] for item in batch]
363
+ texts = [
364
+ "" if t.lower().startswith("zero")
365
+ else "" if random.random() < cfg_dropout
366
+ else t[1:].lstrip() if t.startswith(".")
367
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
368
+ for t in raw_texts
369
+ ]
370
+
371
+ # 3. Кодируем на лету
372
+ # Возвращает: hidden (B, L, D), mask (B, L)
373
+ embeddings, attention_mask = encode_texts(texts)
374
+
375
+ # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай пр��ведем к long
376
+ attention_mask = attention_mask.to(dtype=torch.int64)
377
+
378
+ return latents, embeddings, attention_mask
379
+
380
+ batch_sampler = DistributedResolutionBatchSampler(
381
+ dataset=dataset,
382
+ batch_size=batch_size,
383
+ num_replicas=accelerator.num_processes,
384
+ rank=accelerator.process_index,
385
+ shuffle=shuffle
386
+ )
387
+
388
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
389
+ if accelerator.is_main_process:
390
+ print("Total samples", len(dataloader))
391
+ dataloader = accelerator.prepare(dataloader)
392
+
393
+ start_epoch = 0
394
+ global_step = 0
395
+ total_training_steps = (len(dataloader) * num_epochs)
396
+ world_size = accelerator.state.num_processes
397
+
398
+ # Загрузка UNet
399
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
400
+ if os.path.isdir(latest_checkpoint):
401
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
402
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
403
+ if unet_gradient:
404
+ unet.enable_gradient_checkpointing()
405
+ unet.set_use_memory_efficient_attention_xformers(False)
406
+ try:
407
+ unet.set_attn_processor(AttnProcessor2_0())
408
+ except Exception as e:
409
+ print(f"Ошибка при включении SDPA: {e}")
410
+ unet.set_use_memory_efficient_attention_xformers(True)
411
+ else:
412
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
413
+
414
+ if lora_name:
415
+ # ... (Код LoRA без изменений, опущен для краткости, если не используется, иначе раскомментируйте оригинальный блок) ...
416
+ pass
417
+
418
+ # Оптимизатор
419
+ if lora_name:
420
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
421
+ else:
422
+ if fbp:
423
+ trainable_params = list(unet.parameters())
424
+
425
+ def create_optimizer(name, params):
426
+ if name == "adam8bit":
427
+ return bnb.optim.AdamW8bit(
428
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
429
+ percentile_clipping=percentile_clipping
430
+ )
431
+ elif name == "adam":
432
+ return torch.optim.AdamW(
433
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
434
+ )
435
+ elif name == "muon":
436
+ from muon import MuonWithAuxAdam
437
+ trainable_params = [p for p in params if p.requires_grad]
438
+ hidden_weights = [p for p in trainable_params if p.ndim >= 2]
439
+ hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
440
+
441
+ param_groups = [
442
+ dict(params=hidden_weights, use_muon=True,
443
+ lr=1e-3, weight_decay=1e-4),
444
+ dict(params=hidden_gains_biases, use_muon=False,
445
+ lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
446
+ ]
447
+ optimizer = MuonWithAuxAdam(param_groups)
448
+ from snooc import SnooC
449
+ return SnooC(optimizer)
450
+ else:
451
+ raise ValueError(f"Unknown optimizer: {name}")
452
+
453
+ if fbp:
454
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
455
+ def optimizer_hook(param):
456
+ optimizer_dict[param].step()
457
+ optimizer_dict[param].zero_grad(set_to_none=True)
458
+ for param in trainable_params:
459
+ param.register_post_accumulate_grad_hook(optimizer_hook)
460
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
461
+ else:
462
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
463
+ def lr_schedule(step):
464
+ x = step / (total_training_steps * world_size)
465
+ warmup = warmup_percent
466
+ if not use_decay:
467
+ return base_learning_rate
468
+ if x < warmup:
469
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
470
+ decay_ratio = (x - warmup) / (1 - warmup)
471
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
472
+ (1 + math.cos(math.pi * decay_ratio))
473
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
474
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
475
+
476
+ if torch_compile:
477
+ print("compiling")
478
+ unet = torch.compile(unet)
479
+ print("compiling - ok")
480
+
481
+ # Фиксированные семплы
482
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
483
+
484
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
485
+ def get_negative_embedding(neg_prompt="", batch_size=1):
486
+ if not neg_prompt:
487
+ hidden_dim = 2048
488
+ seq_len = max_length
489
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
490
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
491
+ return empty_emb, empty_mask
492
+
493
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
494
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
495
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
496
+
497
+ return uncond_emb, uncond_mask
498
+
499
+ # Получаем негативные (пустые) условия для валидации
500
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
501
+
502
+ # --- Функция генерации семплов ---
503
+ @torch.compiler.disable()
504
+ @torch.no_grad()
505
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
506
+ uncond_emb, uncond_mask = uncond_data
507
+
508
+ original_model = None
509
+ try:
510
+ if not torch_compile:
511
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
512
+ else:
513
+ original_model = unet.eval()
514
+
515
+ vae.to(device=device).eval()
516
+
517
+ all_generated_images = []
518
+ all_captions = []
519
+
520
+ # Распаковываем 5 элементов (добавились mask)
521
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
522
+ width, height = size
523
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
524
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
525
+ sample_mask = sample_mask.to(device=device)
526
+
527
+ latents = torch.randn(
528
+ sample_latents.shape,
529
+ device=device,
530
+ dtype=sample_latents.dtype,
531
+ generator=torch.Generator(device=device).manual_seed(seed)
532
+ )
533
+
534
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
535
+
536
+ for t in scheduler.timesteps:
537
+ if guidance_scale != 1:
538
+ latent_model_input = torch.cat([latents, latents], dim=0)
539
+
540
+ # Подготовка батчей для CFG (Negative + Positive)
541
+ # 1. Embeddings
542
+ curr_batch_size = sample_text_embeddings.shape[0]
543
+ seq_len = sample_text_embeddings.shape[1]
544
+ hidden_dim = sample_text_embeddings.shape[2]
545
+
546
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
547
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
548
+
549
+ # 2. Masks
550
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
551
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
552
+
553
+ else:
554
+ latent_model_input = latents
555
+ text_embeddings_batch = sample_text_embeddings
556
+ attention_mask_batch = sample_mask
557
+
558
+ # Предсказание с передачей всех условий
559
+ model_out = original_model(
560
+ latent_model_input,
561
+ t,
562
+ encoder_hidden_states=text_embeddings_batch,
563
+ encoder_attention_mask=attention_mask_batch,
564
+ )
565
+ flow = getattr(model_out, "sample", model_out)
566
+
567
+ if guidance_scale != 1:
568
+ flow_uncond, flow_cond = flow.chunk(2)
569
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
570
+
571
+ latents = scheduler.step(flow, t, latents).prev_sample
572
+
573
+ current_latents = latents
574
+
575
+ latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
576
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
577
+ decoded_fp32 = decoded.to(torch.float32)
578
+
579
+ for img_idx, img_tensor in enumerate(decoded_fp32):
580
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
581
+ img = img.transpose(1, 2, 0)
582
+
583
+ if np.isnan(img).any():
584
+ print("NaNs found, saving stopped! Step:", step)
585
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
586
+
587
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
588
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
589
+ max_w_overall = max(255, max_w_overall)
590
+ max_h_overall = max(255, max_h_overall)
591
+
592
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
593
+ all_generated_images.append(padded_img)
594
+
595
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
596
+ all_captions.append(caption_text)
597
+
598
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
599
+ pil_img.save(sample_path, "JPEG", quality=96)
600
+
601
+ if use_wandb and accelerator.is_main_process:
602
+ wandb_images = [
603
+ wandb.Image(img, caption=f"{all_captions[i]}")
604
+ for i, img in enumerate(all_generated_images)
605
+ ]
606
+ wandb.log({"generated_images": wandb_images})
607
+ if use_comet_ml and accelerator.is_main_process:
608
+ for i, img in enumerate(all_generated_images):
609
+ comet_experiment.log_image(
610
+ image_data=img,
611
+ name=f"step_{step}_img_{i}",
612
+ step=step,
613
+ metadata={"caption": all_captions[i]}
614
+ )
615
+ finally:
616
+ vae.to("cpu")
617
+ torch.cuda.empty_cache()
618
+ gc.collect()
619
+
620
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
621
+ if accelerator.is_main_process:
622
+ if save_model:
623
+ print("Генерация сэмплов до старта обучения...")
624
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
625
+ accelerator.wait_for_everyone()
626
+
627
+ def save_checkpoint(unet, variant=""):
628
+ if accelerator.is_main_process:
629
+ if lora_name:
630
+ save_lora_checkpoint(unet)
631
+ else:
632
+ model_to_save = None
633
+ if not torch_compile:
634
+ model_to_save = accelerator.unwrap_model(unet)
635
+ else:
636
+ model_to_save = unet
637
+
638
+ if variant != "":
639
+ model_to_save.to(dtype=torch.float16).save_pretrained(
640
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
641
+ )
642
+ else:
643
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
644
+
645
+ unet = unet.to(dtype=dtype)
646
+
647
+ # --------------------------- Тренировочный цикл ---------------------------
648
+ if accelerator.is_main_process:
649
+ print(f"Total steps per GPU: {total_training_steps}")
650
+
651
+ epoch_loss_points = []
652
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
653
+
654
+ steps_per_epoch = len(dataloader)
655
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
656
+ min_loss = 4.
657
+
658
+ for epoch in range(start_epoch, start_epoch + num_epochs):
659
+ batch_losses = []
660
+ batch_grads = []
661
+ batch_sampler.set_epoch(epoch)
662
+ accelerator.wait_for_everyone()
663
+ unet.train()
664
+
665
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
666
+ with accelerator.accumulate(unet):
667
+ if save_model == False and epoch == 0 and step == 5 :
668
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
669
+ print(f"Шаг {step}: {used_gb:.2f} GB")
670
+
671
+ # шум
672
+ noise = torch.randn_like(latents, dtype=latents.dtype)
673
+ # берём t из [0, 1]
674
+ #t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
675
+ u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
676
+ t = torch.sigmoid(torch.randn_like(u))
677
+
678
+ # интерполяция между x0 и шумом
679
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
680
+ # делаем integer timesteps для UNet
681
+ timesteps = (t * scheduler.config.num_train_timesteps).long()
682
+
683
+ # --- Вызов UNet с маской ---
684
+ model_pred = unet(
685
+ noisy_latents,
686
+ timesteps,
687
+ encoder_hidden_states=embeddings,
688
+ encoder_attention_mask=attention_mask
689
+ ).sample
690
+
691
+ target = noise - latents
692
+
693
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
694
+ mae_loss = F.l1_loss(model_pred.float(), target.float())
695
+ batch_losses.append(mse_loss.detach().item())
696
+
697
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
698
+ accelerator.wait_for_everyone()
699
+
700
+ losses_dict = {}
701
+ losses_dict["mse"] = mse_loss
702
+ losses_dict["mae"] = mae_loss
703
+
704
+ # === Нормализация всех лоссов ===
705
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
706
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
707
+
708
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
709
+ accelerator.wait_for_everyone()
710
+
711
+ accelerator.backward(total_loss)
712
+
713
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
714
+ accelerator.wait_for_everyone()
715
+
716
+ grad = 0.0
717
+ if not fbp:
718
+ if accelerator.sync_gradients:
719
+ #with torch.amp.autocast('cuda', enabled=False):
720
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
721
+ if grad_val is not None:
722
+ grad = float(grad_val)
723
+ print("grad_val is None")
724
+ else:
725
+ grad = 0.0
726
+ optimizer.step()
727
+ lr_scheduler.step()
728
+ optimizer.zero_grad(set_to_none=True)
729
+
730
+ if accelerator.sync_gradients:
731
+ global_step += 1
732
+ progress_bar.update(1)
733
+ if accelerator.is_main_process:
734
+ if fbp:
735
+ current_lr = base_learning_rate
736
+ else:
737
+ current_lr = lr_scheduler.get_last_lr()[0]
738
+ batch_grads.append(grad)
739
+
740
+ log_data = {}
741
+ log_data["loss_mse"] = mse_loss.detach().item()
742
+ log_data["loss_mae"] = mae_loss.detach().item()
743
+ log_data["lr"] = current_lr
744
+ log_data["grad"] = grad
745
+ log_data["loss_norm"] = float(total_loss.item())
746
+ for k, c in coeffs.items():
747
+ log_data[f"coeff_{k}"] = float(c)
748
+ if accelerator.sync_gradients:
749
+ if use_wandb:
750
+ wandb.log(log_data, step=global_step)
751
+ if use_comet_ml:
752
+ comet_experiment.log_metrics(log_data, step=global_step)
753
+
754
+ if global_step % sample_interval == 0:
755
+ # Передаем tuple (emb, mask) для негатива
756
+ if save_model:
757
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
758
+ elif epoch % 10 == 0:
759
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
760
+ last_n = sample_interval
761
+
762
+ if save_model:
763
+ has_losses = len(batch_losses) > 0
764
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
765
+ last_loss = batch_losses[-1] if has_losses else 0.0
766
+ max_loss = max(avg_sample_loss, last_loss)
767
+ should_save = max_loss < min_loss * save_barrier
768
+ print(
769
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
770
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
771
+ )
772
+ # 6. Сохранение и обновление
773
+ if should_save:
774
+ min_loss = max_loss
775
+ save_checkpoint(unet)
776
+
777
+ if accelerator.is_main_process:
778
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
779
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
780
+
781
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
782
+ log_data_ep = {
783
+ "epoch_loss": avg_epoch_loss,
784
+ "epoch_grad": avg_epoch_grad,
785
+ "epoch": epoch + 1,
786
+ }
787
+ if use_wandb:
788
+ wandb.log(log_data_ep)
789
+ if use_comet_ml:
790
+ comet_experiment.log_metrics(log_data_ep)
791
+
792
+ if accelerator.is_main_process:
793
+ print("Обучение завершено! Сохраняем финальную модель...")
794
+ #if save_model:
795
+ save_checkpoint(unet,"fp16")
796
+ if use_comet_ml:
797
+ comet_experiment.end()
798
+ accelerator.free_memory()
799
+ if torch.distributed.is_initialized():
800
+ torch.distributed.destroy_process_group()
801
+
802
+ print("Готово!")
config.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "sdxs_flux",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": [
10
+ 2,
11
+ 4,
12
+ 8,
13
+ 16
14
+ ],
15
+ "attention_type": "default",
16
+ "block_out_channels": [
17
+ 128,
18
+ 256,
19
+ 512,
20
+ 1024
21
+ ],
22
+ "center_input_sample": false,
23
+ "class_embed_type": null,
24
+ "class_embeddings_concat": false,
25
+ "conv_in_kernel": 3,
26
+ "conv_out_kernel": 3,
27
+ "cross_attention_dim": 1024,
28
+ "cross_attention_norm": null,
29
+ "down_block_types": [
30
+ "DownBlock2D",
31
+ "DownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "CrossAttnDownBlock2D"
34
+ ],
35
+ "downsample_padding": 1,
36
+ "dropout": 0.0,
37
+ "dual_cross_attention": false,
38
+ "encoder_hid_dim": null,
39
+ "encoder_hid_dim_type": null,
40
+ "flip_sin_to_cos": true,
41
+ "freq_shift": 0,
42
+ "in_channels": 32,
43
+ "layers_per_block": 2,
44
+ "mid_block_only_cross_attention": null,
45
+ "mid_block_scale_factor": 1.0,
46
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
47
+ "norm_eps": 1e-05,
48
+ "norm_num_groups": 32,
49
+ "num_attention_heads": null,
50
+ "num_class_embeds": null,
51
+ "only_cross_attention": false,
52
+ "out_channels": 32,
53
+ "projection_class_embeddings_input_dim": null,
54
+ "resnet_out_scale_factor": 1.0,
55
+ "resnet_skip_time_act": false,
56
+ "resnet_time_scale_shift": "default",
57
+ "reverse_transformer_layers_per_block": null,
58
+ "sample_size": null,
59
+ "time_cond_proj_dim": null,
60
+ "time_embedding_act_fn": null,
61
+ "time_embedding_dim": null,
62
+ "time_embedding_type": "positional",
63
+ "timestep_post_act": null,
64
+ "transformer_layers_per_block": [
65
+ 1,
66
+ 1,
67
+ 2,
68
+ 4
69
+ ],
70
+ "up_block_types": [
71
+ "CrossAttnUpBlock2D",
72
+ "CrossAttnUpBlock2D",
73
+ "UpBlock2D",
74
+ "UpBlock2D"
75
+ ],
76
+ "upcast_attention": false,
77
+ "use_linear_projection": true
78
+ }
dataset_flux.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKL,AutoencoderKLWan,AutoencoderKLFlux2
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+
20
+ # ---------------- 1️⃣ Настройки ----------------
21
+ dtype = torch.float32
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ batch_size = 5
24
+ min_size = 192 #384 #320 #192 #256 #192
25
+ max_size = 320 #768 #640 #384 #256 #384
26
+ step = 64 #64
27
+ empty_share = 0.0
28
+ limit = 0
29
+ # Основная процедура обработки
30
+ folder_path = "/workspace/mjnj" #alchemist"
31
+ save_path = "/workspace/sdxs/datasets/mjnj" #"alchemist"
32
+ os.makedirs(save_path, exist_ok=True)
33
+
34
+ # Функция для очистки CUDA памяти
35
+ def clear_cuda_memory():
36
+ if torch.cuda.is_available():
37
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
38
+ print(f"used_gb: {used_gb:.2f} GB")
39
+ torch.cuda.empty_cache()
40
+ gc.collect()
41
+
42
+ # ---------------- 2️⃣ Загрузка моделей ----------------
43
+ def load_models():
44
+ print("Загрузка моделей...")
45
+ #vae = AutoencoderKL.from_pretrained("AiArtLab/sdxs",subfolder="vae1x",torch_dtype=dtype).to(device).eval()
46
+ vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
47
+
48
+ #model_name = "Qwen/Qwen3-0.6B"
49
+ #tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+ #model = AutoModelForCausalLM.from_pretrained(
51
+ # model_name,
52
+ # torch_dtype=dtype,
53
+ # device_map=device
54
+ #).eval()
55
+ #tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
56
+ #model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B').to("cuda")
57
+ return vae#, model, tokenizer
58
+
59
+ #vae, model, tokenizer = load_models()
60
+ vae = load_models()
61
+
62
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
63
+ if shift_factor is None:
64
+ shift_factor = 0.0
65
+
66
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
67
+ if scaling_factor is None:
68
+ scaling_factor = 1.0
69
+
70
+ latents_mean = getattr(vae.config, "latents_mean", None)
71
+ latents_std = getattr(vae.config, "latents_std", None)
72
+
73
+ # ---------------- 3️⃣ Трансформации ----------------
74
+ def get_image_transform(min_size=256, max_size=512, step=64):
75
+ def transform(img, dry_run=False):
76
+ # Сохраняем исходные размеры изображения
77
+ original_width, original_height = img.size
78
+
79
+ # 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size
80
+ if original_width >= original_height:
81
+ new_width = max_size
82
+ new_height = int(max_size * original_height / original_width)
83
+ else:
84
+ new_height = max_size
85
+ new_width = int(max_size * original_width / original_height)
86
+
87
+ if new_height < min_size or new_width < min_size:
88
+ # 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size
89
+ if original_width <= original_height:
90
+ new_width = min_size
91
+ new_height = int(min_size * original_height / original_width)
92
+ else:
93
+ new_height = min_size
94
+ new_width = int(min_size * original_width / original_height)
95
+
96
+ # 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке
97
+ crop_width = min(max_size, (new_width // step) * step)
98
+ crop_height = min(max_size, (new_height // step) * step)
99
+
100
+ # Убеждаемся, что размеры обрезки не меньше min_size
101
+ crop_width = max(min_size, crop_width)
102
+ crop_height = max(min_size, crop_height)
103
+
104
+ # Если запрошен только предварительный расчёт размеров
105
+ if dry_run:
106
+ return crop_width, crop_height
107
+
108
+ # Конвертация в RGB и ресайз
109
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
110
+
111
+ # Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху)
112
+ top = (new_height - crop_height) // 3
113
+ left = 0
114
+
115
+ # Обрезка изображения
116
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
117
+
118
+ # Сохраняем итоговые размеры после всех преобразований
119
+ final_width, final_height = img_cropped.size
120
+
121
+ # тензор
122
+ img_tensor = ToTensor()(img_cropped)
123
+ img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor)
124
+ return img_tensor, img_cropped, final_width, final_height
125
+
126
+ return transform
127
+
128
+ # ---------------- 4️⃣ Функции обработки ----------------
129
+ def last_token_pool(last_hidden_states: torch.Tensor,
130
+ attention_mask: torch.Tensor) -> torch.Tensor:
131
+ # Определяем, есть ли left padding
132
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
133
+ if left_padding:
134
+ return last_hidden_states[:, -1]
135
+ else:
136
+ sequence_lengths = attention_mask.sum(dim=1) - 1
137
+ batch_size = last_hidden_states.shape[0]
138
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
139
+
140
+ def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150, normalize=False):
141
+ with torch.inference_mode():
142
+ # Токенизация
143
+ batch = tokenizer(
144
+ texts,
145
+ return_tensors="pt",
146
+ padding="max_length",
147
+ truncation=True,
148
+ max_length=max_length
149
+ ).to(device)
150
+
151
+ # Прогон через модель
152
+ #outputs = model(**batch)
153
+
154
+ # Пулинг по last token
155
+ #embeddings = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
156
+
157
+ # L2-нормализация (опционально, обычно нужна для семантического поиска)
158
+ #if normalize:
159
+ # embeddings = F.normalize(embeddings, p=2, dim=1)
160
+
161
+ # Прогон через базовую модель (внутри CausalLM)
162
+ outputs = model.model(**batch, output_hidden_states=True)
163
+
164
+ # Берем последний слой (эмбеддинги всех токенов)
165
+ hidden_states = outputs.hidden_states[-1] # [B, L, D]
166
+
167
+ # Можно применить нормализацию по каждому токену (как в CLIP)
168
+ if normalize:
169
+ hidden_states = F.normalize(hidden_states, p=2, dim=-1)
170
+
171
+ return hidden_states.cpu().numpy() # embeddings.unsqueeze(1).cpu().numpy()
172
+
173
+ def clean_label(label):
174
+ label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
175
+ if label.startswith("."):
176
+ label = label[1:].lstrip()
177
+ return label
178
+
179
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
180
+ """
181
+ Обрабатывает список меток для classifier-free guidance.
182
+
183
+ С вероятностью prob_to_make_empty:
184
+ - Метка в первом списке заменяется на пустую строку.
185
+ - К метке во втором списке добавляется префикс "zero:".
186
+
187
+ В противном случае метки в обоих списках остаются оригинальными.
188
+
189
+ """
190
+ labels_for_model = []
191
+ labels_for_logging = []
192
+
193
+ for label in original_labels:
194
+ if random.random() < prob_to_make_empty:
195
+ labels_for_model.append("") # Заменяем на пустую строку для модели
196
+ labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования
197
+ else:
198
+ labels_for_model.append(label) # Оставляем оригинальную метку для модели
199
+ labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования
200
+
201
+ return labels_for_model, labels_for_logging
202
+
203
+ def encode_to_latents(images, texts):
204
+ transform = get_image_transform(min_size, max_size, step)
205
+
206
+ try:
207
+ # Обработка изображений (все одинакового размера)
208
+ transformed_tensors = []
209
+ pil_images = []
210
+ widths, heights = [], []
211
+
212
+ # Применяем трансформацию ко всем изображениям
213
+ for img in images:
214
+ try:
215
+ t_img, pil_img, w, h = transform(img)
216
+ transformed_tensors.append(t_img)
217
+ pil_images.append(pil_img)
218
+ widths.append(w)
219
+ heights.append(h)
220
+ except Exception as e:
221
+ print(f"Ошибка трансформации: {e}")
222
+ continue
223
+
224
+ if not transformed_tensors:
225
+ return None
226
+
227
+ # Создаём батч
228
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
229
+ if batch_tensor.ndim==5:
230
+ batch_tensor = batch_tensor.unsqueeze(2) # [B, C, 1, H, W]
231
+
232
+ # Кодируем батч
233
+ with torch.no_grad():
234
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
235
+ latents = (posteriors - shift_factor) / scaling_factor
236
+
237
+ latents_np = latents.to(dtype).cpu().numpy()
238
+
239
+ # Обрабатываем тексты
240
+ text_labels = [clean_label(text) for text in texts]
241
+
242
+ model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
243
+ #embeddings = encode_texts_batch(model_prompts, tokenizer, model)
244
+
245
+ return {
246
+ "vae": latents_np,
247
+ #"embeddings": embeddings,
248
+ "text": text_labels,
249
+ "width": widths,
250
+ "height": heights
251
+ }
252
+
253
+ except Exception as e:
254
+ print(f"Критическая ошибка в encode_to_latents: {e}")
255
+ raise
256
+
257
+
258
+ # ---------------- 5️⃣ Обработка папки с изображениями и текстами ----------------
259
+ def process_folder(folder_path, limit=None):
260
+ """
261
+ Рекурсивно обходит указанную директорию и все вложенные директории,
262
+ собирая пути к изображениям и соответствующим текстовым файлам.
263
+ """
264
+ image_paths = []
265
+ text_paths = []
266
+ width = []
267
+ height = []
268
+ transform = get_image_transform(min_size, max_size, step)
269
+
270
+ # Используем os.walk для рекурсивного обхода директорий
271
+ for root, dirs, files in os.walk(folder_path):
272
+ for filename in files:
273
+ # Проверяем, является ли файл изображением
274
+ if filename.lower().endswith((".jpg", ".jpeg", ".png")):
275
+ image_path = os.path.join(root, filename)
276
+ try:
277
+ img = Image.open(image_path)
278
+ except Exception as e:
279
+ print(f"Ошибка при открытии {image_path}: {e}")
280
+ os.remove(image_path)
281
+ text_path = os.path.splitext(image_path)[0] + ".txt"
282
+ if os.path.exists(text_path):
283
+ os.remove(text_path)
284
+ continue
285
+ # Применяем трансформацию только для получения размеров
286
+ w, h = transform(img, dry_run=True)
287
+ # Формируем путь к текстовому файлу
288
+ text_path = os.path.splitext(image_path)[0] + ".txt"
289
+
290
+ # Добавляем пути, если текстовый файл существует
291
+ if os.path.exists(text_path) and min(w, h)>0:
292
+ image_paths.append(image_path)
293
+ text_paths.append(text_path)
294
+ width.append(w) # Добавляем в список
295
+ height.append(h) # Добавляем в список
296
+
297
+ # Проверяем ограничение на количество
298
+ if limit and limit>0 and len(image_paths) >= limit:
299
+ print(f"Достигнут лимит в {limit} изображений")
300
+ return image_paths, text_paths, width, height
301
+
302
+ print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями")
303
+ return image_paths, text_paths, width, height
304
+
305
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1):
306
+ total_files = len(image_paths)
307
+ start_time = time.time()
308
+ chunks = range(0, total_files, chunk_size)
309
+
310
+ for chunk_idx, start in enumerate(chunks, 1):
311
+ end = min(start + chunk_size, total_files)
312
+ chunk_image_paths = image_paths[start:end]
313
+ chunk_text_paths = text_paths[start:end]
314
+ chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths)
315
+ chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths)
316
+
317
+ # Чтение текстов
318
+ chunk_texts = []
319
+ for text_path in chunk_text_paths:
320
+ try:
321
+ with open(text_path, 'r', encoding='utf-8') as f:
322
+ text = f.read().strip()
323
+ chunk_texts.append(text)
324
+ except Exception as e:
325
+ print(f"Ошибка чтения {text_path}: {e}")
326
+ chunk_texts.append("")
327
+
328
+ # Группируем изображения по размерам
329
+ size_groups = {}
330
+ for i in range(len(chunk_image_paths)):
331
+ size_key = (chunk_widths[i], chunk_heights[i])
332
+ if size_key not in size_groups:
333
+ size_groups[size_key] = {"image_paths": [], "texts": []}
334
+ size_groups[size_key]["image_paths"].append(chunk_image_paths[i])
335
+ size_groups[size_key]["texts"].append(chunk_texts[i])
336
+
337
+ # Обрабатываем каждую группу размеров отдельно
338
+ for size_key, group_data in size_groups.items():
339
+ print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений")
340
+
341
+ group_dataset = Dataset.from_dict({
342
+ "image_path": group_data["image_paths"],
343
+ "text": group_data["texts"]
344
+ })
345
+
346
+ # Теперь можно использовать указанный batch_size, т.к. все изображения одного размера
347
+ processed_group = group_dataset.map(
348
+ lambda examples: encode_to_latents(
349
+ [Image.open(path) for path in examples["image_path"]],
350
+ examples["text"]
351
+ ),
352
+ batched=True,
353
+ batch_size=batch_size,
354
+ #remove_columns=["image_path"],
355
+ desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}"
356
+ )
357
+
358
+ # Сохраняем результаты группы
359
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
360
+ processed_group.save_to_disk(group_save_path)
361
+ clear_cuda_memory()
362
+ elapsed = time.time() - start_time
363
+ processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
364
+ if processed > 0:
365
+ remaining = (elapsed / processed) * (total_files - processed)
366
+ elapsed_str = str(timedelta(seconds=int(elapsed)))
367
+ remaining_str = str(timedelta(seconds=int(remaining)))
368
+ print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
369
+
370
+ # ---------------- 7️⃣ Объединение чанков ----------------
371
+ def combine_chunks(temp_path, final_path):
372
+ """Объединение обработанных чанков в финальный датасет"""
373
+ chunks = sorted([
374
+ os.path.join(temp_path, d)
375
+ for d in os.listdir(temp_path)
376
+ if d.startswith("chunk_")
377
+ ])
378
+
379
+ datasets = [load_from_disk(chunk) for chunk in chunks]
380
+ combined = concatenate_datasets(datasets)
381
+ combined.save_to_disk(final_path)
382
+
383
+ print(f"✅ Датасет успешно сохранен в: {final_path}")
384
+
385
+
386
+
387
+ # Создаем временную папку для чанков
388
+ temp_path = f"{save_path}_temp"
389
+ os.makedirs(temp_path, exist_ok=True)
390
+
391
+ # Получаем список файлов
392
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
393
+ print(f"Всего найдено {len(image_paths)} изображений")
394
+
395
+ # Обработка с чанкованием
396
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size)
397
+
398
+ # Удаление папки
399
+ try:
400
+ shutil.rmtree(folder_path)
401
+ print(f"✅ Папка {folder_path} успешно удалена")
402
+ except Exception as e:
403
+ print(f"⚠️ Ошибка при удалении папки: {e}")
404
+
405
+ # Объединение чанков в финальный датасет
406
+ combine_chunks(temp_path, save_path)
407
+
408
+ # Удаление временной папки
409
+ try:
410
+ shutil.rmtree(temp_path)
411
+ print(f"✅ Временная папка {temp_path} успешно удалена")
412
+ except Exception as e:
413
+ print(f"⚠️ Ошибка при удалении временной папки: {e}")
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41dfe79753fadad9ba8e148ac73e1b61f23c7a1146cece5147c0041d613298bd
3
+ size 3195253456
sdxs_create_flux.ipynb ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "6bf71a1a-1bf0-42c7-8709-6686e8d2f46c",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "test unet\n",
14
+ "Количество параметров: 798780960\n",
15
+ "Output shape: torch.Size([1, 32, 60, 48])\n",
16
+ "UNet2DConditionModel(\n",
17
+ " (conv_in): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
18
+ " (time_proj): Timesteps()\n",
19
+ " (time_embedding): TimestepEmbedding(\n",
20
+ " (linear_1): Linear(in_features=128, out_features=512, bias=True)\n",
21
+ " (act): SiLU()\n",
22
+ " (linear_2): Linear(in_features=512, out_features=512, bias=True)\n",
23
+ " )\n",
24
+ " (down_blocks): ModuleList(\n",
25
+ " (0): DownBlock2D(\n",
26
+ " (resnets): ModuleList(\n",
27
+ " (0-1): 2 x ResnetBlock2D(\n",
28
+ " (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
29
+ " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
30
+ " (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)\n",
31
+ " (norm2): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
32
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
33
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
34
+ " (nonlinearity): SiLU()\n",
35
+ " )\n",
36
+ " )\n",
37
+ " (downsamplers): ModuleList(\n",
38
+ " (0): Downsample2D(\n",
39
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
40
+ " )\n",
41
+ " )\n",
42
+ " )\n",
43
+ " (1): DownBlock2D(\n",
44
+ " (resnets): ModuleList(\n",
45
+ " (0): ResnetBlock2D(\n",
46
+ " (norm1): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
47
+ " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
48
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
49
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
50
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
51
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
52
+ " (nonlinearity): SiLU()\n",
53
+ " (conv_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))\n",
54
+ " )\n",
55
+ " (1): ResnetBlock2D(\n",
56
+ " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
57
+ " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
58
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
59
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
60
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
61
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
62
+ " (nonlinearity): SiLU()\n",
63
+ " )\n",
64
+ " )\n",
65
+ " (downsamplers): ModuleList(\n",
66
+ " (0): Downsample2D(\n",
67
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
68
+ " )\n",
69
+ " )\n",
70
+ " )\n",
71
+ " (2): CrossAttnDownBlock2D(\n",
72
+ " (attentions): ModuleList(\n",
73
+ " (0-1): 2 x Transformer2DModel(\n",
74
+ " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
75
+ " (proj_in): Linear(in_features=512, out_features=512, bias=True)\n",
76
+ " (transformer_blocks): ModuleList(\n",
77
+ " (0-1): 2 x BasicTransformerBlock(\n",
78
+ " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
79
+ " (attn1): Attention(\n",
80
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
81
+ " (to_k): Linear(in_features=512, out_features=512, bias=False)\n",
82
+ " (to_v): Linear(in_features=512, out_features=512, bias=False)\n",
83
+ " (to_out): ModuleList(\n",
84
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
85
+ " (1): Dropout(p=0.0, inplace=False)\n",
86
+ " )\n",
87
+ " )\n",
88
+ " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
89
+ " (attn2): Attention(\n",
90
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
91
+ " (to_k): Linear(in_features=1024, out_features=512, bias=False)\n",
92
+ " (to_v): Linear(in_features=1024, out_features=512, bias=False)\n",
93
+ " (to_out): ModuleList(\n",
94
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
95
+ " (1): Dropout(p=0.0, inplace=False)\n",
96
+ " )\n",
97
+ " )\n",
98
+ " (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
99
+ " (ff): FeedForward(\n",
100
+ " (net): ModuleList(\n",
101
+ " (0): GEGLU(\n",
102
+ " (proj): Linear(in_features=512, out_features=4096, bias=True)\n",
103
+ " )\n",
104
+ " (1): Dropout(p=0.0, inplace=False)\n",
105
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
106
+ " )\n",
107
+ " )\n",
108
+ " )\n",
109
+ " )\n",
110
+ " (proj_out): Linear(in_features=512, out_features=512, bias=True)\n",
111
+ " )\n",
112
+ " )\n",
113
+ " (resnets): ModuleList(\n",
114
+ " (0): ResnetBlock2D(\n",
115
+ " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
116
+ " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
117
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
118
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
119
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
120
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
121
+ " (nonlinearity): SiLU()\n",
122
+ " (conv_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
123
+ " )\n",
124
+ " (1): ResnetBlock2D(\n",
125
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
126
+ " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
127
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
128
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
129
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
130
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
131
+ " (nonlinearity): SiLU()\n",
132
+ " )\n",
133
+ " )\n",
134
+ " (downsamplers): ModuleList(\n",
135
+ " (0): Downsample2D(\n",
136
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
137
+ " )\n",
138
+ " )\n",
139
+ " )\n",
140
+ " (3): CrossAttnDownBlock2D(\n",
141
+ " (attentions): ModuleList(\n",
142
+ " (0-1): 2 x Transformer2DModel(\n",
143
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
144
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
145
+ " (transformer_blocks): ModuleList(\n",
146
+ " (0-3): 4 x BasicTransformerBlock(\n",
147
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
148
+ " (attn1): Attention(\n",
149
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
150
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
151
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
152
+ " (to_out): ModuleList(\n",
153
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
154
+ " (1): Dropout(p=0.0, inplace=False)\n",
155
+ " )\n",
156
+ " )\n",
157
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
158
+ " (attn2): Attention(\n",
159
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
160
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
161
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
162
+ " (to_out): ModuleList(\n",
163
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
164
+ " (1): Dropout(p=0.0, inplace=False)\n",
165
+ " )\n",
166
+ " )\n",
167
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
168
+ " (ff): FeedForward(\n",
169
+ " (net): ModuleList(\n",
170
+ " (0): GEGLU(\n",
171
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
172
+ " )\n",
173
+ " (1): Dropout(p=0.0, inplace=False)\n",
174
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
175
+ " )\n",
176
+ " )\n",
177
+ " )\n",
178
+ " )\n",
179
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
180
+ " )\n",
181
+ " )\n",
182
+ " (resnets): ModuleList(\n",
183
+ " (0): ResnetBlock2D(\n",
184
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
185
+ " (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
186
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
187
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
188
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
189
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
190
+ " (nonlinearity): SiLU()\n",
191
+ " (conv_shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
192
+ " )\n",
193
+ " (1): ResnetBlock2D(\n",
194
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
195
+ " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
196
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
197
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
198
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
199
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
200
+ " (nonlinearity): SiLU()\n",
201
+ " )\n",
202
+ " )\n",
203
+ " )\n",
204
+ " )\n",
205
+ " (up_blocks): ModuleList(\n",
206
+ " (0): CrossAttnUpBlock2D(\n",
207
+ " (attentions): ModuleList(\n",
208
+ " (0-2): 3 x Transformer2DModel(\n",
209
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
210
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
211
+ " (transformer_blocks): ModuleList(\n",
212
+ " (0-3): 4 x BasicTransformerBlock(\n",
213
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
214
+ " (attn1): Attention(\n",
215
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
216
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
217
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
218
+ " (to_out): ModuleList(\n",
219
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
220
+ " (1): Dropout(p=0.0, inplace=False)\n",
221
+ " )\n",
222
+ " )\n",
223
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
224
+ " (attn2): Attention(\n",
225
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
226
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
227
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
228
+ " (to_out): ModuleList(\n",
229
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
230
+ " (1): Dropout(p=0.0, inplace=False)\n",
231
+ " )\n",
232
+ " )\n",
233
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
234
+ " (ff): FeedForward(\n",
235
+ " (net): ModuleList(\n",
236
+ " (0): GEGLU(\n",
237
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
238
+ " )\n",
239
+ " (1): Dropout(p=0.0, inplace=False)\n",
240
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
241
+ " )\n",
242
+ " )\n",
243
+ " )\n",
244
+ " )\n",
245
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
246
+ " )\n",
247
+ " )\n",
248
+ " (resnets): ModuleList(\n",
249
+ " (0-1): 2 x ResnetBlock2D(\n",
250
+ " (norm1): GroupNorm(32, 2048, eps=1e-05, affine=True)\n",
251
+ " (conv1): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
252
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
253
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
254
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
255
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
256
+ " (nonlinearity): SiLU()\n",
257
+ " (conv_shortcut): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
258
+ " )\n",
259
+ " (2): ResnetBlock2D(\n",
260
+ " (norm1): GroupNorm(32, 1536, eps=1e-05, affine=True)\n",
261
+ " (conv1): Conv2d(1536, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
262
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
263
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
264
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
265
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
266
+ " (nonlinearity): SiLU()\n",
267
+ " (conv_shortcut): Conv2d(1536, 1024, kernel_size=(1, 1), stride=(1, 1))\n",
268
+ " )\n",
269
+ " )\n",
270
+ " (upsamplers): ModuleList(\n",
271
+ " (0): Upsample2D(\n",
272
+ " (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
273
+ " )\n",
274
+ " )\n",
275
+ " )\n",
276
+ " (1): CrossAttnUpBlock2D(\n",
277
+ " (attentions): ModuleList(\n",
278
+ " (0-2): 3 x Transformer2DModel(\n",
279
+ " (norm): GroupNorm(32, 512, eps=1e-06, affine=True)\n",
280
+ " (proj_in): Linear(in_features=512, out_features=512, bias=True)\n",
281
+ " (transformer_blocks): ModuleList(\n",
282
+ " (0-1): 2 x BasicTransformerBlock(\n",
283
+ " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
284
+ " (attn1): Attention(\n",
285
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
286
+ " (to_k): Linear(in_features=512, out_features=512, bias=False)\n",
287
+ " (to_v): Linear(in_features=512, out_features=512, bias=False)\n",
288
+ " (to_out): ModuleList(\n",
289
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
290
+ " (1): Dropout(p=0.0, inplace=False)\n",
291
+ " )\n",
292
+ " )\n",
293
+ " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
294
+ " (attn2): Attention(\n",
295
+ " (to_q): Linear(in_features=512, out_features=512, bias=False)\n",
296
+ " (to_k): Linear(in_features=1024, out_features=512, bias=False)\n",
297
+ " (to_v): Linear(in_features=1024, out_features=512, bias=False)\n",
298
+ " (to_out): ModuleList(\n",
299
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
300
+ " (1): Dropout(p=0.0, inplace=False)\n",
301
+ " )\n",
302
+ " )\n",
303
+ " (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
304
+ " (ff): FeedForward(\n",
305
+ " (net): ModuleList(\n",
306
+ " (0): GEGLU(\n",
307
+ " (proj): Linear(in_features=512, out_features=4096, bias=True)\n",
308
+ " )\n",
309
+ " (1): Dropout(p=0.0, inplace=False)\n",
310
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
311
+ " )\n",
312
+ " )\n",
313
+ " )\n",
314
+ " )\n",
315
+ " (proj_out): Linear(in_features=512, out_features=512, bias=True)\n",
316
+ " )\n",
317
+ " )\n",
318
+ " (resnets): ModuleList(\n",
319
+ " (0): ResnetBlock2D(\n",
320
+ " (norm1): GroupNorm(32, 1536, eps=1e-05, affine=True)\n",
321
+ " (conv1): Conv2d(1536, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
322
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
323
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
324
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
325
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
326
+ " (nonlinearity): SiLU()\n",
327
+ " (conv_shortcut): Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1))\n",
328
+ " )\n",
329
+ " (1): ResnetBlock2D(\n",
330
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
331
+ " (conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
332
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
333
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
334
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
335
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
336
+ " (nonlinearity): SiLU()\n",
337
+ " (conv_shortcut): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))\n",
338
+ " )\n",
339
+ " (2): ResnetBlock2D(\n",
340
+ " (norm1): GroupNorm(32, 768, eps=1e-05, affine=True)\n",
341
+ " (conv1): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
342
+ " (time_emb_proj): Linear(in_features=512, out_features=512, bias=True)\n",
343
+ " (norm2): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
344
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
345
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
346
+ " (nonlinearity): SiLU()\n",
347
+ " (conv_shortcut): Conv2d(768, 512, kernel_size=(1, 1), stride=(1, 1))\n",
348
+ " )\n",
349
+ " )\n",
350
+ " (upsamplers): ModuleList(\n",
351
+ " (0): Upsample2D(\n",
352
+ " (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
353
+ " )\n",
354
+ " )\n",
355
+ " )\n",
356
+ " (2): UpBlock2D(\n",
357
+ " (resnets): ModuleList(\n",
358
+ " (0): ResnetBlock2D(\n",
359
+ " (norm1): GroupNorm(32, 768, eps=1e-05, affine=True)\n",
360
+ " (conv1): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
361
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
362
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
363
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
364
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
365
+ " (nonlinearity): SiLU()\n",
366
+ " (conv_shortcut): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1))\n",
367
+ " )\n",
368
+ " (1): ResnetBlock2D(\n",
369
+ " (norm1): GroupNorm(32, 512, eps=1e-05, affine=True)\n",
370
+ " (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
371
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
372
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
373
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
374
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
375
+ " (nonlinearity): SiLU()\n",
376
+ " (conv_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))\n",
377
+ " )\n",
378
+ " (2): ResnetBlock2D(\n",
379
+ " (norm1): GroupNorm(32, 384, eps=1e-05, affine=True)\n",
380
+ " (conv1): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
381
+ " (time_emb_proj): Linear(in_features=512, out_features=256, bias=True)\n",
382
+ " (norm2): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
383
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
384
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
385
+ " (nonlinearity): SiLU()\n",
386
+ " (conv_shortcut): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1))\n",
387
+ " )\n",
388
+ " )\n",
389
+ " (upsamplers): ModuleList(\n",
390
+ " (0): Upsample2D(\n",
391
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
392
+ " )\n",
393
+ " )\n",
394
+ " )\n",
395
+ " (3): UpBlock2D(\n",
396
+ " (resnets): ModuleList(\n",
397
+ " (0): ResnetBlock2D(\n",
398
+ " (norm1): GroupNorm(32, 384, eps=1e-05, affine=True)\n",
399
+ " (conv1): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
400
+ " (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)\n",
401
+ " (norm2): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
402
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
403
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
404
+ " (nonlinearity): SiLU()\n",
405
+ " (conv_shortcut): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1))\n",
406
+ " )\n",
407
+ " (1-2): 2 x ResnetBlock2D(\n",
408
+ " (norm1): GroupNorm(32, 256, eps=1e-05, affine=True)\n",
409
+ " (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
410
+ " (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)\n",
411
+ " (norm2): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
412
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
413
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
414
+ " (nonlinearity): SiLU()\n",
415
+ " (conv_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
416
+ " )\n",
417
+ " )\n",
418
+ " )\n",
419
+ " )\n",
420
+ " (mid_block): UNetMidBlock2DCrossAttn(\n",
421
+ " (attentions): ModuleList(\n",
422
+ " (0): Transformer2DModel(\n",
423
+ " (norm): GroupNorm(32, 1024, eps=1e-06, affine=True)\n",
424
+ " (proj_in): Linear(in_features=1024, out_features=1024, bias=True)\n",
425
+ " (transformer_blocks): ModuleList(\n",
426
+ " (0-3): 4 x BasicTransformerBlock(\n",
427
+ " (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
428
+ " (attn1): Attention(\n",
429
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
430
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
431
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
432
+ " (to_out): ModuleList(\n",
433
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
434
+ " (1): Dropout(p=0.0, inplace=False)\n",
435
+ " )\n",
436
+ " )\n",
437
+ " (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
438
+ " (attn2): Attention(\n",
439
+ " (to_q): Linear(in_features=1024, out_features=1024, bias=False)\n",
440
+ " (to_k): Linear(in_features=1024, out_features=1024, bias=False)\n",
441
+ " (to_v): Linear(in_features=1024, out_features=1024, bias=False)\n",
442
+ " (to_out): ModuleList(\n",
443
+ " (0): Linear(in_features=1024, out_features=1024, bias=True)\n",
444
+ " (1): Dropout(p=0.0, inplace=False)\n",
445
+ " )\n",
446
+ " )\n",
447
+ " (norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
448
+ " (ff): FeedForward(\n",
449
+ " (net): ModuleList(\n",
450
+ " (0): GEGLU(\n",
451
+ " (proj): Linear(in_features=1024, out_features=8192, bias=True)\n",
452
+ " )\n",
453
+ " (1): Dropout(p=0.0, inplace=False)\n",
454
+ " (2): Linear(in_features=4096, out_features=1024, bias=True)\n",
455
+ " )\n",
456
+ " )\n",
457
+ " )\n",
458
+ " )\n",
459
+ " (proj_out): Linear(in_features=1024, out_features=1024, bias=True)\n",
460
+ " )\n",
461
+ " )\n",
462
+ " (resnets): ModuleList(\n",
463
+ " (0-1): 2 x ResnetBlock2D(\n",
464
+ " (norm1): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
465
+ " (conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
466
+ " (time_emb_proj): Linear(in_features=512, out_features=1024, bias=True)\n",
467
+ " (norm2): GroupNorm(32, 1024, eps=1e-05, affine=True)\n",
468
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
469
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
470
+ " (nonlinearity): SiLU()\n",
471
+ " )\n",
472
+ " )\n",
473
+ " )\n",
474
+ " (conv_norm_out): GroupNorm(32, 128, eps=1e-05, affine=True)\n",
475
+ " (conv_act): SiLU()\n",
476
+ " (conv_out): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
477
+ ")\n",
478
+ "Output shape: torch.Size([1, 32, 60, 48])\n"
479
+ ]
480
+ }
481
+ ],
482
+ "source": [
483
+ "config_sdxs = {\n",
484
+ " # === Основные размеры и каналы ===\n",
485
+ " \"in_channels\": 32, # Количество входных каналов (совместимость с 16-канальным VAE)\n",
486
+ " \"out_channels\": 32, # Количество выходных каналов (симметрично in_channels)\n",
487
+ " \"center_input_sample\": False, # Отключение центрирования входных данных (стандарт для диффузионных моделей)\n",
488
+ " \"flip_sin_to_cos\": True, # Автоматическое преобразование sin/cos в эмбеддингах времени (для стабильности)\n",
489
+ " \"freq_shift\": 0, # Сдвиг частоты (0 - стандартное значение для частотных эмбеддингов)\n",
490
+ "\n",
491
+ " # === Архитектура блоков ===\n",
492
+ " \"down_block_types\": [ # Типы блоков энкодера (иерархия обработки):\n",
493
+ " \"DownBlock2D\",\n",
494
+ " \"DownBlock2D\",\n",
495
+ " \"CrossAttnDownBlock2D\",\n",
496
+ " \"CrossAttnDownBlock2D\",\n",
497
+ " ],\n",
498
+ " \"mid_block_type\": \"UNetMidBlock2DCrossAttn\", # Центральный блок с cross-attention (бутылочное горлышко сети)\n",
499
+ " \"up_block_types\": [ # Типы блоков декодера (восстановление изображения):\n",
500
+ " \"CrossAttnUpBlock2D\",\n",
501
+ " \"CrossAttnUpBlock2D\", \n",
502
+ " \"UpBlock2D\",\n",
503
+ " \"UpBlock2D\",\n",
504
+ " ],\n",
505
+ " \"only_cross_attention\": False, # Использование как cross-attention, так и self-attention\n",
506
+ "\n",
507
+ " # === Конфигурация каналов ===\n",
508
+ " \"block_out_channels\": [128, 256, 512, 1024], \n",
509
+ " \"layers_per_block\": 2, # Число слоев в блоках\n",
510
+ " \"downsample_padding\": 1, # Паддинг при уменьшении разрешения\n",
511
+ " \"mid_block_scale_factor\": 1.0, # Усиление сигнала в центральном блоке\n",
512
+ "\n",
513
+ " # === Нормализация ===\n",
514
+ " \"norm_num_groups\": 32, # Число групп для GroupNorm (оптимально для стабильности)\n",
515
+ " \"norm_eps\": 1e-05, # Эпсилон для нормализации (стандартное значение)\n",
516
+ "\n",
517
+ " # === Cross-Attention ===\n",
518
+ " \"cross_attention_dim\": 1024, # Размерность текстовых эмбеддинго\n",
519
+ " \n",
520
+ " \"transformer_layers_per_block\": [1, 1, 2, 4], # Число трансформерных слоев (уменьшение с глубиной)\n",
521
+ " \"attention_head_dim\": [2, 4, 8, 16], # Размерность головы внимания \n",
522
+ " \"dual_cross_attention\": False, # Отключение двойного внимания (упрощение архитектуры)\n",
523
+ " \"use_linear_projection\": True, # Изменено на True для лучшей организации памяти\n",
524
+ "\n",
525
+ " # === ResNet Блоки ===\n",
526
+ " \"resnet_time_scale_shift\": \"default\", # Способ интеграции временных эмбеддингов\n",
527
+ " \"resnet_skip_time_act\": False, # Отключение активации в skip-соединениях\n",
528
+ " \"resnet_out_scale_factor\": 1.0, # Коэффициент масштабирования выхода ResNet\n",
529
+ "\n",
530
+ " # === Временные эмбеддинги ===\n",
531
+ " \"time_embedding_type\": \"positional\", # Тип временных эмбеддингов (стандартный подход)\n",
532
+ "\n",
533
+ " # === Свертки ===\n",
534
+ " \"conv_in_kernel\": 3, # Ядро входной свертки (баланс между рецептивным полем и параметрами)\n",
535
+ " \"conv_out_kernel\": 3, # Ядро выходной свертки (симметрично входной)\n",
536
+ "}\n",
537
+ "\n",
538
+ "if 1:\n",
539
+ " checkpoint_path = \"/workspace/sdxs/sdxs_flux\"#\"sdxs\"\n",
540
+ " import torch\n",
541
+ " from diffusers import UNet2DConditionModel\n",
542
+ " print(\"test unet\")\n",
543
+ " new_unet = UNet2DConditionModel(**config_sdxs).to(\"cuda\", dtype=torch.float16)\n",
544
+ "\n",
545
+ " assert all(ch % 32 == 0 for ch in new_unet.config[\"block_out_channels\"]), \"Каналы должны быть кратны 32\"\n",
546
+ " num_params = sum(p.numel() for p in new_unet.parameters())\n",
547
+ " print(f\"Количество параметров: {num_params}\")\n",
548
+ "\n",
549
+ " # Генерация тестового латента (640x512 в latent space)\n",
550
+ " test_latent = torch.randn(1, 32, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n",
551
+ " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n",
552
+ " encoder_hidden_states = torch.randn(1, 77, 1024).to(\"cuda\", dtype=torch.float16)\n",
553
+ " \n",
554
+ " with torch.no_grad():\n",
555
+ " output = new_unet(\n",
556
+ " test_latent, \n",
557
+ " timesteps, \n",
558
+ " encoder_hidden_states\n",
559
+ " ).sample\n",
560
+ " \n",
561
+ " print(f\"Output shape: {output.shape}\") \n",
562
+ " new_unet.save_pretrained(checkpoint_path)\n",
563
+ " print(new_unet)\n",
564
+ " del new_unet\n",
565
+ " torch.cuda.empty_cache()\n",
566
+ " print(f\"Output shape: {output.shape}\") \n",
567
+ " # Количество параметров: 1228819856 2042991632 1889899536"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "id": "a81bca2d-d96d-4794-90f4-b54ea8682b52",
574
+ "metadata": {},
575
+ "outputs": [],
576
+ "source": []
577
+ }
578
+ ],
579
+ "metadata": {
580
+ "kernelspec": {
581
+ "display_name": "Python 3 (ipykernel)",
582
+ "language": "python",
583
+ "name": "python3"
584
+ },
585
+ "language_info": {
586
+ "codemirror_mode": {
587
+ "name": "ipython",
588
+ "version": 3
589
+ },
590
+ "file_extension": ".py",
591
+ "mimetype": "text/x-python",
592
+ "name": "python",
593
+ "nbconvert_exporter": "python",
594
+ "pygments_lexer": "ipython3",
595
+ "version": "3.12.3"
596
+ }
597
+ },
598
+ "nbformat": 4,
599
+ "nbformat_minor": 5
600
+ }
train_flux.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from comet_ml import Experiment
2
+ import os
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from torch.utils.data import DataLoader, Sampler
8
+ from torch.utils.data.distributed import DistributedSampler
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from collections import defaultdict
11
+ from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image, ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+ from transformers import AutoTokenizer, AutoModel
28
+
29
+ # --------------------------- Параметры ---------------------------
30
+ ds_path = "/workspace/sdxs/datasets/mjnj"
31
+ project = "sdxs_flux"
32
+ batch_size = 32
33
+ base_learning_rate = 4e-5 #2.7e-5
34
+ min_learning_rate = 9e-6 #2.7e-5
35
+ num_epochs = 10
36
+ sample_interval_share = 10
37
+ cfg_dropout = 0.9
38
+ max_length = 192
39
+ use_wandb = True
40
+ use_comet_ml = False
41
+ save_model = True
42
+ use_decay = True
43
+ fbp = False
44
+ optimizer_type = "adam8bit"
45
+ torch_compile = False
46
+ unet_gradient = True
47
+ fixed_seed = False
48
+ shuffle = True
49
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
50
+ comet_ml_workspace = "recoilme"
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+ torch.backends.cudnn.allow_tf32 = True
53
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
54
+ dtype = torch.float32
55
+ save_barrier = 1.01
56
+ warmup_percent = 0.01
57
+ percentile_clipping = 95 #96 #97
58
+ betta2 = 0.995
59
+ eps = 1e-7
60
+ clip_grad_norm = 1.0
61
+ limit = 0
62
+ checkpoints_folder = ""
63
+ mixed_precision = "no"
64
+ gradient_accumulation_steps = 1
65
+
66
+ accelerator = Accelerator(
67
+ mixed_precision=mixed_precision,
68
+ gradient_accumulation_steps=gradient_accumulation_steps
69
+ )
70
+ device = accelerator.device
71
+
72
+ # Параметры для диффузии
73
+ n_diffusion_steps = 40
74
+ samples_to_generate = 12
75
+ guidance_scale = 4
76
+
77
+ # Папки для сохранения результатов
78
+ generated_folder = "samples"
79
+ os.makedirs(generated_folder, exist_ok=True)
80
+
81
+ # Настройка seed
82
+ current_date = datetime.now()
83
+ seed = int(current_date.strftime("%Y%m%d"))
84
+ if fixed_seed:
85
+ torch.manual_seed(seed)
86
+ np.random.seed(seed)
87
+ random.seed(seed)
88
+ if torch.cuda.is_available():
89
+ torch.cuda.manual_seed_all(seed)
90
+
91
+ # --------------------------- Параметры LoRA ---------------------------
92
+ lora_name = ""
93
+ lora_rank = 32
94
+ lora_alpha = 64
95
+
96
+ print("init")
97
+
98
+ loss_ratios = {
99
+ "mse": 1.5,
100
+ "mae": 0.5,
101
+ }
102
+ median_coeff_steps = 256
103
+
104
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
105
+ class MedianLossNormalizer:
106
+ def __init__(self, desired_ratios: dict, window_steps: int):
107
+ # нормируем доли на случай, если сумма != 1
108
+ #s = sum(desired_ratios.values())
109
+ #self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
110
+ #self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
111
+ #self.window = window_steps
112
+ self.ratios = {k: float(v) for k, v in desired_ratios.items()}
113
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
114
+ self.window = window_steps
115
+
116
+ def update_and_total(self, losses: dict):
117
+ """
118
+ losses: dict ключ->тензор (значения лоссов)
119
+ Поведение:
120
+ - буферим ABS(l) только для активных (ratio>0) лоссов
121
+ - coeff = ratio / median(abs(loss))
122
+ - total = sum(coeff * loss) по активным лоссам
123
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
124
+ """
125
+ # буферим только активные лоссы
126
+ for k, v in losses.items():
127
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
128
+ val = v.detach().abs().mean().cpu().item() # .item() лучше float() для тензоров
129
+ self.buffers[k].append(val)
130
+ #self.buffers[k].append(float(v.detach().abs().cpu()))
131
+
132
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
133
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
134
+
135
+ # суммируем только по активным (ratio>0)
136
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
137
+ return total, coeffs, meds
138
+
139
+ # создаём normalizer после определения loss_ratios
140
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
141
+
142
+ # --------------------------- Ин��циализация WandB ---------------------------
143
+ if accelerator.is_main_process:
144
+ if use_wandb:
145
+ wandb.init(project=project+lora_name, config={
146
+ "batch_size": batch_size,
147
+ "base_learning_rate": base_learning_rate,
148
+ "num_epochs": num_epochs,
149
+ "optimizer_type": optimizer_type,
150
+ })
151
+ if use_comet_ml:
152
+ from comet_ml import Experiment
153
+ comet_experiment = Experiment(
154
+ api_key=comet_ml_api_key,
155
+ project_name=project,
156
+ workspace=comet_ml_workspace
157
+ )
158
+ hyper_params = {
159
+ "batch_size": batch_size,
160
+ "base_learning_rate": base_learning_rate,
161
+ "num_epochs": num_epochs,
162
+ }
163
+ comet_experiment.log_parameters(hyper_params)
164
+
165
+ # Включение Flash Attention 2/SDPA
166
+ torch.backends.cuda.enable_flash_sdp(True)
167
+
168
+ # --------------------------- Загрузка моделей ---------------------------
169
+ #vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
170
+ vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
171
+ tokenizer = AutoTokenizer.from_pretrained("tokenizer")
172
+ text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
173
+
174
+ # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
175
+ def encode_texts(texts, max_length=max_length):
176
+ # Если тексты пустые (для unconditional), создаем заглушки
177
+ if texts is None:
178
+ # В случае None возвращаем нули (логика для get_negative_embedding)
179
+ # Но здесь мы обычно ожидаем список строк.
180
+ pass
181
+
182
+ with torch.no_grad():
183
+ if isinstance(texts, str):
184
+ texts = [texts]
185
+
186
+ for i, prompt_item in enumerate(texts):
187
+ messages = [
188
+ {"role": "user", "content": prompt_item},
189
+ ]
190
+ prompt_item = tokenizer.apply_chat_template(
191
+ messages,
192
+ tokenize=False,
193
+ add_generation_prompt=True,
194
+ #enable_thinking=True,
195
+ )
196
+ #print(prompt_item+"\n")
197
+ texts[i] = prompt_item
198
+
199
+ toks = tokenizer(
200
+ texts,
201
+ return_tensors="pt",
202
+ padding="max_length",
203
+ truncation=True,
204
+ max_length=max_length
205
+ ).to(device)
206
+
207
+ outs = text_model(**toks, output_hidden_states=True, return_dict=True)
208
+
209
+ # Используем last_hidden_state или hidden_states[-1] (если Qwen, лучше last_hidden_state - прим человека: ХУЙ)
210
+ hidden = outs.hidden_states[-2]
211
+
212
+ # 2. Маска внимания
213
+ attention_mask = toks["attention_mask"]
214
+
215
+ # 3. Пулинг-эмбеддинг (Последний токен)
216
+ sequence_lengths = attention_mask.sum(dim=1) - 1
217
+ batch_size = hidden.shape[0]
218
+ pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
219
+
220
+ #return hidden, attention_mask
221
+ # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
222
+ # 1. Расширяем пулинг-вектор до последовательности [B, 1, emb]
223
+ pooled_expanded = pooled.unsqueeze(1)
224
+
225
+ # 2. Объединяем последовательность токенов и пулинг-вектор
226
+ # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
227
+ # Теперь: [B, 1 + L, emb]. Пулинг стал токеном в НАЧАЛЕ.
228
+ new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
229
+
230
+ # 3. Обновляем маску внимания для нового токена
231
+ # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
232
+ # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
233
+ new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
234
+
235
+ return new_encoder_hidden_states, new_attention_mask
236
+
237
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
238
+ if shift_factor is None: shift_factor = 0.0
239
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
240
+ if scaling_factor is None: scaling_factor = 1.0
241
+
242
+ from diffusers import FlowMatchEulerDiscreteScheduler
243
+ num_train_timesteps = 1000
244
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps)
245
+
246
+ class DistributedResolutionBatchSampler(Sampler):
247
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
248
+ self.dataset = dataset
249
+ self.batch_size = max(1, batch_size // num_replicas)
250
+ self.num_replicas = num_replicas
251
+ self.rank = rank
252
+ self.shuffle = shuffle
253
+ self.drop_last = drop_last
254
+ self.epoch = 0
255
+
256
+ try:
257
+ widths = np.array(dataset["width"])
258
+ heights = np.array(dataset["height"])
259
+ except KeyError:
260
+ widths = np.zeros(len(dataset))
261
+ heights = np.zeros(len(dataset))
262
+
263
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
264
+ self.size_groups = {}
265
+ for w, h in self.size_keys:
266
+ mask = (widths == w) & (heights == h)
267
+ self.size_groups[(w, h)] = np.where(mask)[0]
268
+
269
+ self.group_num_batches = {}
270
+ total_batches = 0
271
+ for size, indices in self.size_groups.items():
272
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
273
+ self.group_num_batches[size] = num_full_batches
274
+ total_batches += num_full_batches
275
+
276
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
277
+
278
+ def __iter__(self):
279
+ if torch.cuda.is_available():
280
+ torch.cuda.empty_cache()
281
+ all_batches = []
282
+ rng = np.random.RandomState(self.epoch)
283
+
284
+ for size, indices in self.size_groups.items():
285
+ indices = indices.copy()
286
+ if self.shuffle:
287
+ rng.shuffle(indices)
288
+ num_full_batches = self.group_num_batches[size]
289
+ if num_full_batches == 0:
290
+ continue
291
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
292
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
293
+ start_idx = self.rank * self.batch_size
294
+ end_idx = start_idx + self.batch_size
295
+ gpu_batches = batches[:, start_idx:end_idx]
296
+ all_batches.extend(gpu_batches)
297
+
298
+ if self.shuffle:
299
+ rng.shuffle(all_batches)
300
+ accelerator.wait_for_everyone()
301
+ return iter(all_batches)
302
+
303
+ def __len__(self):
304
+ return self.num_batches
305
+
306
+ def set_epoch(self, epoch):
307
+ self.epoch = epoch
308
+
309
+ # --- [UPDATED] Функция для фиксированных семплов ---
310
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
311
+ size_groups = defaultdict(list)
312
+ try:
313
+ widths = dataset["width"]
314
+ heights = dataset["height"]
315
+ except KeyError:
316
+ widths = [0] * len(dataset)
317
+ heights = [0] * len(dataset)
318
+ for i, (w, h) in enumerate(zip(widths, heights)):
319
+ size = (w, h)
320
+ size_groups[size].append(i)
321
+
322
+ fixed_samples = {}
323
+ for size, indices in size_groups.items():
324
+ n_samples = min(samples_per_group, len(indices))
325
+ if len(size_groups)==1:
326
+ n_samples = samples_to_generate
327
+ if n_samples == 0:
328
+ continue
329
+ sample_indices = random.sample(indices, n_samples)
330
+ samples_data = [dataset[idx] for idx in sample_indices]
331
+
332
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
333
+ texts = [item["text"] for item in samples_data]
334
+
335
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
336
+ embeddings, masks = encode_texts(texts)
337
+
338
+ fixed_samples[size] = (latents, embeddings, masks, texts)
339
+
340
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
341
+ return fixed_samples
342
+
343
+ if limit > 0:
344
+ dataset = load_from_disk(ds_path).select(range(limit))
345
+ else:
346
+ dataset = load_from_disk(ds_path)
347
+
348
+ dataset = dataset.filter(
349
+ lambda x: [not (path.startswith("/workspace/ds/animesfw") or path.startswith("/workspace/ds/d4/animesfw")) for path in x["image_path"]],
350
+ batched=True,
351
+ batch_size=10000, # обрабатываем по 10к строк за раз
352
+ num_proc=8
353
+ )
354
+ print(f"Осталось примеров после фильтрации: {len(dataset)}")
355
+
356
+ # --- [UPDATED] Collate Function ---
357
+ def collate_fn_simple(batch):
358
+ # 1. Латенты (VAE)
359
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype)
360
+
361
+ # 2. Текст берем сырой из датасета
362
+ raw_texts = [item["text"] for item in batch]
363
+ texts = [
364
+ "" if t.lower().startswith("zero")
365
+ else "" if random.random() < cfg_dropout
366
+ else t[1:].lstrip() if t.startswith(".")
367
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
368
+ for t in raw_texts
369
+ ]
370
+
371
+ # 3. Кодируем на лету
372
+ # Возвращает: hidden (B, L, D), mask (B, L)
373
+ embeddings, attention_mask = encode_texts(texts)
374
+
375
+ # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай пр��ведем к long
376
+ attention_mask = attention_mask.to(dtype=torch.int64)
377
+
378
+ return latents, embeddings, attention_mask
379
+
380
+ batch_sampler = DistributedResolutionBatchSampler(
381
+ dataset=dataset,
382
+ batch_size=batch_size,
383
+ num_replicas=accelerator.num_processes,
384
+ rank=accelerator.process_index,
385
+ shuffle=shuffle
386
+ )
387
+
388
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
389
+ if accelerator.is_main_process:
390
+ print("Total samples", len(dataloader))
391
+ dataloader = accelerator.prepare(dataloader)
392
+
393
+ start_epoch = 0
394
+ global_step = 0
395
+ total_training_steps = (len(dataloader) * num_epochs)
396
+ world_size = accelerator.state.num_processes
397
+
398
+ # Загрузка UNet
399
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
400
+ if os.path.isdir(latest_checkpoint):
401
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
402
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
403
+ if unet_gradient:
404
+ unet.enable_gradient_checkpointing()
405
+ unet.set_use_memory_efficient_attention_xformers(False)
406
+ try:
407
+ unet.set_attn_processor(AttnProcessor2_0())
408
+ except Exception as e:
409
+ print(f"Ошибка при включении SDPA: {e}")
410
+ unet.set_use_memory_efficient_attention_xformers(True)
411
+ else:
412
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
413
+
414
+ if lora_name:
415
+ # ... (Код LoRA без изменений, опущен для краткости, если не используется, иначе раскомментируйте оригинальный блок) ...
416
+ pass
417
+
418
+ # Оптимизатор
419
+ if lora_name:
420
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
421
+ else:
422
+ if fbp:
423
+ trainable_params = list(unet.parameters())
424
+
425
+ def create_optimizer(name, params):
426
+ if name == "adam8bit":
427
+ return bnb.optim.AdamW8bit(
428
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
429
+ percentile_clipping=percentile_clipping
430
+ )
431
+ elif name == "adam":
432
+ return torch.optim.AdamW(
433
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
434
+ )
435
+ elif name == "muon":
436
+ from muon import MuonWithAuxAdam
437
+ trainable_params = [p for p in params if p.requires_grad]
438
+ hidden_weights = [p for p in trainable_params if p.ndim >= 2]
439
+ hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
440
+
441
+ param_groups = [
442
+ dict(params=hidden_weights, use_muon=True,
443
+ lr=1e-3, weight_decay=1e-4),
444
+ dict(params=hidden_gains_biases, use_muon=False,
445
+ lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
446
+ ]
447
+ optimizer = MuonWithAuxAdam(param_groups)
448
+ from snooc import SnooC
449
+ return SnooC(optimizer)
450
+ else:
451
+ raise ValueError(f"Unknown optimizer: {name}")
452
+
453
+ if fbp:
454
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
455
+ def optimizer_hook(param):
456
+ optimizer_dict[param].step()
457
+ optimizer_dict[param].zero_grad(set_to_none=True)
458
+ for param in trainable_params:
459
+ param.register_post_accumulate_grad_hook(optimizer_hook)
460
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
461
+ else:
462
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
463
+ def lr_schedule(step):
464
+ x = step / (total_training_steps * world_size)
465
+ warmup = warmup_percent
466
+ if not use_decay:
467
+ return base_learning_rate
468
+ if x < warmup:
469
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
470
+ decay_ratio = (x - warmup) / (1 - warmup)
471
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
472
+ (1 + math.cos(math.pi * decay_ratio))
473
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
474
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
475
+
476
+ if torch_compile:
477
+ print("compiling")
478
+ unet = torch.compile(unet)
479
+ print("compiling - ok")
480
+
481
+ # Фиксированные семплы
482
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
483
+
484
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
485
+ def get_negative_embedding(neg_prompt="", batch_size=1):
486
+ if not neg_prompt:
487
+ hidden_dim = 2048
488
+ seq_len = max_length
489
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
490
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
491
+ return empty_emb, empty_mask
492
+
493
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
494
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
495
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
496
+
497
+ return uncond_emb, uncond_mask
498
+
499
+ # Получаем негативные (пустые) условия для валидации
500
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
501
+
502
+ # --- Функция генерации семплов ---
503
+ @torch.compiler.disable()
504
+ @torch.no_grad()
505
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
506
+ uncond_emb, uncond_mask = uncond_data
507
+
508
+ original_model = None
509
+ try:
510
+ if not torch_compile:
511
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
512
+ else:
513
+ original_model = unet.eval()
514
+
515
+ vae.to(device=device).eval()
516
+
517
+ all_generated_images = []
518
+ all_captions = []
519
+
520
+ # Распаковываем 5 элементов (добавились mask)
521
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
522
+ width, height = size
523
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
524
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
525
+ sample_mask = sample_mask.to(device=device)
526
+
527
+ latents = torch.randn(
528
+ sample_latents.shape,
529
+ device=device,
530
+ dtype=sample_latents.dtype,
531
+ generator=torch.Generator(device=device).manual_seed(seed)
532
+ )
533
+
534
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
535
+
536
+ for t in scheduler.timesteps:
537
+ if guidance_scale != 1:
538
+ latent_model_input = torch.cat([latents, latents], dim=0)
539
+
540
+ # Подготовка батчей для CFG (Negative + Positive)
541
+ # 1. Embeddings
542
+ curr_batch_size = sample_text_embeddings.shape[0]
543
+ seq_len = sample_text_embeddings.shape[1]
544
+ hidden_dim = sample_text_embeddings.shape[2]
545
+
546
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
547
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
548
+
549
+ # 2. Masks
550
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
551
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
552
+
553
+ else:
554
+ latent_model_input = latents
555
+ text_embeddings_batch = sample_text_embeddings
556
+ attention_mask_batch = sample_mask
557
+
558
+ # Предсказание с передачей всех условий
559
+ model_out = original_model(
560
+ latent_model_input,
561
+ t,
562
+ encoder_hidden_states=text_embeddings_batch,
563
+ encoder_attention_mask=attention_mask_batch,
564
+ )
565
+ flow = getattr(model_out, "sample", model_out)
566
+
567
+ if guidance_scale != 1:
568
+ flow_uncond, flow_cond = flow.chunk(2)
569
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
570
+
571
+ latents = scheduler.step(flow, t, latents).prev_sample
572
+
573
+ current_latents = latents
574
+
575
+ latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
576
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
577
+ decoded_fp32 = decoded.to(torch.float32)
578
+
579
+ for img_idx, img_tensor in enumerate(decoded_fp32):
580
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
581
+ img = img.transpose(1, 2, 0)
582
+
583
+ if np.isnan(img).any():
584
+ print("NaNs found, saving stopped! Step:", step)
585
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
586
+
587
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
588
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
589
+ max_w_overall = max(255, max_w_overall)
590
+ max_h_overall = max(255, max_h_overall)
591
+
592
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
593
+ all_generated_images.append(padded_img)
594
+
595
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
596
+ all_captions.append(caption_text)
597
+
598
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
599
+ pil_img.save(sample_path, "JPEG", quality=96)
600
+
601
+ if use_wandb and accelerator.is_main_process:
602
+ wandb_images = [
603
+ wandb.Image(img, caption=f"{all_captions[i]}")
604
+ for i, img in enumerate(all_generated_images)
605
+ ]
606
+ wandb.log({"generated_images": wandb_images})
607
+ if use_comet_ml and accelerator.is_main_process:
608
+ for i, img in enumerate(all_generated_images):
609
+ comet_experiment.log_image(
610
+ image_data=img,
611
+ name=f"step_{step}_img_{i}",
612
+ step=step,
613
+ metadata={"caption": all_captions[i]}
614
+ )
615
+ finally:
616
+ vae.to("cpu")
617
+ torch.cuda.empty_cache()
618
+ gc.collect()
619
+
620
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
621
+ if accelerator.is_main_process:
622
+ if save_model:
623
+ print("Генерация сэмплов до старта обучения...")
624
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
625
+ accelerator.wait_for_everyone()
626
+
627
+ def save_checkpoint(unet, variant=""):
628
+ if accelerator.is_main_process:
629
+ if lora_name:
630
+ save_lora_checkpoint(unet)
631
+ else:
632
+ model_to_save = None
633
+ if not torch_compile:
634
+ model_to_save = accelerator.unwrap_model(unet)
635
+ else:
636
+ model_to_save = unet
637
+
638
+ if variant != "":
639
+ model_to_save.to(dtype=torch.float16).save_pretrained(
640
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
641
+ )
642
+ else:
643
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
644
+
645
+ unet = unet.to(dtype=dtype)
646
+
647
+ # --------------------------- Тренировочный цикл ---------------------------
648
+ if accelerator.is_main_process:
649
+ print(f"Total steps per GPU: {total_training_steps}")
650
+
651
+ epoch_loss_points = []
652
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
653
+
654
+ steps_per_epoch = len(dataloader)
655
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
656
+ min_loss = 4.
657
+
658
+ for epoch in range(start_epoch, start_epoch + num_epochs):
659
+ batch_losses = []
660
+ batch_grads = []
661
+ batch_sampler.set_epoch(epoch)
662
+ accelerator.wait_for_everyone()
663
+ unet.train()
664
+
665
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
666
+ with accelerator.accumulate(unet):
667
+ if save_model == False and epoch == 0 and step == 5 :
668
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
669
+ print(f"Шаг {step}: {used_gb:.2f} GB")
670
+
671
+ # шум
672
+ noise = torch.randn_like(latents, dtype=latents.dtype)
673
+ # берём t из [0, 1]
674
+ #t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
675
+ u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
676
+ t = torch.sigmoid(torch.randn_like(u))
677
+
678
+ # интерполяция между x0 и шумом
679
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
680
+ # делаем integer timesteps для UNet
681
+ timesteps = (t * scheduler.config.num_train_timesteps).long()
682
+
683
+ # --- Вызов UNet с маской ---
684
+ model_pred = unet(
685
+ noisy_latents,
686
+ timesteps,
687
+ encoder_hidden_states=embeddings,
688
+ encoder_attention_mask=attention_mask
689
+ ).sample
690
+
691
+ target = noise - latents
692
+
693
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
694
+ mae_loss = F.l1_loss(model_pred.float(), target.float())
695
+ batch_losses.append(mse_loss.detach().item())
696
+
697
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
698
+ accelerator.wait_for_everyone()
699
+
700
+ losses_dict = {}
701
+ losses_dict["mse"] = mse_loss
702
+ losses_dict["mae"] = mae_loss
703
+
704
+ # === Нормализация всех лоссов ===
705
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
706
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
707
+
708
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
709
+ accelerator.wait_for_everyone()
710
+
711
+ accelerator.backward(total_loss)
712
+
713
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
714
+ accelerator.wait_for_everyone()
715
+
716
+ grad = 0.0
717
+ if not fbp:
718
+ if accelerator.sync_gradients:
719
+ #with torch.amp.autocast('cuda', enabled=False):
720
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
721
+ if grad_val is not None:
722
+ grad = float(grad_val)
723
+ print("grad_val is None")
724
+ else:
725
+ grad = 0.0
726
+ optimizer.step()
727
+ lr_scheduler.step()
728
+ optimizer.zero_grad(set_to_none=True)
729
+
730
+ if accelerator.sync_gradients:
731
+ global_step += 1
732
+ progress_bar.update(1)
733
+ if accelerator.is_main_process:
734
+ if fbp:
735
+ current_lr = base_learning_rate
736
+ else:
737
+ current_lr = lr_scheduler.get_last_lr()[0]
738
+ batch_grads.append(grad)
739
+
740
+ log_data = {}
741
+ log_data["loss_mse"] = mse_loss.detach().item()
742
+ log_data["loss_mae"] = mae_loss.detach().item()
743
+ log_data["lr"] = current_lr
744
+ log_data["grad"] = grad
745
+ log_data["loss_norm"] = float(total_loss.item())
746
+ for k, c in coeffs.items():
747
+ log_data[f"coeff_{k}"] = float(c)
748
+ if accelerator.sync_gradients:
749
+ if use_wandb:
750
+ wandb.log(log_data, step=global_step)
751
+ if use_comet_ml:
752
+ comet_experiment.log_metrics(log_data, step=global_step)
753
+
754
+ if global_step % sample_interval == 0:
755
+ # Передаем tuple (emb, mask) для негатива
756
+ if save_model:
757
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
758
+ elif epoch % 10 == 0:
759
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
760
+ last_n = sample_interval
761
+
762
+ if save_model:
763
+ has_losses = len(batch_losses) > 0
764
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
765
+ last_loss = batch_losses[-1] if has_losses else 0.0
766
+ max_loss = max(avg_sample_loss, last_loss)
767
+ should_save = max_loss < min_loss * save_barrier
768
+ print(
769
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
770
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
771
+ )
772
+ # 6. Сохранение и обновление
773
+ if should_save:
774
+ min_loss = max_loss
775
+ save_checkpoint(unet)
776
+
777
+ if accelerator.is_main_process:
778
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
779
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
780
+
781
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
782
+ log_data_ep = {
783
+ "epoch_loss": avg_epoch_loss,
784
+ "epoch_grad": avg_epoch_grad,
785
+ "epoch": epoch + 1,
786
+ }
787
+ if use_wandb:
788
+ wandb.log(log_data_ep)
789
+ if use_comet_ml:
790
+ comet_experiment.log_metrics(log_data_ep)
791
+
792
+ if accelerator.is_main_process:
793
+ print("Обучение завершено! Сохраняем финальную модель...")
794
+ #if save_model:
795
+ save_checkpoint(unet,"fp16")
796
+ if use_comet_ml:
797
+ comet_experiment.end()
798
+ accelerator.free_memory()
799
+ if torch.distributed.is_initialized():
800
+ torch.distributed.destroy_process_group()
801
+
802
+ print("Готово!")