recoilme commited on
Commit
2016da2
·
1 Parent(s): 0623613
dataset_flux.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKL,AutoencoderKLWan,AutoencoderKLFlux2
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+
20
+ # ---------------- 1️⃣ Настройки ----------------
21
+ dtype = torch.float32
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ batch_size = 5
24
+ min_size = 192 #384 #320 #192 #256 #192
25
+ max_size = 320 #768 #640 #384 #256 #384
26
+ step = 64 #64
27
+ empty_share = 0.0
28
+ limit = 0
29
+ # Основная процедура обработки
30
+ folder_path = "/workspace/mjnj" #alchemist"
31
+ save_path = "/workspace/sdxs/datasets/mjnj" #"alchemist"
32
+ os.makedirs(save_path, exist_ok=True)
33
+
34
+ # Функция для очистки CUDA памяти
35
+ def clear_cuda_memory():
36
+ if torch.cuda.is_available():
37
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
38
+ print(f"used_gb: {used_gb:.2f} GB")
39
+ torch.cuda.empty_cache()
40
+ gc.collect()
41
+
42
+ # ---------------- 2️⃣ Загрузка моделей ----------------
43
+ def load_models():
44
+ print("Загрузка моделей...")
45
+ #vae = AutoencoderKL.from_pretrained("AiArtLab/sdxs",subfolder="vae1x",torch_dtype=dtype).to(device).eval()
46
+ vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
47
+
48
+ #model_name = "Qwen/Qwen3-0.6B"
49
+ #tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+ #model = AutoModelForCausalLM.from_pretrained(
51
+ # model_name,
52
+ # torch_dtype=dtype,
53
+ # device_map=device
54
+ #).eval()
55
+ #tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
56
+ #model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-0.6B').to("cuda")
57
+ return vae#, model, tokenizer
58
+
59
+ #vae, model, tokenizer = load_models()
60
+ vae = load_models()
61
+
62
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
63
+ if shift_factor is None:
64
+ shift_factor = 0.0
65
+
66
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
67
+ if scaling_factor is None:
68
+ scaling_factor = 1.0
69
+
70
+ latents_mean = getattr(vae.config, "latents_mean", None)
71
+ latents_std = getattr(vae.config, "latents_std", None)
72
+
73
+ # ---------------- 3️⃣ Трансформации ----------------
74
+ def get_image_transform(min_size=256, max_size=512, step=64):
75
+ def transform(img, dry_run=False):
76
+ # Сохраняем исходные размеры изображения
77
+ original_width, original_height = img.size
78
+
79
+ # 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size
80
+ if original_width >= original_height:
81
+ new_width = max_size
82
+ new_height = int(max_size * original_height / original_width)
83
+ else:
84
+ new_height = max_size
85
+ new_width = int(max_size * original_width / original_height)
86
+
87
+ if new_height < min_size or new_width < min_size:
88
+ # 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size
89
+ if original_width <= original_height:
90
+ new_width = min_size
91
+ new_height = int(min_size * original_height / original_width)
92
+ else:
93
+ new_height = min_size
94
+ new_width = int(min_size * original_width / original_height)
95
+
96
+ # 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке
97
+ crop_width = min(max_size, (new_width // step) * step)
98
+ crop_height = min(max_size, (new_height // step) * step)
99
+
100
+ # Убеждаемся, что размеры обрезки не меньше min_size
101
+ crop_width = max(min_size, crop_width)
102
+ crop_height = max(min_size, crop_height)
103
+
104
+ # Если запрошен только предварительный расчёт размеров
105
+ if dry_run:
106
+ return crop_width, crop_height
107
+
108
+ # Конвертация в RGB и ресайз
109
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
110
+
111
+ # Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху)
112
+ top = (new_height - crop_height) // 3
113
+ left = 0
114
+
115
+ # Обрезка изображения
116
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
117
+
118
+ # Сохраняем итоговые размеры после всех преобразований
119
+ final_width, final_height = img_cropped.size
120
+
121
+ # тензор
122
+ img_tensor = ToTensor()(img_cropped)
123
+ img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor)
124
+ return img_tensor, img_cropped, final_width, final_height
125
+
126
+ return transform
127
+
128
+ # ---------------- 4️⃣ Функции обработки ----------------
129
+ def last_token_pool(last_hidden_states: torch.Tensor,
130
+ attention_mask: torch.Tensor) -> torch.Tensor:
131
+ # Определяем, есть ли left padding
132
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
133
+ if left_padding:
134
+ return last_hidden_states[:, -1]
135
+ else:
136
+ sequence_lengths = attention_mask.sum(dim=1) - 1
137
+ batch_size = last_hidden_states.shape[0]
138
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
139
+
140
+ def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150, normalize=False):
141
+ with torch.inference_mode():
142
+ # Токенизация
143
+ batch = tokenizer(
144
+ texts,
145
+ return_tensors="pt",
146
+ padding="max_length",
147
+ truncation=True,
148
+ max_length=max_length
149
+ ).to(device)
150
+
151
+ # Прогон через модель
152
+ #outputs = model(**batch)
153
+
154
+ # Пулинг по last token
155
+ #embeddings = last_token_pool(outputs.last_hidden_state, batch["attention_mask"])
156
+
157
+ # L2-нормализация (опционально, обычно нужна для семантического поиска)
158
+ #if normalize:
159
+ # embeddings = F.normalize(embeddings, p=2, dim=1)
160
+
161
+ # Прогон через базовую модель (внутри CausalLM)
162
+ outputs = model.model(**batch, output_hidden_states=True)
163
+
164
+ # Берем последний слой (эмбеддинги всех токенов)
165
+ hidden_states = outputs.hidden_states[-1] # [B, L, D]
166
+
167
+ # Можно применить нормализацию по каждому токену (как в CLIP)
168
+ if normalize:
169
+ hidden_states = F.normalize(hidden_states, p=2, dim=-1)
170
+
171
+ return hidden_states.cpu().numpy() # embeddings.unsqueeze(1).cpu().numpy()
172
+
173
+ def clean_label(label):
174
+ label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "").replace("The image depicts ","").replace("The image presents ","").replace("The image features ","").replace("The image portrays ","").replace("The image is ","").strip()
175
+ if label.startswith("."):
176
+ label = label[1:].lstrip()
177
+ return label
178
+
179
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
180
+ """
181
+ Обрабатывает список меток для classifier-free guidance.
182
+
183
+ С вероятностью prob_to_make_empty:
184
+ - Метка в первом списке заменяется на пустую строку.
185
+ - К метке во втором списке добавляется префикс "zero:".
186
+
187
+ В противном случае метки в обоих списках остаются оригинальными.
188
+
189
+ """
190
+ labels_for_model = []
191
+ labels_for_logging = []
192
+
193
+ for label in original_labels:
194
+ if random.random() < prob_to_make_empty:
195
+ labels_for_model.append("") # Заменяем на пустую строку для модели
196
+ labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования
197
+ else:
198
+ labels_for_model.append(label) # Оставляем оригинальную метку для модели
199
+ labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования
200
+
201
+ return labels_for_model, labels_for_logging
202
+
203
+ def encode_to_latents(images, texts):
204
+ transform = get_image_transform(min_size, max_size, step)
205
+
206
+ try:
207
+ # Обработка изображений (все одинакового размера)
208
+ transformed_tensors = []
209
+ pil_images = []
210
+ widths, heights = [], []
211
+
212
+ # Применяем трансформацию ко всем изображениям
213
+ for img in images:
214
+ try:
215
+ t_img, pil_img, w, h = transform(img)
216
+ transformed_tensors.append(t_img)
217
+ pil_images.append(pil_img)
218
+ widths.append(w)
219
+ heights.append(h)
220
+ except Exception as e:
221
+ print(f"Ошибка трансформации: {e}")
222
+ continue
223
+
224
+ if not transformed_tensors:
225
+ return None
226
+
227
+ # Создаём батч
228
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
229
+ if batch_tensor.ndim==5:
230
+ batch_tensor = batch_tensor.unsqueeze(2) # [B, C, 1, H, W]
231
+
232
+ # Кодируем батч
233
+ with torch.no_grad():
234
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
235
+ latents = (posteriors - shift_factor) / scaling_factor
236
+
237
+ latents_np = latents.to(dtype).cpu().numpy()
238
+
239
+ # Обрабатываем тексты
240
+ text_labels = [clean_label(text) for text in texts]
241
+
242
+ model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
243
+ #embeddings = encode_texts_batch(model_prompts, tokenizer, model)
244
+
245
+ return {
246
+ "vae": latents_np,
247
+ #"embeddings": embeddings,
248
+ "text": text_labels,
249
+ "width": widths,
250
+ "height": heights
251
+ }
252
+
253
+ except Exception as e:
254
+ print(f"Критическая ошибка в encode_to_latents: {e}")
255
+ raise
256
+
257
+
258
+ # ---------------- 5️⃣ Обработка папки с изображениями и текстами ----------------
259
+ def process_folder(folder_path, limit=None):
260
+ """
261
+ Рекурсивно обходит указанную директорию и все вложенные директории,
262
+ собирая пути к изображениям и соответствующим текстовым файлам.
263
+ """
264
+ image_paths = []
265
+ text_paths = []
266
+ width = []
267
+ height = []
268
+ transform = get_image_transform(min_size, max_size, step)
269
+
270
+ # Используем os.walk для рекурсивного обхода директорий
271
+ for root, dirs, files in os.walk(folder_path):
272
+ for filename in files:
273
+ # Проверяем, является ли файл изображением
274
+ if filename.lower().endswith((".jpg", ".jpeg", ".png")):
275
+ image_path = os.path.join(root, filename)
276
+ try:
277
+ img = Image.open(image_path)
278
+ except Exception as e:
279
+ print(f"Ошибка при открытии {image_path}: {e}")
280
+ os.remove(image_path)
281
+ text_path = os.path.splitext(image_path)[0] + ".txt"
282
+ if os.path.exists(text_path):
283
+ os.remove(text_path)
284
+ continue
285
+ # Применяем трансформацию только для получения размеров
286
+ w, h = transform(img, dry_run=True)
287
+ # Формируем путь к текстовому файлу
288
+ text_path = os.path.splitext(image_path)[0] + ".txt"
289
+
290
+ # Добавляем пути, если текстовый файл существует
291
+ if os.path.exists(text_path) and min(w, h)>0:
292
+ image_paths.append(image_path)
293
+ text_paths.append(text_path)
294
+ width.append(w) # Добавляем в список
295
+ height.append(h) # Добавляем в список
296
+
297
+ # Проверяем ограничение на количество
298
+ if limit and limit>0 and len(image_paths) >= limit:
299
+ print(f"Достигнут лимит в {limit} изображений")
300
+ return image_paths, text_paths, width, height
301
+
302
+ print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями")
303
+ return image_paths, text_paths, width, height
304
+
305
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1):
306
+ total_files = len(image_paths)
307
+ start_time = time.time()
308
+ chunks = range(0, total_files, chunk_size)
309
+
310
+ for chunk_idx, start in enumerate(chunks, 1):
311
+ end = min(start + chunk_size, total_files)
312
+ chunk_image_paths = image_paths[start:end]
313
+ chunk_text_paths = text_paths[start:end]
314
+ chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths)
315
+ chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths)
316
+
317
+ # Чтение текстов
318
+ chunk_texts = []
319
+ for text_path in chunk_text_paths:
320
+ try:
321
+ with open(text_path, 'r', encoding='utf-8') as f:
322
+ text = f.read().strip()
323
+ chunk_texts.append(text)
324
+ except Exception as e:
325
+ print(f"Ошибка чтения {text_path}: {e}")
326
+ chunk_texts.append("")
327
+
328
+ # Группируем изображения по размерам
329
+ size_groups = {}
330
+ for i in range(len(chunk_image_paths)):
331
+ size_key = (chunk_widths[i], chunk_heights[i])
332
+ if size_key not in size_groups:
333
+ size_groups[size_key] = {"image_paths": [], "texts": []}
334
+ size_groups[size_key]["image_paths"].append(chunk_image_paths[i])
335
+ size_groups[size_key]["texts"].append(chunk_texts[i])
336
+
337
+ # Обрабатываем каждую группу размеров отдельно
338
+ for size_key, group_data in size_groups.items():
339
+ print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений")
340
+
341
+ group_dataset = Dataset.from_dict({
342
+ "image_path": group_data["image_paths"],
343
+ "text": group_data["texts"]
344
+ })
345
+
346
+ # Теперь можно использовать указанный batch_size, т.к. все изображения одного размера
347
+ processed_group = group_dataset.map(
348
+ lambda examples: encode_to_latents(
349
+ [Image.open(path) for path in examples["image_path"]],
350
+ examples["text"]
351
+ ),
352
+ batched=True,
353
+ batch_size=batch_size,
354
+ #remove_columns=["image_path"],
355
+ desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}"
356
+ )
357
+
358
+ # Сохраняем результаты группы
359
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
360
+ processed_group.save_to_disk(group_save_path)
361
+ clear_cuda_memory()
362
+ elapsed = time.time() - start_time
363
+ processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
364
+ if processed > 0:
365
+ remaining = (elapsed / processed) * (total_files - processed)
366
+ elapsed_str = str(timedelta(seconds=int(elapsed)))
367
+ remaining_str = str(timedelta(seconds=int(remaining)))
368
+ print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
369
+
370
+ # ---------------- 7️⃣ Объединение чанков ----------------
371
+ def combine_chunks(temp_path, final_path):
372
+ """Объединение обработанных чанков в финальный датасет"""
373
+ chunks = sorted([
374
+ os.path.join(temp_path, d)
375
+ for d in os.listdir(temp_path)
376
+ if d.startswith("chunk_")
377
+ ])
378
+
379
+ datasets = [load_from_disk(chunk) for chunk in chunks]
380
+ combined = concatenate_datasets(datasets)
381
+ combined.save_to_disk(final_path)
382
+
383
+ print(f"✅ Датасет успешно сохранен в: {final_path}")
384
+
385
+
386
+
387
+ # Создаем временную папку для чанков
388
+ temp_path = f"{save_path}_temp"
389
+ os.makedirs(temp_path, exist_ok=True)
390
+
391
+ # Получаем список файлов
392
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
393
+ print(f"Всего найдено {len(image_paths)} изображений")
394
+
395
+ # Обработка с чанкованием
396
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=20000, batch_size=batch_size)
397
+
398
+ # Удаление папки
399
+ try:
400
+ shutil.rmtree(folder_path)
401
+ print(f"✅ Папка {folder_path} успешно удалена")
402
+ except Exception as e:
403
+ print(f"⚠️ Ошибка при удалении папки: {e}")
404
+
405
+ # Объединение чанков в финальный датасет
406
+ combine_chunks(temp_path, save_path)
407
+
408
+ # Удаление временной папки
409
+ try:
410
+ shutil.rmtree(temp_path)
411
+ print(f"✅ Временная папка {temp_path} успешно удалена")
412
+ except Exception as e:
413
+ print(f"⚠️ Ошибка при удалении временной папки: {e}")
samples/sdxs_1b_384x768_0.jpg ADDED

