recoilme commited on
Commit
1ae8d8f
·
1 Parent(s): 1acb679
Files changed (7) hide show
  1. dataset_new.py +376 -0
  2. media/result_grid.jpg +2 -2
  3. pipeline_sdxs.py +8 -2
  4. src/dataset_mjnj.ipynb +2 -2
  5. test.ipynb +2 -2
  6. train-Copy1.py +771 -0
  7. train.py +10 -46
dataset_new.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKL,AutoencoderKLWan
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer, AutoModelForCausalLM
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+
20
+ # ---------------- 1️⃣ Настройки ----------------
21
+ dtype = torch.float16
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ batch_size = 10
24
+ min_size = 320 #192 #256 #192
25
+ max_size = 640 #384 #256 #384
26
+ step = 64 #64
27
+ empty_share = 0.05
28
+ limit = 0
29
+ # Основная процедура обработки
30
+ folder_path = "/workspace/ds/dataset/mjnj" #alchemist"
31
+ save_path = "/workspace/mjnj_640" #"alchemist"
32
+ os.makedirs(save_path, exist_ok=True)
33
+
34
+ # Функция для очистки CUDA памяти
35
+ def clear_cuda_memory():
36
+ if torch.cuda.is_available():
37
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
38
+ print(f"used_gb: {used_gb:.2f} GB")
39
+ torch.cuda.empty_cache()
40
+ gc.collect()
41
+
42
+ # ---------------- 2️⃣ Загрузка моделей ----------------
43
+ def load_models():
44
+ print("Загрузка моделей...")
45
+ vae = AutoencoderKL.from_pretrained("AiArtLab/sdxs3d",subfolder="vae",torch_dtype=dtype).to(device).eval()
46
+
47
+ model_name = "Qwen/Qwen3-0.6B"
48
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_name,
51
+ torch_dtype=dtype,
52
+ device_map=device
53
+ ).eval()
54
+ return vae, model, tokenizer
55
+
56
+ vae, model, tokenizer = load_models()
57
+
58
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
59
+ if shift_factor is None:
60
+ shift_factor = 0.0
61
+
62
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
63
+ if scaling_factor is None:
64
+ scaling_factor = 1.0
65
+
66
+ latents_mean = getattr(vae.config, "latents_mean", None)
67
+ latents_std = getattr(vae.config, "latents_std", None)
68
+
69
+ # ---------------- 3️⃣ Трансформации ----------------
70
+ def get_image_transform(min_size=256, max_size=512, step=64):
71
+ def transform(img, dry_run=False):
72
+ # Сохраняем исходные размеры изображения
73
+ original_width, original_height = img.size
74
+
75
+ # 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size
76
+ if original_width >= original_height:
77
+ new_width = max_size
78
+ new_height = int(max_size * original_height / original_width)
79
+ else:
80
+ new_height = max_size
81
+ new_width = int(max_size * original_width / original_height)
82
+
83
+ if new_height < min_size or new_width < min_size:
84
+ # 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size
85
+ if original_width <= original_height:
86
+ new_width = min_size
87
+ new_height = int(min_size * original_height / original_width)
88
+ else:
89
+ new_height = min_size
90
+ new_width = int(min_size * original_width / original_height)
91
+
92
+ # 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке
93
+ crop_width = min(max_size, (new_width // step) * step)
94
+ crop_height = min(max_size, (new_height // step) * step)
95
+
96
+ # Убеждаемся, что размеры обрезки не меньше min_size
97
+ crop_width = max(min_size, crop_width)
98
+ crop_height = max(min_size, crop_height)
99
+
100
+ # Если запрошен только предварительный расчёт размеров
101
+ if dry_run:
102
+ return crop_width, crop_height
103
+
104
+ # Конвертация в RGB и ресайз
105
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
106
+
107
+ # Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху)
108
+ top = (new_height - crop_height) // 3
109
+ left = 0
110
+
111
+ # Обрезка изображения
112
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
113
+
114
+ # Сохраняем итоговые размеры после всех преобразований
115
+ final_width, final_height = img_cropped.size
116
+
117
+ # тензор
118
+ img_tensor = ToTensor()(img_cropped)
119
+ img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor)
120
+ return img_tensor, img_cropped, final_width, final_height
121
+
122
+ return transform
123
+
124
+ # ---------------- 4️⃣ Функции обработки ----------------
125
+ def encode_texts_batch(texts, tokenizer, model, device="cuda", max_length=150):
126
+ with torch.inference_mode():
127
+ # 1. Tokenization
128
+ batch = tokenizer(
129
+ texts,
130
+ return_tensors="pt",
131
+ padding="max_length",
132
+ truncation=True,
133
+ max_length=max_length,
134
+ ).to(device)
135
+
136
+ # 2. Forward pass
137
+ outputs = model.model(**batch, output_hidden_states=True, return_dict=True)
138
+ hidden = outputs.hidden_states[-1]
139
+ mask = batch["attention_mask"].unsqueeze(-1) # (B, L, 1)
140
+
141
+ # 3. Zero-pad embeddings for pad tokens
142
+ hidden = hidden * mask # (B, L, D)
143
+
144
+ return hidden.cpu().numpy()
145
+
146
+
147
+ def clean_label(label):
148
+ label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "")
149
+ return label
150
+
151
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
152
+ """
153
+ Обрабатывает список меток для classifier-free guidance.
154
+
155
+ С вероятностью prob_to_make_empty:
156
+ - Метка в первом списке заменяется на пустую строку.
157
+ - К метке во втором списке добавляется префикс "zero:".
158
+
159
+ В противном случае метки в обоих списках остаются оригинальными.
160
+
161
+ """
162
+ labels_for_model = []
163
+ labels_for_logging = []
164
+
165
+ for label in original_labels:
166
+ if random.random() < prob_to_make_empty:
167
+ labels_for_model.append("") # Заменяем на пустую строку для модели
168
+ labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования
169
+ else:
170
+ labels_for_model.append(label) # Оставляем оригинальную метку для модели
171
+ labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования
172
+
173
+ return labels_for_model, labels_for_logging
174
+
175
+ def encode_to_latents(images, texts):
176
+ transform = get_image_transform(min_size, max_size, step)
177
+
178
+ try:
179
+ # Обработка изображений (все одинакового размера)
180
+ transformed_tensors = []
181
+ pil_images = []
182
+ widths, heights = [], []
183
+
184
+ # Применяем трансформацию ко всем изображениям
185
+ for img in images:
186
+ try:
187
+ t_img, pil_img, w, h = transform(img)
188
+ transformed_tensors.append(t_img)
189
+ pil_images.append(pil_img)
190
+ widths.append(w)
191
+ heights.append(h)
192
+ except Exception as e:
193
+ print(f"Ошибка трансформации: {e}")
194
+ continue
195
+
196
+ if not transformed_tensors:
197
+ return None
198
+
199
+ # Создаём батч
200
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
201
+ if batch_tensor.ndim==5:
202
+ batch_tensor = batch_tensor.unsqueeze(2) # [B, C, 1, H, W]
203
+
204
+ # Кодируем батч
205
+ with torch.no_grad():
206
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
207
+ latents = (posteriors - shift_factor) / scaling_factor
208
+
209
+ latents_np = latents.to(dtype).cpu().numpy()
210
+
211
+ # Обрабатываем тексты
212
+ text_labels = [clean_label(text) for text in texts]
213
+
214
+ model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
215
+ embeddings = encode_texts_batch(model_prompts, tokenizer, model)
216
+
217
+ return {
218
+ "vae": latents_np,
219
+ "embeddings": embeddings,
220
+ "text": text_labels,
221
+ "width": widths,
222
+ "height": heights
223
+ }
224
+
225
+ except Exception as e:
226
+ print(f"Критическая ошибка в encode_to_latents: {e}")
227
+ raise
228
+
229
+
230
+ # ---------------- 5️⃣ Обработка папки с изображениями и текстами ----------------
231
+ def process_folder(folder_path, limit=None):
232
+ """
233
+ Рекурсивно обходит указанную директорию и все вложенные директории,
234
+ собирая пути к изображениям и соответствующим текстовым файлам.
235
+ """
236
+ image_paths = []
237
+ text_paths = []
238
+ width = []
239
+ height = []
240
+ transform = get_image_transform(min_size, max_size, step)
241
+
242
+ # Используем os.walk для рекурсивного обхода директорий
243
+ for root, dirs, files in os.walk(folder_path):
244
+ for filename in files:
245
+ # Проверяем, является ли файл изображением
246
+ if filename.lower().endswith((".jpg", ".jpeg", ".png")):
247
+ image_path = os.path.join(root, filename)
248
+ try:
249
+ img = Image.open(image_path)
250
+ except Exception as e:
251
+ print(f"Ошибка при открытии {image_path}: {e}")
252
+ os.remove(image_path)
253
+ text_path = os.path.splitext(image_path)[0] + ".txt"
254
+ if os.path.exists(text_path):
255
+ os.remove(text_path)
256
+ continue
257
+ # Применяем трансформацию только для получения размеров
258
+ w, h = transform(img, dry_run=True)
259
+ # Формируем путь к текстовому файлу
260
+ text_path = os.path.splitext(image_path)[0] + ".txt"
261
+
262
+ # Добавляем пути, если текстовый файл существует
263
+ if os.path.exists(text_path) and min(w, h)>0:
264
+ image_paths.append(image_path)
265
+ text_paths.append(text_path)
266
+ width.append(w) # Добавляем в список
267
+ height.append(h) # Добавляем в список
268
+
269
+ # Проверяем ограничение на количество
270
+ if limit and limit>0 and len(image_paths) >= limit:
271
+ print(f"Достигнут лимит в {limit} изображений")
272
+ return image_paths, text_paths, width, height
273
+
274
+ print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями")
275
+ return image_paths, text_paths, width, height
276
+
277
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=10000, batch_size=1):
278
+ total_files = len(image_paths)
279
+ start_time = time.time()
280
+ chunks = range(0, total_files, chunk_size)
281
+
282
+ for chunk_idx, start in enumerate(chunks, 1):
283
+ end = min(start + chunk_size, total_files)
284
+ chunk_image_paths = image_paths[start:end]
285
+ chunk_text_paths = text_paths[start:end]
286
+ chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths)
287
+ chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths)
288
+
289
+ # Чтение текстов
290
+ chunk_texts = []
291
+ for text_path in chunk_text_paths:
292
+ try:
293
+ with open(text_path, 'r', encoding='utf-8') as f:
294
+ text = f.read().strip()
295
+ chunk_texts.append(text)
296
+ except Exception as e:
297
+ print(f"Ошибка чтения {text_path}: {e}")
298
+ chunk_texts.append("")
299
+
300
+ # Группируем изображения по размерам
301
+ size_groups = {}
302
+ for i in range(len(chunk_image_paths)):
303
+ size_key = (chunk_widths[i], chunk_heights[i])
304
+ if size_key not in size_groups:
305
+ size_groups[size_key] = {"image_paths": [], "texts": []}
306
+ size_groups[size_key]["image_paths"].append(chunk_image_paths[i])
307
+ size_groups[size_key]["texts"].append(chunk_texts[i])
308
+
309
+ # Обрабатываем каждую группу размеров отдельно
310
+ for size_key, group_data in size_groups.items():
311
+ print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений")
312
+
313
+ group_dataset = Dataset.from_dict({
314
+ "image_path": group_data["image_paths"],
315
+ "text": group_data["texts"]
316
+ })
317
+
318
+ # Теперь можно использовать указанный batch_size, т.к. все изображения одного размера
319
+ processed_group = group_dataset.map(
320
+ lambda examples: encode_to_latents(
321
+ [Image.open(path) for path in examples["image_path"]],
322
+ examples["text"]
323
+ ),
324
+ batched=True,
325
+ batch_size=batch_size,
326
+ #remove_columns=["image_path"],
327
+ desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}"
328
+ )
329
+
330
+ # Сохраняем результаты группы
331
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
332
+ processed_group.save_to_disk(group_save_path)
333
+ clear_cuda_memory()
334
+ elapsed = time.time() - start_time
335
+ processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
336
+ if processed > 0:
337
+ remaining = (elapsed / processed) * (total_files - processed)
338
+ elapsed_str = str(timedelta(seconds=int(elapsed)))
339
+ remaining_str = str(timedelta(seconds=int(remaining)))
340
+ print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
341
+
342
+ # ---------------- 7️⃣ Объединение чанков ----------------
343
+ def combine_chunks(temp_path, final_path):
344
+ """Объединение обработанных чанков в финальный датасет"""
345
+ chunks = sorted([
346
+ os.path.join(temp_path, d)
347
+ for d in os.listdir(temp_path)
348
+ if d.startswith("chunk_")
349
+ ])
350
+
351
+ datasets = [load_from_disk(chunk) for chunk in chunks]
352
+ combined = concatenate_datasets(datasets)
353
+ combined.save_to_disk(final_path)
354
+
355
+ print(f"✅ Датасет успешно сохранен в: {final_path}")
356
+
357
+ # Создаем временную папку для чанков
358
+ temp_path = f"{save_path}_temp"
359
+ os.makedirs(temp_path, exist_ok=True)
360
+
361
+ # Получаем список файлов
362
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
363
+ print(f"Всего найдено {len(image_paths)} изображений")
364
+
365
+ # Обработка с чанкованием
366
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=25000, batch_size=batch_size)
367
+
368
+ # Объединение чанков в финальный датасет
369
+ combine_chunks(temp_path, save_path)
370
+
371
+ # Удаление временной папки
372
+ try:
373
+ shutil.rmtree(temp_path)
374
+ print(f"✅ Временная папка {temp_path} успешно удалена")
375
+ except Exception as e:
376
+ print(f"⚠️ Ошибка при удалении временной папки: {e}")
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: 5ca79ca46ceff1fe1d1500cd2a12056a4e2d5d1a82dcee3b6c5f72531652af9e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.57 MB