Git LFS Details

  • SHA256: 04e973e35af92734bfb31877d2a5a3a5d548246b06bb6a2f26803635b333b9e7
  • Pointer size: 130 Bytes
  • Size of remote file: 75.4 kB
samples/sdxs_1b_416x768_0.jpg ADDED

Git LFS Details

  • SHA256: ef3a15fb13c198be3a2f1359e993998f8bf8dcb85f716dea457cc859eb9751c5
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
samples/sdxs_1b_448x768_0.jpg ADDED

Git LFS Details

  • SHA256: 3051f142812e47c678363a93dafff38b75845714389790441e50a4ecafe639f5
  • Pointer size: 130 Bytes
  • Size of remote file: 93.2 kB
samples/sdxs_1b_480x768_0.jpg ADDED

Git LFS Details

  • SHA256: 1f68b43fd1e1a684a8e8a5b34713393081c3097d923db99fa99236a93cea3512
  • Pointer size: 131 Bytes
  • Size of remote file: 301 kB
samples/sdxs_1b_512x768_0.jpg ADDED

Git LFS Details

  • SHA256: fb71463171685787a5a3d1cc600dd46ca221eee9d56cc549f2247e936e000457
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
samples/sdxs_1b_544x768_0.jpg ADDED

Git LFS Details

  • SHA256: 9bc771776872e1b17c5ca4d6cdf267b5f1e5de1d3481d66d11fc98d45527cc6c
  • Pointer size: 131 Bytes
  • Size of remote file: 185 kB
samples/sdxs_1b_576x768_0.jpg ADDED

Git LFS Details

  • SHA256: 68aa07d63bdd9437b0da26a1ed99ab710daca55627494196fa8697520455a18d
  • Pointer size: 131 Bytes
  • Size of remote file: 121 kB
samples/sdxs_1b_608x768_0.jpg ADDED

Git LFS Details

  • SHA256: c6b4c65b62eb6aba01e8cb585d45cb87d6ea8f3a5c99578cd9a61d8f8ac75707
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
samples/sdxs_1b_640x768_0.jpg ADDED

Git LFS Details

  • SHA256: 874c5035246f3b58839bba13f1b25076a5b9d1222bc8f4362d4f9411c56a6589
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
samples/sdxs_1b_672x768_0.jpg ADDED

Git LFS Details

  • SHA256: 792f7f42a44d290b45d44a07cf6b68d642fb24eca5ec1632612efc6cfb81a1b9
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
samples/sdxs_1b_704x768_0.jpg ADDED

Git LFS Details

  • SHA256: c020f7189bc050a687fff7f500019adcccaf331e5300dd3dcbe282547428997c
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
samples/sdxs_1b_736x768_0.jpg ADDED

Git LFS Details

  • SHA256: 7c7337c4ee1c596b0c7bfdb886eb44e29f5e919b8f314ed51173c6a5bc982b0c
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
samples/sdxs_1b_768x384_0.jpg ADDED

Git LFS Details

  • SHA256: 622ce7eef549cf3a412353120eb0386546bd0583f623e2a462b1845c48099fa1
  • Pointer size: 131 Bytes
  • Size of remote file: 252 kB
samples/sdxs_1b_768x416_0.jpg ADDED

Git LFS Details

  • SHA256: 2105f3d6b6a10c0ef4cf3097cf60fb205890c5647500700de58ca9156a9622df
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
samples/sdxs_1b_768x448_0.jpg ADDED

Git LFS Details

  • SHA256: 8e63e99418ceceecb8ecba636c8c4d44bf6c230439055f583b0634119156ab3c
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
samples/sdxs_1b_768x480_0.jpg ADDED

Git LFS Details

  • SHA256: cbb3fb683b10bbb48039b15371092fc84300e3f8705ca0ae6a93bf406a8da880
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
samples/sdxs_1b_768x512_0.jpg ADDED

Git LFS Details

  • SHA256: 2e1bcd280987b762c5adbd6eb420e4a7a6db56837d8a429f0c255b362186ef14
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
samples/sdxs_1b_768x544_0.jpg ADDED

Git LFS Details

  • SHA256: 24e992127aa14803ae3b82d78d6fede7f0023c03edc5ec701c73a3c4926b25b2
  • Pointer size: 131 Bytes
  • Size of remote file: 152 kB
samples/sdxs_1b_768x576_0.jpg ADDED

Git LFS Details

  • SHA256: 693106a26aff4d77744c59e7fbd5fcb3edc0f0d30072fdcd2c9a036a9b0a2f50
  • Pointer size: 130 Bytes
  • Size of remote file: 93.7 kB
samples/sdxs_1b_768x608_0.jpg ADDED

Git LFS Details

  • SHA256: db6fe5c7b64ef2528a838c7f26b5da8549e87bd486146a3b9d0ade0c052d1e81
  • Pointer size: 131 Bytes
  • Size of remote file: 240 kB
samples/sdxs_1b_768x640_0.jpg ADDED

Git LFS Details

  • SHA256: b8bf0081508aff64350ca55dbffce57ccf3db295b65de30799e82699bd0a8cbf
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB
samples/sdxs_1b_768x672_0.jpg ADDED

Git LFS Details

  • SHA256: a0198ef31c35bc409b3da00518249ec7b324abfd814d51d6cc6c771c5f1d5fb9
  • Pointer size: 131 Bytes
  • Size of remote file: 231 kB
samples/sdxs_1b_768x704_0.jpg ADDED

Git LFS Details

  • SHA256: a5d9033fbdadcf6fb88e933c8a3283e74bc54155a4a05a52546495f05c259aca
  • Pointer size: 131 Bytes
  • Size of remote file: 252 kB
samples/sdxs_1b_768x736_0.jpg ADDED

Git LFS Details

  • SHA256: 4cb94b9587b8dcc3b65bf9649eb1f2bd34cfa237c35048c583001285e24912e6
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB
samples/sdxs_1b_768x768_0.jpg ADDED

Git LFS Details

  • SHA256: 2c11f2d0e23e86c88ff6c6ed158de56f2c46302b56f8f64d9de70af001091f4f
  • Pointer size: 131 Bytes
  • Size of remote file: 222 kB
sdxs_1b/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c470e960d12887252492e7b4144aac436dfac59018713f1a59e21e385f00b32
3
- size 1860
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7d112c25b36970128b4ecb7566d23cce42308edcc37652950ae27444d433eaa
3
+ size 1863
sdxs_1b/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f6fbd31dda1a4d43b642aa62feeff17eb12c8b2b4b106e159370c9e5156db983
3
  size 4463672488
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:104b893ba36e1359f8468450653c0e00a7b2faeb1a42d3797747f0d0c06ccd0a
3
  size 4463672488
sdxs_flux/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
sdxs_flux/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:226251d5e756f1eef156cd84e3271a56b181ecf30cc9c86d1bc9b59777da6b6d
3
+ size 1861
sdxs_flux/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41dfe79753fadad9ba8e148ac73e1b61f23c7a1146cece5147c0041d613298bd
3
+ size 3195253456
src/sdxs_create_flux.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5906961734bae7d2b94597c9e23a452729b33709b5a90281097c4471872a597
3
+ size 34156
train.py CHANGED
@@ -28,19 +28,19 @@ from transformers import AutoTokenizer, AutoModel
28
 
29
  # --------------------------- Параметры ---------------------------
30
  ds_path = "/workspace/sdxs/datasets/768"
31
- project = "sdxs_07b"
32
- batch_size = 120
33
  base_learning_rate = 4e-5 #2.7e-5
34
  min_learning_rate = 9e-6 #2.7e-5
35
- num_epochs = 100
36
- sample_interval_share = 1
37
  max_length = 192
38
  use_wandb = True
39
  use_comet_ml = False
40
  save_model = True
41
  use_decay = True
42
- fbp = False
43
- optimizer_type = "adam"
44
  torch_compile = False
45
  unet_gradient = True
46
  fixed_seed = False
@@ -95,8 +95,8 @@ lora_alpha = 64
95
  print("init")
96
 
97
  loss_ratios = {
98
- "mse": 0.8,
99
- "mae": 0.2,
100
  }
101
  median_coeff_steps = 256
102
 
@@ -733,7 +733,8 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
733
  log_data["loss_mse"] = mse_loss.detach().item()
734
  log_data["loss_mae"] = mae_loss.detach().item()
735
  log_data["lr"] = current_lr
736
- log_data["grad"] = grad
 
737
  log_data["loss_norm"] = float(total_loss.item())
738
  for k, c in coeffs.items():
739
  log_data[f"coeff_{k}"] = float(c)
 
28
 
29
  # --------------------------- Параметры ---------------------------
30
  ds_path = "/workspace/sdxs/datasets/768"
31
+ project = "sdxs_1b"
32
+ batch_size = 32
33
  base_learning_rate = 4e-5 #2.7e-5
34
  min_learning_rate = 9e-6 #2.7e-5
35
+ num_epochs = 10
36
+ sample_interval_share = 20
37
  max_length = 192
38
  use_wandb = True
39
  use_comet_ml = False
40
  save_model = True
41
  use_decay = True
42
+ fbp = True
43
+ optimizer_type = "adam8bit"
44
  torch_compile = False
45
  unet_gradient = True
46
  fixed_seed = False
 
95
  print("init")
96
 
97
  loss_ratios = {
98
+ "mse": 0.6,
99
+ "mae": 0.4,
100
  }
101
  median_coeff_steps = 256
102
 
 
733
  log_data["loss_mse"] = mse_loss.detach().item()
734
  log_data["loss_mae"] = mae_loss.detach().item()
735
  log_data["lr"] = current_lr
736
+ if not fbp:
737
+ log_data["grad"] = grad
738
  log_data["loss_norm"] = float(total_loss.item())
739
  for k, c in coeffs.items():
740
  log_data[f"coeff_{k}"] = float(c)
train_flux.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from comet_ml import Experiment
2
+ import os
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from torch.utils.data import DataLoader, Sampler
8
+ from torch.utils.data.distributed import DistributedSampler
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from collections import defaultdict
11
+ from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image, ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+ from transformers import AutoTokenizer, AutoModel
28
+
29
+ # --------------------------- Параметры ---------------------------
30
+ ds_path = "/workspace/sdxs/datasets/mjnj"
31
+ project = "sdxs_flux"
32
+ batch_size = 32
33
+ base_learning_rate = 4e-5 #2.7e-5
34
+ min_learning_rate = 9e-6 #2.7e-5
35
+ num_epochs = 10
36
+ sample_interval_share = 10
37
+ cfg_dropout = 0.9
38
+ max_length = 192
39
+ use_wandb = True
40
+ use_comet_ml = False
41
+ save_model = True
42
+ use_decay = True
43
+ fbp = False
44
+ optimizer_type = "adam8bit"
45
+ torch_compile = False
46
+ unet_gradient = True
47
+ fixed_seed = False
48
+ shuffle = True
49
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
50
+ comet_ml_workspace = "recoilme"
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+ torch.backends.cudnn.allow_tf32 = True
53
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
54
+ dtype = torch.float32
55
+ save_barrier = 1.01
56
+ warmup_percent = 0.01
57
+ percentile_clipping = 95 #96 #97
58
+ betta2 = 0.995
59
+ eps = 1e-7
60
+ clip_grad_norm = 1.0
61
+ limit = 0
62
+ checkpoints_folder = ""
63
+ mixed_precision = "no"
64
+ gradient_accumulation_steps = 1
65
+
66
+ accelerator = Accelerator(
67
+ mixed_precision=mixed_precision,
68
+ gradient_accumulation_steps=gradient_accumulation_steps
69
+ )
70
+ device = accelerator.device
71
+
72
+ # Параметры для диффузии
73
+ n_diffusion_steps = 40
74
+ samples_to_generate = 12
75
+ guidance_scale = 4
76
+
77
+ # Папки для сохранения результатов
78
+ generated_folder = "samples"
79
+ os.makedirs(generated_folder, exist_ok=True)
80
+
81
+ # Настройка seed
82
+ current_date = datetime.now()
83
+ seed = int(current_date.strftime("%Y%m%d"))
84
+ if fixed_seed:
85
+ torch.manual_seed(seed)
86
+ np.random.seed(seed)
87
+ random.seed(seed)
88
+ if torch.cuda.is_available():
89
+ torch.cuda.manual_seed_all(seed)
90
+
91
+ # --------------------------- Параметры LoRA ---------------------------
92
+ lora_name = ""
93
+ lora_rank = 32
94
+ lora_alpha = 64
95
+
96
+ print("init")
97
+
98
+ loss_ratios = {
99
+ "mse": 1.5,
100
+ "mae": 0.5,
101
+ }
102
+ median_coeff_steps = 256
103
+
104
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
105
+ class MedianLossNormalizer:
106
+ def __init__(self, desired_ratios: dict, window_steps: int):
107
+ # нормируем доли на случай, если сумма != 1
108
+ #s = sum(desired_ratios.values())
109
+ #self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
110
+ #self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
111
+ #self.window = window_steps
112
+ self.ratios = {k: float(v) for k, v in desired_ratios.items()}
113
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
114
+ self.window = window_steps
115
+
116
+ def update_and_total(self, losses: dict):
117
+ """
118
+ losses: dict ключ->тензор (значения лоссов)
119
+ Поведение:
120
+ - буферим ABS(l) только для активных (ratio>0) лоссов
121
+ - coeff = ratio / median(abs(loss))
122
+ - total = sum(coeff * loss) по активным лоссам
123
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
124
+ """
125
+ # буферим только активные лоссы
126
+ for k, v in losses.items():
127
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
128
+ val = v.detach().abs().mean().cpu().item() # .item() лучше float() для тензоров
129
+ self.buffers[k].append(val)
130
+ #self.buffers[k].append(float(v.detach().abs().cpu()))
131
+
132
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
133
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
134
+
135
+ # суммируем только по активным (ratio>0)
136
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
137
+ return total, coeffs, meds
138
+
139
+ # создаём normalizer после определения loss_ratios
140
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
141
+
142
+ # --------------------------- Ин��циализация WandB ---------------------------
143
+ if accelerator.is_main_process:
144
+ if use_wandb:
145
+ wandb.init(project=project+lora_name, config={
146
+ "batch_size": batch_size,
147
+ "base_learning_rate": base_learning_rate,
148
+ "num_epochs": num_epochs,
149
+ "optimizer_type": optimizer_type,
150
+ })
151
+ if use_comet_ml:
152
+ from comet_ml import Experiment
153
+ comet_experiment = Experiment(
154
+ api_key=comet_ml_api_key,
155
+ project_name=project,
156
+ workspace=comet_ml_workspace
157
+ )
158
+ hyper_params = {
159
+ "batch_size": batch_size,
160
+ "base_learning_rate": base_learning_rate,
161
+ "num_epochs": num_epochs,
162
+ }
163
+ comet_experiment.log_parameters(hyper_params)
164
+
165
+ # Включение Flash Attention 2/SDPA
166
+ torch.backends.cuda.enable_flash_sdp(True)
167
+
168
+ # --------------------------- Загрузка моделей ---------------------------
169
+ #vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
170
+ vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
171
+ tokenizer = AutoTokenizer.from_pretrained("tokenizer")
172
+ text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
173
+
174
+ # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
175
+ def encode_texts(texts, max_length=max_length):
176
+ # Если тексты пустые (для unconditional), создаем заглушки
177
+ if texts is None:
178
+ # В случае None возвращаем нули (логика для get_negative_embedding)
179
+ # Но здесь мы обычно ожидаем список строк.
180
+ pass
181
+
182
+ with torch.no_grad():
183
+ if isinstance(texts, str):
184
+ texts = [texts]
185
+
186
+ for i, prompt_item in enumerate(texts):
187
+ messages = [
188
+ {"role": "user", "content": prompt_item},
189
+ ]
190
+ prompt_item = tokenizer.apply_chat_template(
191
+ messages,
192
+ tokenize=False,
193
+ add_generation_prompt=True,
194
+ #enable_thinking=True,
195
+ )
196
+ #print(prompt_item+"\n")
197
+ texts[i] = prompt_item
198
+
199
+ toks = tokenizer(
200
+ texts,
201
+ return_tensors="pt",
202
+ padding="max_length",
203
+ truncation=True,
204
+ max_length=max_length
205
+ ).to(device)
206
+
207
+ outs = text_model(**toks, output_hidden_states=True, return_dict=True)
208
+
209
+ # Используем last_hidden_state или hidden_states[-1] (если Qwen, лучше last_hidden_state - прим человека: ХУЙ)
210
+ hidden = outs.hidden_states[-2]
211
+
212
+ # 2. Маска внимания
213
+ attention_mask = toks["attention_mask"]
214
+
215
+ # 3. Пулинг-эмбеддинг (Последний токен)
216
+ sequence_lengths = attention_mask.sum(dim=1) - 1
217
+ batch_size = hidden.shape[0]
218
+ pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
219
+
220
+ #return hidden, attention_mask
221
+ # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
222
+ # 1. Расширяем пулинг-вектор до последовательности [B, 1, emb]
223
+ pooled_expanded = pooled.unsqueeze(1)
224
+
225
+ # 2. Объединяем последовательность токенов и пулинг-вектор
226
+ # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
227
+ # Теперь: [B, 1 + L, emb]. Пулинг стал токеном в НАЧАЛЕ.
228
+ new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
229
+
230
+ # 3. Обновляем маску внимания для нового токена
231
+ # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
232
+ # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
233
+ new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
234
+
235
+ return new_encoder_hidden_states, new_attention_mask
236
+
237
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
238
+ if shift_factor is None: shift_factor = 0.0
239
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
240
+ if scaling_factor is None: scaling_factor = 1.0
241
+
242
+ from diffusers import FlowMatchEulerDiscreteScheduler
243
+ num_train_timesteps = 1000
244
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps)
245
+
246
+ class DistributedResolutionBatchSampler(Sampler):
247
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
248
+ self.dataset = dataset
249
+ self.batch_size = max(1, batch_size // num_replicas)
250
+ self.num_replicas = num_replicas
251
+ self.rank = rank
252
+ self.shuffle = shuffle
253
+ self.drop_last = drop_last
254
+ self.epoch = 0
255
+
256
+ try:
257
+ widths = np.array(dataset["width"])
258
+ heights = np.array(dataset["height"])
259
+ except KeyError:
260
+ widths = np.zeros(len(dataset))
261
+ heights = np.zeros(len(dataset))
262
+
263
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
264
+ self.size_groups = {}
265
+ for w, h in self.size_keys:
266
+ mask = (widths == w) & (heights == h)
267
+ self.size_groups[(w, h)] = np.where(mask)[0]
268
+
269
+ self.group_num_batches = {}
270
+ total_batches = 0
271
+ for size, indices in self.size_groups.items():
272
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
273
+ self.group_num_batches[size] = num_full_batches
274
+ total_batches += num_full_batches
275
+
276
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
277
+
278
+ def __iter__(self):
279
+ if torch.cuda.is_available():
280
+ torch.cuda.empty_cache()
281
+ all_batches = []
282
+ rng = np.random.RandomState(self.epoch)
283
+
284
+ for size, indices in self.size_groups.items():
285
+ indices = indices.copy()
286
+ if self.shuffle:
287
+ rng.shuffle(indices)
288
+ num_full_batches = self.group_num_batches[size]
289
+ if num_full_batches == 0:
290
+ continue
291
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
292
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
293
+ start_idx = self.rank * self.batch_size
294
+ end_idx = start_idx + self.batch_size
295
+ gpu_batches = batches[:, start_idx:end_idx]
296
+ all_batches.extend(gpu_batches)
297
+
298
+ if self.shuffle:
299
+ rng.shuffle(all_batches)
300
+ accelerator.wait_for_everyone()
301
+ return iter(all_batches)
302
+
303
+ def __len__(self):
304
+ return self.num_batches
305
+
306
+ def set_epoch(self, epoch):
307
+ self.epoch = epoch
308
+
309
+ # --- [UPDATED] Функция для фиксированных семплов ---
310
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
311
+ size_groups = defaultdict(list)
312
+ try:
313
+ widths = dataset["width"]
314
+ heights = dataset["height"]
315
+ except KeyError:
316
+ widths = [0] * len(dataset)
317
+ heights = [0] * len(dataset)
318
+ for i, (w, h) in enumerate(zip(widths, heights)):
319
+ size = (w, h)
320
+ size_groups[size].append(i)
321
+
322
+ fixed_samples = {}
323
+ for size, indices in size_groups.items():
324
+ n_samples = min(samples_per_group, len(indices))
325
+ if len(size_groups)==1:
326
+ n_samples = samples_to_generate
327
+ if n_samples == 0:
328
+ continue
329
+ sample_indices = random.sample(indices, n_samples)
330
+ samples_data = [dataset[idx] for idx in sample_indices]
331
+
332
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
333
+ texts = [item["text"] for item in samples_data]
334
+
335
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
336
+ embeddings, masks = encode_texts(texts)
337
+
338
+ fixed_samples[size] = (latents, embeddings, masks, texts)
339
+
340
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
341
+ return fixed_samples
342
+
343
+ if limit > 0:
344
+ dataset = load_from_disk(ds_path).select(range(limit))
345
+ else:
346
+ dataset = load_from_disk(ds_path)
347
+
348
+ dataset = dataset.filter(
349
+ lambda x: [not (path.startswith("/workspace/ds/animesfw") or path.startswith("/workspace/ds/d4/animesfw")) for path in x["image_path"]],
350
+ batched=True,
351
+ batch_size=10000, # обрабатываем по 10к строк за раз
352
+ num_proc=8
353
+ )
354
+ print(f"Осталось примеров после фильтрации: {len(dataset)}")
355
+
356
+ # --- [UPDATED] Collate Function ---
357
+ def collate_fn_simple(batch):
358
+ # 1. Латенты (VAE)
359
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype)
360
+
361
+ # 2. Текст берем сырой из датасета
362
+ raw_texts = [item["text"] for item in batch]
363
+ texts = [
364
+ "" if t.lower().startswith("zero")
365
+ else "" if random.random() < cfg_dropout
366
+ else t[1:].lstrip() if t.startswith(".")
367
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
368
+ for t in raw_texts
369
+ ]
370
+
371
+ # 3. Кодируем на лету
372
+ # Возвращает: hidden (B, L, D), mask (B, L)
373
+ embeddings, attention_mask = encode_texts(texts)
374
+
375
+ # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай пр��ведем к long
376
+ attention_mask = attention_mask.to(dtype=torch.int64)
377
+
378
+ return latents, embeddings, attention_mask
379
+
380
+ batch_sampler = DistributedResolutionBatchSampler(
381
+ dataset=dataset,
382
+ batch_size=batch_size,
383
+ num_replicas=accelerator.num_processes,
384
+ rank=accelerator.process_index,
385
+ shuffle=shuffle
386
+ )
387
+
388
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
389
+ if accelerator.is_main_process:
390
+ print("Total samples", len(dataloader))
391
+ dataloader = accelerator.prepare(dataloader)
392
+
393
+ start_epoch = 0
394
+ global_step = 0
395
+ total_training_steps = (len(dataloader) * num_epochs)
396
+ world_size = accelerator.state.num_processes
397
+
398
+ # Загрузка UNet
399
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
400
+ if os.path.isdir(latest_checkpoint):
401
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
402
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
403
+ if unet_gradient:
404
+ unet.enable_gradient_checkpointing()
405
+ unet.set_use_memory_efficient_attention_xformers(False)
406
+ try:
407
+ unet.set_attn_processor(AttnProcessor2_0())
408
+ except Exception as e:
409
+ print(f"Ошибка при включении SDPA: {e}")
410
+ unet.set_use_memory_efficient_attention_xformers(True)
411
+ else:
412
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
413
+
414
+ if lora_name:
415
+ # ... (Код LoRA без изменений, опущен для краткости, если не используется, иначе раскомментируйте оригинальный блок) ...
416
+ pass
417
+
418
+ # Оптимизатор
419
+ if lora_name:
420
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
421
+ else:
422
+ if fbp:
423
+ trainable_params = list(unet.parameters())
424
+
425
+ def create_optimizer(name, params):
426
+ if name == "adam8bit":
427
+ return bnb.optim.AdamW8bit(
428
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
429
+ percentile_clipping=percentile_clipping
430
+ )
431
+ elif name == "adam":
432
+ return torch.optim.AdamW(
433
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
434
+ )
435
+ elif name == "muon":
436
+ from muon import MuonWithAuxAdam
437
+ trainable_params = [p for p in params if p.requires_grad]
438
+ hidden_weights = [p for p in trainable_params if p.ndim >= 2]
439
+ hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
440
+
441
+ param_groups = [
442
+ dict(params=hidden_weights, use_muon=True,
443
+ lr=1e-3, weight_decay=1e-4),
444
+ dict(params=hidden_gains_biases, use_muon=False,
445
+ lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
446
+ ]
447
+ optimizer = MuonWithAuxAdam(param_groups)
448
+ from snooc import SnooC
449
+ return SnooC(optimizer)
450
+ else:
451
+ raise ValueError(f"Unknown optimizer: {name}")
452
+
453
+ if fbp:
454
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
455
+ def optimizer_hook(param):
456
+ optimizer_dict[param].step()
457
+ optimizer_dict[param].zero_grad(set_to_none=True)
458
+ for param in trainable_params:
459
+ param.register_post_accumulate_grad_hook(optimizer_hook)
460
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
461
+ else:
462
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
463
+ def lr_schedule(step):
464
+ x = step / (total_training_steps * world_size)
465
+ warmup = warmup_percent
466
+ if not use_decay:
467
+ return base_learning_rate
468
+ if x < warmup:
469
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
470
+ decay_ratio = (x - warmup) / (1 - warmup)
471
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
472
+ (1 + math.cos(math.pi * decay_ratio))
473
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
474
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
475
+
476
+ if torch_compile:
477
+ print("compiling")
478
+ unet = torch.compile(unet)
479
+ print("compiling - ok")
480
+
481
+ # Фиксированные семплы
482
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
483
+
484
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
485
+ def get_negative_embedding(neg_prompt="", batch_size=1):
486
+ if not neg_prompt:
487
+ hidden_dim = 2048
488
+ seq_len = max_length
489
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
490
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
491
+ return empty_emb, empty_mask
492
+
493
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
494
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
495
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
496
+
497
+ return uncond_emb, uncond_mask
498
+
499
+ # Получаем негативные (пустые) условия для валидации
500
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
501
+
502
+ # --- Функция генерации семплов ---
503
+ @torch.compiler.disable()
504
+ @torch.no_grad()
505
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
506
+ uncond_emb, uncond_mask = uncond_data
507
+
508
+ original_model = None
509
+ try:
510
+ if not torch_compile:
511
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
512
+ else:
513
+ original_model = unet.eval()
514
+
515
+ vae.to(device=device).eval()
516
+
517
+ all_generated_images = []
518
+ all_captions = []
519
+
520
+ # Распаковываем 5 элементов (добавились mask)
521
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
522
+ width, height = size
523
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
524
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
525
+ sample_mask = sample_mask.to(device=device)
526
+
527
+ latents = torch.randn(
528
+ sample_latents.shape,
529
+ device=device,
530
+ dtype=sample_latents.dtype,
531
+ generator=torch.Generator(device=device).manual_seed(seed)
532
+ )
533
+
534
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
535
+
536
+ for t in scheduler.timesteps:
537
+ if guidance_scale != 1:
538
+ latent_model_input = torch.cat([latents, latents], dim=0)
539
+
540
+ # Подготовка батчей для CFG (Negative + Positive)
541
+ # 1. Embeddings
542
+ curr_batch_size = sample_text_embeddings.shape[0]
543
+ seq_len = sample_text_embeddings.shape[1]
544
+ hidden_dim = sample_text_embeddings.shape[2]
545
+
546
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
547
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
548
+
549
+ # 2. Masks
550
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
551
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
552
+
553
+ else:
554
+ latent_model_input = latents
555
+ text_embeddings_batch = sample_text_embeddings
556
+ attention_mask_batch = sample_mask
557
+
558
+ # Предсказание с передачей всех условий
559
+ model_out = original_model(
560
+ latent_model_input,
561
+ t,
562
+ encoder_hidden_states=text_embeddings_batch,
563
+ encoder_attention_mask=attention_mask_batch,
564
+ )
565
+ flow = getattr(model_out, "sample", model_out)
566
+
567
+ if guidance_scale != 1:
568
+ flow_uncond, flow_cond = flow.chunk(2)
569
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
570
+
571
+ latents = scheduler.step(flow, t, latents).prev_sample
572
+
573
+ current_latents = latents
574
+
575
+ latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
576
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
577
+ decoded_fp32 = decoded.to(torch.float32)
578
+
579
+ for img_idx, img_tensor in enumerate(decoded_fp32):
580
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
581
+ img = img.transpose(1, 2, 0)
582
+
583
+ if np.isnan(img).any():
584
+ print("NaNs found, saving stopped! Step:", step)
585
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
586
+
587
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
588
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
589
+ max_w_overall = max(255, max_w_overall)
590
+ max_h_overall = max(255, max_h_overall)
591
+
592
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
593
+ all_generated_images.append(padded_img)
594
+
595
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
596
+ all_captions.append(caption_text)
597
+
598
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
599
+ pil_img.save(sample_path, "JPEG", quality=96)
600
+
601
+ if use_wandb and accelerator.is_main_process:
602
+ wandb_images = [
603
+ wandb.Image(img, caption=f"{all_captions[i]}")
604
+ for i, img in enumerate(all_generated_images)
605
+ ]
606
+ wandb.log({"generated_images": wandb_images})
607
+ if use_comet_ml and accelerator.is_main_process:
608
+ for i, img in enumerate(all_generated_images):
609
+ comet_experiment.log_image(
610
+ image_data=img,
611
+ name=f"step_{step}_img_{i}",
612
+ step=step,
613
+ metadata={"caption": all_captions[i]}
614
+ )
615
+ finally:
616
+ vae.to("cpu")
617
+ torch.cuda.empty_cache()
618
+ gc.collect()
619
+
620
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
621
+ if accelerator.is_main_process:
622
+ if save_model:
623
+ print("Генерация сэмплов до старта обучения...")
624
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
625
+ accelerator.wait_for_everyone()
626
+
627
+ def save_checkpoint(unet, variant=""):
628
+ if accelerator.is_main_process:
629
+ if lora_name:
630
+ save_lora_checkpoint(unet)
631
+ else:
632
+ model_to_save = None
633
+ if not torch_compile:
634
+ model_to_save = accelerator.unwrap_model(unet)
635
+ else:
636
+ model_to_save = unet
637
+
638
+ if variant != "":
639
+ model_to_save.to(dtype=torch.float16).save_pretrained(
640
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
641
+ )
642
+ else:
643
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
644
+
645
+ unet = unet.to(dtype=dtype)
646
+
647
+ # --------------------------- Тренировочный цикл ---------------------------
648
+ if accelerator.is_main_process:
649
+ print(f"Total steps per GPU: {total_training_steps}")
650
+
651
+ epoch_loss_points = []
652
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
653
+
654
+ steps_per_epoch = len(dataloader)
655
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
656
+ min_loss = 4.
657
+
658
+ for epoch in range(start_epoch, start_epoch + num_epochs):
659
+ batch_losses = []
660
+ batch_grads = []
661
+ batch_sampler.set_epoch(epoch)
662
+ accelerator.wait_for_everyone()
663
+ unet.train()
664
+
665
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
666
+ with accelerator.accumulate(unet):
667
+ if save_model == False and epoch == 0 and step == 5 :
668
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
669
+ print(f"Шаг {step}: {used_gb:.2f} GB")
670
+
671
+ # шум
672
+ noise = torch.randn_like(latents, dtype=latents.dtype)
673
+ # берём t из [0, 1]
674
+ #t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
675
+ u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
676
+ t = torch.sigmoid(torch.randn_like(u))
677
+
678
+ # интерполяция между x0 и шумом
679
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
680
+ # делаем integer timesteps для UNet
681
+ timesteps = (t * scheduler.config.num_train_timesteps).long()
682
+
683
+ # --- Вызов UNet с маской ---
684
+ model_pred = unet(
685
+ noisy_latents,
686
+ timesteps,
687
+ encoder_hidden_states=embeddings,
688
+ encoder_attention_mask=attention_mask
689
+ ).sample
690
+
691
+ target = noise - latents
692
+
693
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
694
+ mae_loss = F.l1_loss(model_pred.float(), target.float())
695
+ batch_losses.append(mse_loss.detach().item())
696
+
697
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
698
+ accelerator.wait_for_everyone()
699
+
700
+ losses_dict = {}
701
+ losses_dict["mse"] = mse_loss
702
+ losses_dict["mae"] = mae_loss
703
+
704
+ # === Нормализация всех лоссов ===
705
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
706
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
707
+
708
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
709
+ accelerator.wait_for_everyone()
710
+
711
+ accelerator.backward(total_loss)
712
+
713
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
714
+ accelerator.wait_for_everyone()
715
+
716
+ grad = 0.0
717
+ if not fbp:
718
+ if accelerator.sync_gradients:
719
+ #with torch.amp.autocast('cuda', enabled=False):
720
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
721
+ if grad_val is not None:
722
+ grad = float(grad_val)
723
+ print("grad_val is None")
724
+ else:
725
+ grad = 0.0
726
+ optimizer.step()
727
+ lr_scheduler.step()
728
+ optimizer.zero_grad(set_to_none=True)
729
+
730
+ if accelerator.sync_gradients:
731
+ global_step += 1
732
+ progress_bar.update(1)
733
+ if accelerator.is_main_process:
734
+ if fbp:
735
+ current_lr = base_learning_rate
736
+ else:
737
+ current_lr = lr_scheduler.get_last_lr()[0]
738
+ batch_grads.append(grad)
739
+
740
+ log_data = {}
741
+ log_data["loss_mse"] = mse_loss.detach().item()
742
+ log_data["loss_mae"] = mae_loss.detach().item()
743
+ log_data["lr"] = current_lr
744
+ log_data["grad"] = grad
745
+ log_data["loss_norm"] = float(total_loss.item())
746
+ for k, c in coeffs.items():
747
+ log_data[f"coeff_{k}"] = float(c)
748
+ if accelerator.sync_gradients:
749
+ if use_wandb:
750
+ wandb.log(log_data, step=global_step)
751
+ if use_comet_ml:
752
+ comet_experiment.log_metrics(log_data, step=global_step)
753
+
754
+ if global_step % sample_interval == 0:
755
+ # Передаем tuple (emb, mask) для негатива
756
+ if save_model:
757
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
758
+ elif epoch % 10 == 0:
759
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
760
+ last_n = sample_interval
761
+
762
+ if save_model:
763
+ has_losses = len(batch_losses) > 0
764
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
765
+ last_loss = batch_losses[-1] if has_losses else 0.0
766
+ max_loss = max(avg_sample_loss, last_loss)
767
+ should_save = max_loss < min_loss * save_barrier
768
+ print(
769
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
770
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
771
+ )
772
+ # 6. Сохранение и обновление
773
+ if should_save:
774
+ min_loss = max_loss
775
+ save_checkpoint(unet)
776
+
777
+ if accelerator.is_main_process:
778
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
779
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
780
+
781
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
782
+ log_data_ep = {
783
+ "epoch_loss": avg_epoch_loss,
784
+ "epoch_grad": avg_epoch_grad,
785
+ "epoch": epoch + 1,
786
+ }
787
+ if use_wandb:
788
+ wandb.log(log_data_ep)
789
+ if use_comet_ml:
790
+ comet_experiment.log_metrics(log_data_ep)
791
+
792
+ if accelerator.is_main_process:
793
+ print("Обучение завершено! Сохраняем финальную модель...")
794
+ #if save_model:
795
+ save_checkpoint(unet,"fp16")
796
+ if use_comet_ml:
797
+ comet_experiment.end()
798
+ accelerator.free_memory()
799
+ if torch.distributed.is_initialized():
800
+ torch.distributed.destroy_process_group()
801
+
802
+ print("Готово!")