Git LFS Details

  • SHA256: 286c7affa350104e6b2a790707f1136741f5855c8c05f96d82faad42f694a7b5
  • Pointer size: 132 Bytes
  • Size of remote file: 4.48 MB
pipeline_sdxs.py CHANGED
@@ -49,8 +49,14 @@ class SdxsPipeline(DiffusionPipeline):
49
  padding_side = 'right',
50
  max_length=max_length
51
  ).to(device)
52
- outs = self.text_encoder(**toks, output_hidden_states=True)
53
- return outs.hidden_states[-1]
 
 
 
 
 
 
54
 
55
  # Кодируем позитивные и негативные промпты
56
  pos_embeddings = encode_texts(prompt) if prompt is not None else None
 
49
  padding_side = 'right',
50
  max_length=max_length
51
  ).to(device)
52
+ outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True)
53
+ hidden = outs.hidden_states[-1]
54
+ mask = toks["attention_mask"].unsqueeze(-1) # (B, L, 1)
55
+
56
+ # 3. Zero-pad embeddings for pad tokens
57
+ hidden = hidden * mask
58
+
59
+ return hidden
60
 
61
  # Кодируем позитивные и негативные промпты
62
  pos_embeddings = encode_texts(prompt) if prompt is not None else None
src/dataset_mjnj.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8cf317c438de242a8cc0c7d710c00ceec53e887108b081235a1fb05dae0074b0
3
- size 23158
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eb6dcd0ea9aef5fc624bd91abb3058a663d2a66f2843ee9817cd7927b396523
3
+ size 23542
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:19c5220e0f9401dbd9abae23049b050b0f540b55205c46894e1a7e57cf0ab6bf
3
- size 1076278
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00123f2a11487702da78dff4cca6f78c71a25549eb878fac8412c752d2e45f19
3
+ size 5703503
train-Copy1.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader, Sampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ from collections import defaultdict
10
+ from torch.optim.lr_scheduler import LambdaLR
11
+ from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image,ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+
28
+ # --------------------------- Параметры ---------------------------
29
+ ds_path = "/workspace/sdxs/datasets/640"
30
+ project = "unet"
31
+ batch_size = 48
32
+ base_learning_rate = 4e-5
33
+ min_learning_rate = 2e-5
34
+ num_epochs = 50
35
+ # samples/save per epoch
36
+ sample_interval_share = 5
37
+ use_wandb = True
38
+ use_comet_ml = False
39
+ save_model = True
40
+ use_decay = True
41
+ fbp = False # fused backward pass
42
+ optimizer_type = "adam8bit"
43
+ torch_compile = False
44
+ unet_gradient = True
45
+ clip_sample = False #Scheduler
46
+ fixed_seed = True
47
+ shuffle = True
48
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
+ comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
+ torch.backends.cuda.matmul.allow_tf32 = True
51
+ torch.backends.cudnn.allow_tf32 = True
52
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
53
+ dtype = torch.float32
54
+ save_barrier = 1.006
55
+ warmup_percent = 0.005
56
+ percentile_clipping = 99 # 8bit optim
57
+ betta2 = 0.99
58
+ eps = 1e-8
59
+ clip_grad_norm = 1.0
60
+ steps_offset = 0 # Scheduler
61
+ limit = 0
62
+ checkpoints_folder = ""
63
+ mixed_precision = "no" #"fp16"
64
+ gradient_accumulation_steps = 1
65
+ accelerator = Accelerator(
66
+ mixed_precision=mixed_precision,
67
+ gradient_accumulation_steps=gradient_accumulation_steps
68
+ )
69
+ device = accelerator.device
70
+
71
+ # Параметры для диффузии
72
+ n_diffusion_steps = 50
73
+ samples_to_generate = 12
74
+ guidance_scale = 4
75
+
76
+ # Папки для сохранения результатов
77
+ generated_folder = "samples"
78
+ os.makedirs(generated_folder, exist_ok=True)
79
+
80
+ # Настройка seed для воспроизводимости
81
+ current_date = datetime.now()
82
+ seed = int(current_date.strftime("%Y%m%d"))
83
+ if fixed_seed:
84
+ torch.manual_seed(seed)
85
+ np.random.seed(seed)
86
+ random.seed(seed)
87
+ if torch.cuda.is_available():
88
+ torch.cuda.manual_seed_all(seed)
89
+
90
+ # --------------------------- Параметры LoRA ---------------------------
91
+ lora_name = ""
92
+ lora_rank = 32
93
+ lora_alpha = 64
94
+
95
+ print("init")
96
+
97
+ # --------------------------- вспомогательные функции ---------------------------
98
+ def sample_timesteps_bias(
99
+ batch_size: int,
100
+ progress: float, # [0..1]
101
+ num_train_timesteps: int, # обычно 1000
102
+ steps_offset: int = 0,
103
+ device=None,
104
+ mode: str = "beta", # "beta", "uniform"
105
+ ) -> torch.Tensor:
106
+ """
107
+ Возвращает timesteps с разным bias:
108
+ - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
109
+ - normal : около середины (гауссовое распределение)
110
+ - uniform: равномерно по всем timestep’ам
111
+ """
112
+
113
+ max_idx = num_train_timesteps - 1 - steps_offset
114
+
115
+ if mode == "beta":
116
+ alpha = 1.0 + .5 * (1.0 - progress)
117
+ beta = 1.0 + .5 * progress
118
+ samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
119
+
120
+ elif mode == "uniform":
121
+ samples = torch.rand(batch_size)
122
+
123
+ else:
124
+ raise ValueError(f"Unknown mode: {mode}")
125
+
126
+ timesteps = steps_offset + (samples * max_idx).long().to(device)
127
+ return timesteps
128
+
129
+ def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
+ normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
+
132
+ logit_normal_samples = torch.sigmoid(normal_samples)
133
+
134
+ return logit_normal_samples
135
+
136
+ # --------------------------- Инициализация WandB ---------------------------
137
+ if accelerator.is_main_process:
138
+ if use_wandb:
139
+ wandb.init(project=project+lora_name, config={
140
+ "batch_size": batch_size,
141
+ "base_learning_rate": base_learning_rate,
142
+ "num_epochs": num_epochs,
143
+ "fbp": fbp,
144
+ "optimizer_type": optimizer_type,
145
+ })
146
+ if use_comet_ml:
147
+ from comet_ml import Experiment
148
+ comet_experiment = Experiment(
149
+ api_key=comet_ml_api_key,
150
+ project_name=project,
151
+ workspace=comet_ml_workspace
152
+ )
153
+ # Логируем гиперпа��аметры в Comet ML
154
+ hyper_params = {
155
+ "batch_size": batch_size,
156
+ "base_learning_rate": base_learning_rate,
157
+ "min_learning_rate": min_learning_rate,
158
+ "num_epochs": num_epochs,
159
+ "n_diffusion_steps": n_diffusion_steps,
160
+ "guidance_scale": guidance_scale,
161
+ "optimizer_type": optimizer_type,
162
+ "mixed_precision": mixed_precision,
163
+ }
164
+ comet_experiment.log_parameters(hyper_params)
165
+
166
+ # Включение Flash Attention 2/SDPA
167
+ torch.backends.cuda.enable_flash_sdp(True)
168
+ # --------------------------- Инициализация Accelerator --------------------
169
+ gen = torch.Generator(device=device)
170
+ gen.manual_seed(seed)
171
+
172
+ # --------------------------- Загрузка моделей ---------------------------
173
+ # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
174
+ vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
175
+
176
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
177
+ if shift_factor is None:
178
+ shift_factor = 0.0
179
+
180
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
181
+ if scaling_factor is None:
182
+ scaling_factor = 1.0
183
+
184
+ latents_mean = getattr(vae.config, "latents_mean", None)
185
+ latents_std = getattr(vae.config, "latents_std", None)
186
+
187
+ from diffusers import FlowMatchEulerDiscreteScheduler
188
+
189
+ # Подстрой под свои параметры
190
+ num_train_timesteps = 1000
191
+
192
+ scheduler = FlowMatchEulerDiscreteScheduler(
193
+ num_train_timesteps=num_train_timesteps,
194
+ #shift=3.0, # пример; подбирается при необходимости
195
+ #use_dynamic_shifting=True
196
+ )
197
+
198
+
199
+ class DistributedResolutionBatchSampler(Sampler):
200
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
201
+ self.dataset = dataset
202
+ self.batch_size = max(1, batch_size // num_replicas)
203
+ self.num_replicas = num_replicas
204
+ self.rank = rank
205
+ self.shuffle = shuffle
206
+ self.drop_last = drop_last
207
+ self.epoch = 0
208
+
209
+ try:
210
+ widths = np.array(dataset["width"])
211
+ heights = np.array(dataset["height"])
212
+ except KeyError:
213
+ widths = np.zeros(len(dataset))
214
+ heights = np.zeros(len(dataset))
215
+
216
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
217
+ self.size_groups = {}
218
+ for w, h in self.size_keys:
219
+ mask = (widths == w) & (heights == h)
220
+ self.size_groups[(w, h)] = np.where(mask)[0]
221
+
222
+ self.group_num_batches = {}
223
+ total_batches = 0
224
+ for size, indices in self.size_groups.items():
225
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
226
+ self.group_num_batches[size] = num_full_batches
227
+ total_batches += num_full_batches
228
+
229
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
230
+
231
+ def __iter__(self):
232
+ if torch.cuda.is_available():
233
+ torch.cuda.empty_cache()
234
+ all_batches = []
235
+ rng = np.random.RandomState(self.epoch)
236
+
237
+ for size, indices in self.size_groups.items():
238
+ indices = indices.copy()
239
+ if self.shuffle:
240
+ rng.shuffle(indices)
241
+ num_full_batches = self.group_num_batches[size]
242
+ if num_full_batches == 0:
243
+ continue
244
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
245
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
246
+ start_idx = self.rank * self.batch_size
247
+ end_idx = start_idx + self.batch_size
248
+ gpu_batches = batches[:, start_idx:end_idx]
249
+ all_batches.extend(gpu_batches)
250
+
251
+ if self.shuffle:
252
+ rng.shuffle(all_batches)
253
+ accelerator.wait_for_everyone()
254
+ return iter(all_batches)
255
+
256
+ def __len__(self):
257
+ return self.num_batches
258
+
259
+ def set_epoch(self, epoch):
260
+ self.epoch = epoch
261
+
262
+ # Функция для выборки фиксированных семплов по размерам
263
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
264
+ size_groups = defaultdict(list)
265
+ try:
266
+ widths = dataset["width"]
267
+ heights = dataset["height"]
268
+ except KeyError:
269
+ widths = [0] * len(dataset)
270
+ heights = [0] * len(dataset)
271
+ for i, (w, h) in enumerate(zip(widths, heights)):
272
+ size = (w, h)
273
+ size_groups[size].append(i)
274
+
275
+ fixed_samples = {}
276
+ for size, indices in size_groups.items():
277
+ n_samples = min(samples_per_group, len(indices))
278
+ if len(size_groups)==1:
279
+ n_samples = samples_to_generate
280
+ if n_samples == 0:
281
+ continue
282
+ sample_indices = random.sample(indices, n_samples)
283
+ samples_data = [dataset[idx] for idx in sample_indices]
284
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
285
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
286
+ texts = [item["text"] for item in samples_data]
287
+ fixed_samples[size] = (latents, embeddings, texts)
288
+
289
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
290
+ return fixed_samples
291
+
292
+ if limit > 0:
293
+ dataset = load_from_disk(ds_path).select(range(limit))
294
+ else:
295
+ dataset = load_from_disk(ds_path)
296
+
297
+ def collate_fn_simple(batch):
298
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
299
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
300
+ return latents, embeddings
301
+
302
+ batch_sampler = DistributedResolutionBatchSampler(
303
+ dataset=dataset,
304
+ batch_size=batch_size,
305
+ num_replicas=accelerator.num_processes,
306
+ rank=accelerator.process_index,
307
+ shuffle=shuffle
308
+ )
309
+
310
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
311
+ print("Total samples",len(dataloader))
312
+ dataloader = accelerator.prepare(dataloader)
313
+
314
+ start_epoch = 0
315
+ global_step = 0
316
+ total_training_steps = (len(dataloader) * num_epochs)
317
+ world_size = accelerator.state.num_processes
318
+
319
+ # Опция загрузки модели из последнего чекпоинта (если существует)
320
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
321
+ if os.path.isdir(latest_checkpoint):
322
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
323
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
324
+ if unet_gradient:
325
+ unet.enable_gradient_checkpointing()
326
+ unet.set_use_memory_efficient_attention_xformers(False)
327
+ try:
328
+ unet.set_attn_processor(AttnProcessor2_0())
329
+ except Exception as e:
330
+ print(f"Ошибка при включении SDPA: {e}")
331
+ unet.set_use_memory_efficient_attention_xformers(True)
332
+
333
+ else:
334
+ # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
335
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
336
+
337
+ if lora_name:
338
+ print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
339
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
340
+ from peft.tuners.lora import LoraModel
341
+ import os
342
+ unet.requires_grad_(False)
343
+ print("Параметры базового UNet заморожены.")
344
+
345
+ lora_config = LoraConfig(
346
+ r=lora_rank,
347
+ lora_alpha=lora_alpha,
348
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
349
+ )
350
+ unet.add_adapter(lora_config)
351
+
352
+ from peft import get_peft_model
353
+ peft_unet = get_peft_model(unet, lora_config)
354
+ params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
355
+
356
+ if accelerator.is_main_process:
357
+ lora_params_count = sum(p.numel() for p in params_to_optimize)
358
+ total_params_count = sum(p.numel() for p in unet.parameters())
359
+ print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
360
+ print(f"Общее количество параметров UNet: {total_params_count:,}")
361
+
362
+ lora_save_path = os.path.join("lora", lora_name)
363
+ os.makedirs(lora_save_path, exist_ok=True)
364
+
365
+ def save_lora_checkpoint(model):
366
+ if accelerator.is_main_process:
367
+ print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
368
+ from peft.utils.save_and_load import get_peft_model_state_dict
369
+ lora_state_dict = get_peft_model_state_dict(model)
370
+ torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
371
+ model.peft_config["default"].save_pretrained(lora_save_path)
372
+ from diffusers import StableDiffusionXLPipeline
373
+ StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
374
+
375
+ # --------------------------- Оптимизатор ---------------------------
376
+ if lora_name:
377
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
378
+ else:
379
+ if fbp:
380
+ trainable_params = list(unet.parameters())
381
+
382
+ def create_optimizer(name, params):
383
+ if name == "adam8bit":
384
+ return bnb.optim.AdamW8bit(
385
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
386
+ percentile_clipping=percentile_clipping
387
+ )
388
+ elif name == "adam":
389
+ return torch.optim.AdamW(
390
+ params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
391
+ )
392
+ else:
393
+ raise ValueError(f"Unknown optimizer: {name}")
394
+
395
+ if fbp:
396
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
397
+ def optimizer_hook(param):
398
+ optimizer_dict[param].step()
399
+ optimizer_dict[param].zero_grad(set_to_none=True)
400
+ for param in trainable_params:
401
+ param.register_post_accumulate_grad_hook(optimizer_hook)
402
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
403
+ else:
404
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
405
+ def lr_schedule(step):
406
+ x = step / (total_training_steps * world_size)
407
+ warmup = warmup_percent
408
+ if not use_decay:
409
+ return base_learning_rate
410
+ if x < warmup:
411
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
412
+ decay_ratio = (x - warmup) / (1 - warmup)
413
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
414
+ (1 + math.cos(math.pi * decay_ratio))
415
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
416
+
417
+ num_params = sum(p.numel() for p in unet.parameters())
418
+ print(f"[rank {accelerator.process_index}] total params: {num_params}")
419
+ for name, param in unet.named_parameters():
420
+ if torch.isnan(param).any() or torch.isinf(param).any():
421
+ print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
422
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
423
+
424
+ if torch_compile:
425
+ print("compiling")
426
+ torch.set_float32_matmul_precision('high')
427
+ torch.backends.cudnn.allow_tf32 = True
428
+ torch.backends.cuda.matmul.allow_tf32 = True
429
+ unet = torch.compile(unet)#, mode='max-autotune')
430
+ print("compiling - ok")
431
+
432
+ # --------------------------- Фиксированные семплы для генерации ---------------------------
433
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
434
+
435
+ def get_negative_embedding(neg_prompt="", batch_size=1):
436
+ """
437
+ Возвращает эмбеддинг негативного промпта с батчем.
438
+ Загружает модели, вычисляет эмбеддинг, выгружает модели на CPU.
439
+ """
440
+ import torch
441
+ from transformers import AutoTokenizer, AutoModel
442
+
443
+ # Настройки
444
+ dtype = torch.float16
445
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
446
+
447
+ # Загрузка моделей (если ещё не загружены)
448
+ if not hasattr(get_negative_embedding, "tokenizer"):
449
+ get_negative_embedding.tokenizer = AutoTokenizer.from_pretrained(
450
+ "Qwen/Qwen3-0.6B"
451
+ )
452
+ get_negative_embedding.text_model = AutoModel.from_pretrained(
453
+ "Qwen/Qwen3-0.6B"
454
+ ).to(device).eval()
455
+
456
+ # Вычисление эмбеддинга
457
+ def encode_texts(texts, max_length=150):
458
+ with torch.inference_mode():
459
+ toks = get_negative_embedding.tokenizer(
460
+ texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
461
+ ).to(device)
462
+
463
+ outs = get_negative_embedding.text_model(**toks, output_hidden_states=True)
464
+ hidden_states = outs.hidden_states[-1] # [B, L, D]
465
+ return hidden_states
466
+
467
+ # Возвращаем эмбеддинг
468
+ if not neg_prompt:
469
+ hidden_dim = 1024 # Размерность эмбеддинга Qwen3-Embedding-0.6B
470
+ seq_len = 150
471
+ return torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
472
+
473
+ uncond_emb = encode_texts([neg_prompt]).to(dtype=dtype, device=device)
474
+ uncond_emb = uncond_emb.repeat(batch_size, 1, 1) # Добавляем батч
475
+
476
+ # Выгружаем модели
477
+ if hasattr(get_negative_embedding, "text_model"):
478
+ get_negative_embedding.text_model = get_negative_embedding.text_model.to("cpu")
479
+ if hasattr(get_negative_embedding, "tokenizer"):
480
+ del get_negative_embedding.tokenizer # Освобождаем память
481
+ torch.cuda.empty_cache()
482
+
483
+ return uncond_emb
484
+
485
+ uncond_emb = get_negative_embedding("low quality")
486
+
487
+ @torch.compiler.disable()
488
+ @torch.no_grad()
489
+ def generate_and_save_samples(fixed_samples_cpu,empty_embeddings, step):
490
+ original_model = None
491
+ try:
492
+ # безопасный unwrap: если компилировано, unwrap не нужен
493
+ if not torch_compile:
494
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
495
+ else:
496
+ original_model = unet.eval()
497
+
498
+ vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
499
+
500
+
501
+ all_generated_images = []
502
+ all_captions = []
503
+
504
+ for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
505
+ width, height = size
506
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
507
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
508
+
509
+ # начальный шум
510
+ latents = torch.randn(
511
+ sample_latents.shape,
512
+ device=device,
513
+ dtype=sample_latents.dtype,
514
+ generator=torch.Generator(device=device).manual_seed(seed)
515
+ )
516
+
517
+ # подготовим timesteps через шедулер
518
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
519
+
520
+ for t in scheduler.timesteps:
521
+ # guidance: удваиваем батч
522
+ if guidance_scale != 1:
523
+ latent_model_input = torch.cat([latents, latents], dim=0)
524
+
525
+ # empty_embeddings: [1, 1, hidden_dim] → повторяем по seq_len и batch
526
+ seq_len = sample_text_embeddings.shape[1]
527
+ hidden_dim = sample_text_embeddings.shape[2]
528
+ empty_embeddings_exp = empty_embeddings.expand(-1, seq_len, hidden_dim) # [1, seq_len, hidden_dim]
529
+ empty_embeddings_exp = empty_embeddings_exp.repeat(sample_text_embeddings.shape[0], 1, 1) # [batch, seq_len, hidden_dim]
530
+
531
+ text_embeddings_batch = torch.cat([empty_embeddings_exp, sample_text_embeddings], dim=0)
532
+ else:
533
+ latent_model_input = latents
534
+ text_embeddings_batch = sample_text_embeddings
535
+
536
+
537
+
538
+ # предсказание потока (velocity)
539
+ model_out = original_model(latent_model_input, t, encoder_hidden_states=text_embeddings_batch)
540
+ flow = getattr(model_out, "sample", model_out)
541
+
542
+ # guidance объединение
543
+ if guidance_scale != 1:
544
+ flow_uncond, flow_cond = flow.chunk(2)
545
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
546
+
547
+ # шаг через scheduler
548
+ latents = scheduler.step(flow, t, latents).prev_sample
549
+
550
+ current_latents = latents
551
+
552
+
553
+ # Параметры нормализации
554
+ latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
555
+
556
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
557
+ #decoded = decoded[:, :, 0, :, :] # [3, H, W]
558
+ #print(decoded.ndim, decoded.shape)
559
+
560
+ decoded_fp32 = decoded.to(torch.float32)
561
+ for img_idx, img_tensor in enumerate(decoded_fp32):
562
+
563
+ # Форма: [3, H, W] -> преобразуем в [H, W, 3]
564
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
565
+ img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
566
+
567
+ #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
568
+ if np.isnan(img).any():
569
+ print("NaNs found, saving stopped! Step:", step)
570
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
571
+
572
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
573
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
574
+ max_w_overall = max(255, max_w_overall)
575
+ max_h_overall = max(255, max_h_overall)
576
+
577
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
578
+ all_generated_images.append(padded_img)
579
+
580
+ caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
581
+ all_captions.append(caption_text)
582
+
583
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
584
+ pil_img.save(sample_path, "JPEG", quality=96)
585
+
586
+ if use_wandb and accelerator.is_main_process:
587
+ wandb_images = [
588
+ wandb.Image(img, caption=f"{all_captions[i]}")
589
+ for i, img in enumerate(all_generated_images)
590
+ ]
591
+ wandb.log({"generated_images": wandb_images})
592
+ if use_comet_ml and accelerator.is_main_process:
593
+ for i, img in enumerate(all_generated_images):
594
+ comet_experiment.log_image(
595
+ image_data=img,
596
+ name=f"step_{step}_img_{i}",
597
+ step=step,
598
+ metadata={
599
+ "caption": all_captions[i],
600
+ "width": img.width,
601
+ "height": img.height,
602
+ "global_step": step
603
+ }
604
+ )
605
+ finally:
606
+ # вернуть VAE на CPU (как было в твоём коде)
607
+ vae.to("cpu")
608
+ for var in list(locals().keys()):
609
+ if isinstance(locals()[var], torch.Tensor):
610
+ del locals()[var]
611
+ torch.cuda.empty_cache()
612
+ gc.collect()
613
+
614
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
615
+ if accelerator.is_main_process:
616
+ if save_model:
617
+ print("Генерация сэмплов до старта обучения...")
618
+ generate_and_save_samples(fixed_samples,uncond_emb,0)
619
+ accelerator.wait_for_everyone()
620
+
621
+ # Модифицируем функцию сохранения модели для поддержки LoRA
622
+ def save_checkpoint(unet, variant=""):
623
+ if accelerator.is_main_process:
624
+ if lora_name:
625
+ save_lora_checkpoint(unet)
626
+ else:
627
+ # безопасный unwrap для компилированной модели
628
+ model_to_save = None
629
+ if not torch_compile:
630
+ model_to_save = accelerator.unwrap_model(unet)
631
+ else:
632
+ model_to_save = unet
633
+
634
+ if variant != "":
635
+ model_to_save.to(dtype=torch.float16).save_pretrained(
636
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
637
+ )
638
+ else:
639
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
640
+
641
+ unet = unet.to(dtype=dtype)
642
+
643
+ # --------------------------- Тренировочный цикл ---------------------------
644
+ if accelerator.is_main_process:
645
+ print(f"Total steps per GPU: {total_training_steps}")
646
+
647
+ epoch_loss_points = []
648
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
649
+
650
+ steps_per_epoch = len(dataloader)
651
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
652
+ min_loss = 2.
653
+
654
+ for epoch in range(start_epoch, start_epoch + num_epochs):
655
+ batch_losses = []
656
+ batch_grads = []
657
+ batch_sampler.set_epoch(epoch)
658
+ accelerator.wait_for_everyone()
659
+ unet.train()
660
+ #print("epoch:",epoch)
661
+ for step, (latents, embeddings) in enumerate(dataloader):
662
+ with accelerator.accumulate(unet):
663
+ if save_model == False and step == 5 :
664
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
665
+ print(f"Шаг {step}: {used_gb:.2f} GB")
666
+
667
+ # шум
668
+ noise = torch.randn_like(latents, dtype=latents.dtype)
669
+
670
+ # берём t из [0, 1]
671
+ t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
672
+
673
+ # интерполяция между x0 и шумом
674
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
675
+
676
+ # делаем integer timesteps для UNet
677
+ timesteps = (t * scheduler.config.num_train_timesteps).long()
678
+
679
+ # предсказание потока (Flow)
680
+ model_pred = unet(noisy_latents, timesteps, embeddings).sample
681
+
682
+ # таргет — векторное поле (= разность между конечными точками)
683
+ target = noise - latents # или latents - noise?
684
+
685
+ # MSE лосс
686
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
687
+
688
+ # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
689
+ batch_losses.append(mse_loss.detach().item())
690
+
691
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
692
+ accelerator.wait_for_everyone()
693
+
694
+ # Backward
695
+ accelerator.backward(mse_loss)
696
+
697
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
698
+ accelerator.wait_for_everyone()
699
+
700
+ grad = 0.0
701
+ if not fbp:
702
+ if accelerator.sync_gradients:
703
+ with torch.amp.autocast('cuda', enabled=False):
704
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
705
+ grad = float(grad_val)
706
+ optimizer.step()
707
+ lr_scheduler.step()
708
+ optimizer.zero_grad(set_to_none=True)
709
+
710
+ if accelerator.sync_gradients:
711
+ global_step += 1
712
+ progress_bar.update(1)
713
+ # Логируем метрики
714
+ if accelerator.is_main_process:
715
+ if fbp:
716
+ current_lr = base_learning_rate
717
+ else:
718
+ current_lr = lr_scheduler.get_last_lr()[0]
719
+ batch_grads.append(grad)
720
+
721
+ log_data = {}
722
+ log_data["loss"] = mse_loss.detach().item()
723
+ log_data["lr"] = current_lr
724
+ log_data["grad"] = grad
725
+ if accelerator.sync_gradients:
726
+ if use_wandb:
727
+ wandb.log(log_data, step=global_step)
728
+ if use_comet_ml:
729
+ comet_experiment.log_metrics(log_data, step=global_step)
730
+
731
+ # Генерируем сэмплы с заданным интервалом
732
+ if global_step % sample_interval == 0:
733
+ generate_and_save_samples(fixed_samples,uncond_emb, global_step)
734
+ last_n = sample_interval
735
+
736
+ if save_model:
737
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
738
+ print("saving:", avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
739
+ if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
740
+ min_loss = avg_sample_loss
741
+ save_checkpoint(unet)
742
+
743
+
744
+ if accelerator.is_main_process:
745
+ # local averages
746
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
747
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
748
+
749
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
750
+ log_data_ep = {
751
+ "epoch_loss": avg_epoch_loss,
752
+ "epoch_grad": avg_epoch_grad,
753
+ "epoch": epoch + 1,
754
+ }
755
+ if use_wandb:
756
+ wandb.log(log_data_ep)
757
+ if use_comet_ml:
758
+ comet_experiment.log_metrics(log_data_ep)
759
+
760
+ # Завершение обучения - сохраняем финальную модель
761
+ if accelerator.is_main_process:
762
+ print("Обучение завершено! Сохраняем финальную модель...")
763
+ if save_model:
764
+ save_checkpoint(unet,"fp16")
765
+ if use_comet_ml:
766
+ comet_experiment.end()
767
+ accelerator.free_memory()
768
+ if torch.distributed.is_initialized():
769
+ torch.distributed.destroy_process_group()
770
+
771
+ print("Готово!")
train.py CHANGED
@@ -26,11 +26,11 @@ import torch.nn.functional as F
26
  from collections import deque
27
 
28
  # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs/datasets/640"
30
  project = "unet"
31
  batch_size = 48
32
- base_learning_rate = 4e-5
33
- min_learning_rate = 2e-5
34
  num_epochs = 50
35
  # samples/save per epoch
36
  sample_interval_share = 5
@@ -52,7 +52,7 @@ torch.backends.cudnn.allow_tf32 = True
52
  torch.backends.cuda.enable_mem_efficient_sdp(False)
53
  dtype = torch.float32
54
  save_barrier = 1.006
55
- warmup_percent = 0.005
56
  percentile_clipping = 99 # 8bit optim
57
  betta2 = 0.99
58
  eps = 1e-8
@@ -94,45 +94,6 @@ lora_alpha = 64
94
 
95
  print("init")
96
 
97
- # --------------------------- вспомогательные функции ---------------------------
98
- def sample_timesteps_bias(
99
- batch_size: int,
100
- progress: float, # [0..1]
101
- num_train_timesteps: int, # обычно 1000
102
- steps_offset: int = 0,
103
- device=None,
104
- mode: str = "beta", # "beta", "uniform"
105
- ) -> torch.Tensor:
106
- """
107
- Возвращает timesteps с разным bias:
108
- - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
109
- - normal : около середины (гауссовое распределение)
110
- - uniform: равномерно по всем timestep’ам
111
- """
112
-
113
- max_idx = num_train_timesteps - 1 - steps_offset
114
-
115
- if mode == "beta":
116
- alpha = 1.0 + .5 * (1.0 - progress)
117
- beta = 1.0 + .5 * progress
118
- samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
119
-
120
- elif mode == "uniform":
121
- samples = torch.rand(batch_size)
122
-
123
- else:
124
- raise ValueError(f"Unknown mode: {mode}")
125
-
126
- timesteps = steps_offset + (samples * max_idx).long().to(device)
127
- return timesteps
128
-
129
- def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
- normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
-
132
- logit_normal_samples = torch.sigmoid(normal_samples)
133
-
134
- return logit_normal_samples
135
-
136
  # --------------------------- Инициализация WandB ---------------------------
137
  if accelerator.is_main_process:
138
  if use_wandb:
@@ -460,9 +421,12 @@ def get_negative_embedding(neg_prompt="", batch_size=1):
460
  texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
461
  ).to(device)
462
 
463
- outs = get_negative_embedding.text_model(**toks, output_hidden_states=True)
464
- hidden_states = outs.hidden_states[-1] # [B, L, D]
465
- return hidden_states
 
 
 
466
 
467
  # Возвращаем эмбеддинг
468
  if not neg_prompt:
 
26
  from collections import deque
27
 
28
  # --------------------------- Параметры ---------------------------
29
+ ds_path = "/workspace/sdxs/datasets/640_mjnj"
30
  project = "unet"
31
  batch_size = 48
32
+ base_learning_rate = 5e-5
33
+ min_learning_rate = 2.5e-5
34
  num_epochs = 50
35
  # samples/save per epoch
36
  sample_interval_share = 5
 
52
  torch.backends.cuda.enable_mem_efficient_sdp(False)
53
  dtype = torch.float32
54
  save_barrier = 1.006
55
+ warmup_percent = 0.01
56
  percentile_clipping = 99 # 8bit optim
57
  betta2 = 0.99
58
  eps = 1e-8
 
94
 
95
  print("init")
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # --------------------------- Инициализация WandB ---------------------------
98
  if accelerator.is_main_process:
99
  if use_wandb:
 
421
  texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
422
  ).to(device)
423
 
424
+ outs = get_negative_embedding.text_model(**toks, output_hidden_states=True, return_dict=True)
425
+ hidden = outs.hidden_states[-1] # [B, L, D]
426
+ mask = toks["attention_mask"].unsqueeze(-1) # (B, L, 1)
427
+ hidden = hidden * mask
428
+
429
+ return hidden
430
 
431
  # Возвращаем эмбеддинг
432
  if not neg_prompt: