recoilme commited on
Commit
b9a43be
·
1 Parent(s): 7df52ae
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jupyter Notebook
2
+ __pycache__/
3
+ *.pyc
4
+ .ipynb_checkpoints/
5
+ *.ipynb_checkpoints/*
6
+ .ipynb_checkpoints/*
7
+ src/samples
8
+ # cache
9
+ cache
10
+ datasets
11
+ test
12
+ wandb
13
+ nohup.out
TRAIN.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ Краткая инструкция по установке
6
+ Обновите систему и установите git-lfs:
7
+
8
+ ```
9
+ apt update
10
+ apt install git-lfs
11
+ git config --global credential.helper store
12
+ ```
13
+ Обновите pip и установите требуемые пакеты:
14
+
15
+ ```
16
+ python -m pip install --upgrade pip
17
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 -U
18
+ pip install flash-attn --no-build-isolation # optional
19
+ ```
20
+ Клонируйте репозиторий:
21
+
22
+ ```
23
+ git clone https://huggingface.co/AiArtLab/sdxs
24
+ cd sdxs/
25
+ pip install -r requirements.txt
26
+ ```
27
+ Подготовьте датасет:
28
+
29
+ ```
30
+ mkdir datasets
31
+ cd datasets
32
+ huggingface-cli download AiArtLab/384 --local-dir 384 --repo-type dataset
33
+ ```
34
+ Выполните вход в сервисы:
35
+
36
+ ```
37
+ huggingface-cli login
38
+ wandb login
39
+ ```
40
+ Запустите обучение!
41
+
42
+ ```
43
+ nohup accelerate launch train.py &
44
+ ```
butterfly.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b923bef9a5d1fe7103e960c943c110ec46155fc71d7f45e0070f3ef072bbdcb
3
+ size 237918081
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torch>=2.6.0
2
+ # torchvision>=0.21.0
3
+ # torchaudio>=2.6.0
4
+ diffusers>=0.32.2
5
+ accelerate>=1.5.2
6
+ datasets>=3.5.0
7
+ matplotlib>=3.10.1
8
+ wandb>=0.19.8
9
+ huggingface_hub>=0.29.3
10
+ bitsandbytes>=0.45.4
11
+ transformers
src/dataset_from_folder.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install flash-attn --no-build-isolation
2
+ import torch
3
+ import os
4
+ import gc
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import shutil
9
+ import time
10
+
11
+ from datasets import Dataset, load_from_disk, concatenate_datasets
12
+ from diffusers import AutoencoderKL,AutoencoderKLWan
13
+ from torchvision.transforms import Resize, ToTensor, Normalize, Compose, InterpolationMode, Lambda
14
+ from transformers import AutoModel, AutoImageProcessor, AutoTokenizer
15
+ from typing import Dict, List, Tuple, Optional, Any
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+ from datetime import timedelta
19
+
20
+ # ---------------- 1️⃣ Настройки ----------------
21
+ dtype = torch.float16
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ batch_size = 5
24
+ min_size = 192 #256 #192
25
+ max_size = 384 #256 #384
26
+ step = 64
27
+ img_share = 1.0
28
+ empty_share = 0.05
29
+ limit = 0
30
+ textemb_full = False
31
+ # Основная процедура обработки
32
+ folder_path = "/workspace/butterfly" #alchemist"
33
+ save_path = "/workspace/sdxs3d/datasets/butterfly" #"alchemist"
34
+ os.makedirs(save_path, exist_ok=True)
35
+
36
+ # Функция для очистки CUDA памяти
37
+ def clear_cuda_memory():
38
+ if torch.cuda.is_available():
39
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
40
+ print(f"used_gb: {used_gb:.2f} GB")
41
+ torch.cuda.empty_cache()
42
+ gc.collect()
43
+
44
+ # ---------------- 2️⃣ Загрузка моделей ----------------
45
+ def load_models():
46
+ print("Загрузка моделей...")
47
+ #vae = AutoencoderKLWan.from_pretrained("AiArtLab/simplevae",subfolder="wan16x_vae_nightly",torch_dtype=dtype).to(device).eval()
48
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", subfolder=None,torch_dtype=dtype).to(device).eval()
49
+
50
+ #vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to(device).eval()
51
+ #vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell",subfolder="vae",torch_dtype=dtype).to(device).eval()
52
+ #vae = AutoencoderKL.from_pretrained("/home/recoilme/sdxs/vae", variant="fp16",torch_dtype=dtype).to(device).eval()
53
+ model = AutoModel.from_pretrained("visheratin/mexma-siglip2", dtype=dtype, trust_remote_code=True, optimized=True).to(device).eval()
54
+ processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip2", use_fast=True)
55
+ tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip2")
56
+ return vae, model, processor, tokenizer
57
+
58
+ vae, model, processor, tokenizer = load_models()
59
+
60
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
61
+ if shift_factor is None:
62
+ shift_factor = 0.0
63
+
64
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
65
+ if scaling_factor is None:
66
+ scaling_factor = 1.0
67
+
68
+ latents_mean = getattr(vae.config, "latents_mean", None)
69
+ latents_std = getattr(vae.config, "latents_std", None)
70
+
71
+ # ---------------- 3️⃣ Трансформации ----------------
72
+ def get_image_transform(min_size=256, max_size=512, step=64):
73
+ def transform(img, dry_run=False):
74
+ # Сохраняем исходные размеры изображения
75
+ original_width, original_height = img.size
76
+
77
+ # 0. Ресайз: масштабируем изображение, чтобы максимальная сторона была равна max_size
78
+ if original_width >= original_height:
79
+ new_width = max_size
80
+ new_height = int(max_size * original_height / original_width)
81
+ else:
82
+ new_height = max_size
83
+ new_width = int(max_size * original_width / original_height)
84
+
85
+ if new_height < min_size or new_width < min_size:
86
+ # 1. Ресайз: масштабируем изображение, чтобы минимальная сторона была равна min_size
87
+ if original_width <= original_height:
88
+ new_width = min_size
89
+ new_height = int(min_size * original_height / original_width)
90
+ else:
91
+ new_height = min_size
92
+ new_width = int(min_size * original_width / original_height)
93
+
94
+ # 2. Проверка: если одна из сторон превышает max_size, готовимся к обрезке
95
+ crop_width = min(max_size, (new_width // step) * step)
96
+ crop_height = min(max_size, (new_height // step) * step)
97
+
98
+ # Убеждаемся, что размеры обрезки не меньше min_size
99
+ crop_width = max(min_size, crop_width)
100
+ crop_height = max(min_size, crop_height)
101
+
102
+ # Если запрошен только предварительный расчёт размеров
103
+ if dry_run:
104
+ return crop_width, crop_height
105
+
106
+ # Конвертация в RGB и ресайз
107
+ img_resized = img.convert("RGB").resize((new_width, new_height), Image.LANCZOS)
108
+
109
+ # Определение координат обрезки (обрезаем с учетом вотермарок - треть сверху)
110
+ top = (new_height - crop_height) // 3
111
+ left = 0
112
+
113
+ # Обрезка изображения
114
+ img_cropped = img_resized.crop((left, top, left + crop_width, top + crop_height))
115
+
116
+ # Сохраняем итоговые размеры после всех преобразований
117
+ final_width, final_height = img_cropped.size
118
+
119
+ # тензор
120
+ img_tensor = ToTensor()(img_cropped)
121
+ img_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img_tensor)
122
+ return img_tensor, img_cropped, final_width, final_height
123
+
124
+ return transform
125
+
126
+ # ---------------- 4️⃣ Функции обработки ----------------
127
+ def encode_images_batch(images, processor, model, empty_share=0.0):
128
+ """
129
+ images: список PIL.Image
130
+ processor: трансформер для препроцессинга изображений
131
+ model: vision encoder (например, CLIP или подобный)
132
+ empty_share: доля эмбеддингов, которые нужно обнулить
133
+ """
134
+ # Преобразуем весь батч сразу (вместо обхода по каждому изображению)
135
+ processed = processor(images=images, return_tensors="pt")
136
+ pixel_values = processed["pixel_values"].to(device, dtype)
137
+
138
+ with torch.inference_mode():
139
+ outputs = model.vision_model(pixel_values)
140
+ #hidden_states = outputs.last_hidden_state # [B, seq_len, dim]
141
+ pooled = outputs.pooler_output # [B, dim]
142
+
143
+ # Добавляем pooled embedding в конец sequence
144
+ #context = torch.cat([hidden_states, pooled.unsqueeze(1)], dim=1) # [B, seq_len+1, dim]
145
+ context = pooled.unsqueeze(1)
146
+
147
+ # Добавляем нулевые эмбеддинги с вероятностью empty_share
148
+ if empty_share > 0:
149
+ batch_size = context.shape[0]
150
+ num_empty = int(batch_size * empty_share)
151
+ if num_empty > 0:
152
+ zero_embeddings = torch.zeros_like(context[:num_empty])
153
+ context[:num_empty] = zero_embeddings
154
+
155
+ # Преобразуем bfloat16 в float32 если нужно
156
+ if context.dtype == torch.bfloat16:
157
+ context = context.to(torch.float32)
158
+
159
+ return context.cpu().numpy() # [B, seq_len+1, dim]
160
+
161
+
162
+ def encode_texts_batch(texts, tokenizer, model):
163
+ with torch.inference_mode():
164
+ text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length",
165
+ max_length=512,
166
+ truncation=True).to(device)
167
+ text_embeddings = model.encode_texts(text_tokenized.input_ids, text_tokenized.attention_mask)
168
+ return text_embeddings.unsqueeze(1).cpu().numpy()
169
+
170
+ def encode_texts_batch_full(texts, tokenizer, model):
171
+ with torch.inference_mode():
172
+ text_tokenized = tokenizer(texts, return_tensors="pt", padding="max_length",max_length=512,truncation=True).to(device)
173
+ features = model.text_model(
174
+ input_ids=text_tokenized.input_ids, attention_mask=text_tokenized.attention_mask
175
+ ).last_hidden_state
176
+ features_proj = model.text_projector(features)
177
+ return features_proj.cpu().numpy()
178
+
179
+ def clean_label(label):
180
+ label = label.replace("Image 1", "").replace("Image 2", "").replace("Image 3", "").replace("Image 4", "")
181
+ return label
182
+
183
+ def process_labels_for_guidance(original_labels, prob_to_make_empty=0.01):
184
+ """
185
+ Обрабатывает список меток для classifier-free guidance.
186
+
187
+ С вероятностью prob_to_make_empty:
188
+ - Метка в первом списке заменяется на пустую строку.
189
+ - К метке во втором списке добавляется префикс "zero:".
190
+
191
+ В противном случае метки в обоих списках остаются оригинальными.
192
+
193
+ """
194
+ labels_for_model = []
195
+ labels_for_logging = []
196
+
197
+ for label in original_labels:
198
+ if random.random() < prob_to_make_empty:
199
+ labels_for_model.append("") # Заменяем на пустую строку для модели
200
+ labels_for_logging.append(f"zero: {label}") # Добавляем префикс для логгирования
201
+ else:
202
+ labels_for_model.append(label) # Оставляем оригинальную метку для модели
203
+ labels_for_logging.append(label) # Оставляем оригинальную метку для логгирования
204
+
205
+ return labels_for_model, labels_for_logging
206
+
207
+ def encode_to_latents(images, texts):
208
+ transform = get_image_transform(min_size, max_size, step)
209
+
210
+ try:
211
+ # Обработка изображений (все одинакового размера)
212
+ transformed_tensors = []
213
+ pil_images = []
214
+ widths, heights = [], []
215
+
216
+ # Применяем трансформацию ко всем изображениям
217
+ for img in images:
218
+ try:
219
+ t_img, pil_img, w, h = transform(img)
220
+ transformed_tensors.append(t_img)
221
+ pil_images.append(pil_img)
222
+ widths.append(w)
223
+ heights.append(h)
224
+ except Exception as e:
225
+ print(f"Ошибка трансформации: {e}")
226
+ continue
227
+
228
+ if not transformed_tensors:
229
+ return None
230
+
231
+ # Создаём батч
232
+ batch_tensor = torch.stack(transformed_tensors).to(device, dtype)
233
+ if batch_tensor.ndim==5:
234
+ batch_tensor = batch_tensor.unsqueeze(2) # [B, C, 1, H, W]
235
+
236
+ # Кодируем батч
237
+ with torch.no_grad():
238
+ posteriors = vae.encode(batch_tensor).latent_dist.mode()
239
+
240
+ latents = (posteriors - shift_factor) / scaling_factor
241
+
242
+ if latents_mean!=None and latents_std!=None:
243
+ latents = (latents - torch.tensor(latents_mean, device=device, dtype=dtype).view(1, -1, 1, 1, 1)) / torch.tensor(latents_std, device=device, dtype=dtype).view(1, -1, 1, 1, 1)
244
+ #print(latents.ndim, latents.shape)
245
+ if latents.ndim==5:
246
+ latents = latents[:, :, 0, :, :] # Убираем временную ось [B, C, H, W]
247
+
248
+ latents_np = latents.to(dtype).cpu().numpy()
249
+
250
+ # Обрабатываем тексты
251
+ text_labels = [clean_label(text) for text in texts]
252
+ if random.random() < img_share:
253
+ embeddings = encode_images_batch(pil_images, processor, model)
254
+ text_labels = [f"img: {label}" for label in text_labels]
255
+ else:
256
+ model_prompts, text_labels = process_labels_for_guidance(text_labels, empty_share)
257
+ if textemb_full:
258
+ embeddings = encode_texts_batch_full(model_prompts, tokenizer, model)
259
+ else:
260
+ embeddings = encode_texts_batch(model_prompts, tokenizer, model)
261
+
262
+ return {
263
+ "vae": latents_np,
264
+ "embeddings": embeddings,
265
+ "text": text_labels,
266
+ "width": widths,
267
+ "height": heights
268
+ }
269
+
270
+ except Exception as e:
271
+ print(f"Критическая ошибка в encode_to_latents: {e}")
272
+ raise
273
+
274
+
275
+ # ---------------- 5️⃣ Обработка папки с изображениями и текстами ----------------
276
+ def process_folder(folder_path, limit=None):
277
+ """
278
+ Рекурсивно обходит указанную директорию и все вложенные директории,
279
+ собирая пути к изображениям и соответствующим текстовым файлам.
280
+ """
281
+ image_paths = []
282
+ text_paths = []
283
+ width = []
284
+ height = []
285
+ transform = get_image_transform(min_size, max_size, step)
286
+
287
+ # Используем os.walk для рекурсивного обхода директорий
288
+ for root, dirs, files in os.walk(folder_path):
289
+ for filename in files:
290
+ # Проверяем, является ли файл изображением
291
+ if filename.lower().endswith((".jpg", ".jpeg", ".png")):
292
+ image_path = os.path.join(root, filename)
293
+ try:
294
+ img = Image.open(image_path)
295
+ except Exception as e:
296
+ print(f"Ошибка при открытии {image_path}: {e}")
297
+ os.remove(image_path)
298
+ text_path = os.path.splitext(image_path)[0] + ".txt"
299
+ if os.path.exists(text_path):
300
+ os.remove(text_path)
301
+ continue
302
+ # Применяем трансформацию только для получения размеров
303
+ w, h = transform(img, dry_run=True)
304
+ # Формируем путь к текстовому файлу
305
+ text_path = os.path.splitext(image_path)[0] + ".txt"
306
+
307
+ # Добавляем пути, если текстовый файл существует
308
+ if os.path.exists(text_path) and min(w, h)>0:
309
+ image_paths.append(image_path)
310
+ text_paths.append(text_path)
311
+ width.append(w) # Добавляем в список
312
+ height.append(h) # Добавляем в список
313
+
314
+ # Проверяем ограничение на количество
315
+ if limit and limit>0 and len(image_paths) >= limit:
316
+ print(f"Достигнут лимит в {limit} изображений")
317
+ return image_paths, text_paths, width, height
318
+
319
+ print(f"Найдено {len(image_paths)} изображений с текстовыми описаниями")
320
+ return image_paths, text_paths, width, height
321
+
322
+ def process_in_chunks(image_paths, text_paths, width, height, chunk_size=50000, batch_size=1):
323
+ total_files = len(image_paths)
324
+ start_time = time.time()
325
+ chunks = range(0, total_files, chunk_size)
326
+
327
+ for chunk_idx, start in enumerate(chunks, 1):
328
+ end = min(start + chunk_size, total_files)
329
+ chunk_image_paths = image_paths[start:end]
330
+ chunk_text_paths = text_paths[start:end]
331
+ chunk_widths = width[start:end] if isinstance(width, list) else [width] * len(chunk_image_paths)
332
+ chunk_heights = height[start:end] if isinstance(height, list) else [height] * len(chunk_image_paths)
333
+
334
+ # Чтение текстов
335
+ chunk_texts = []
336
+ for text_path in chunk_text_paths:
337
+ try:
338
+ with open(text_path, 'r', encoding='utf-8') as f:
339
+ text = f.read().strip()
340
+ chunk_texts.append(text)
341
+ except Exception as e:
342
+ print(f"Ошибка чтения {text_path}: {e}")
343
+ chunk_texts.append("")
344
+
345
+ # Группируем изображения по размерам
346
+ size_groups = {}
347
+ for i in range(len(chunk_image_paths)):
348
+ size_key = (chunk_widths[i], chunk_heights[i])
349
+ if size_key not in size_groups:
350
+ size_groups[size_key] = {"image_paths": [], "texts": []}
351
+ size_groups[size_key]["image_paths"].append(chunk_image_paths[i])
352
+ size_groups[size_key]["texts"].append(chunk_texts[i])
353
+
354
+ # Обрабатываем каждую группу размеров отдельно
355
+ for size_key, group_data in size_groups.items():
356
+ print(f"Обработка группы с размером {size_key[0]}x{size_key[1]} - {len(group_data['image_paths'])} изображений")
357
+
358
+ group_dataset = Dataset.from_dict({
359
+ "image_path": group_data["image_paths"],
360
+ "text": group_data["texts"]
361
+ })
362
+
363
+ # Теперь можно использовать указанный batch_size, т.к. все изображения одного размера
364
+ processed_group = group_dataset.map(
365
+ lambda examples: encode_to_latents(
366
+ [Image.open(path) for path in examples["image_path"]],
367
+ examples["text"]
368
+ ),
369
+ batched=True,
370
+ batch_size=batch_size,
371
+ #remove_columns=["image_path"],
372
+ desc=f"Обработка группы размера {size_key[0]}x{size_key[1]}"
373
+ )
374
+
375
+ # Сохраняем результаты группы
376
+ group_save_path = f"{save_path}_temp/chunk_{chunk_idx}_size_{size_key[0]}x{size_key[1]}"
377
+ processed_group.save_to_disk(group_save_path)
378
+ clear_cuda_memory()
379
+ elapsed = time.time() - start_time
380
+ processed = (chunk_idx - 1) * chunk_size + sum([len(sg["image_paths"]) for sg in list(size_groups.values())[:list(size_groups.values()).index(group_data) + 1]])
381
+ if processed > 0:
382
+ remaining = (elapsed / processed) * (total_files - processed)
383
+ elapsed_str = str(timedelta(seconds=int(elapsed)))
384
+ remaining_str = str(timedelta(seconds=int(remaining)))
385
+ print(f"ETA: Прошло {elapsed_str}, Осталось {remaining_str}, Прогресс {processed}/{total_files} ({processed/total_files:.1%})")
386
+
387
+ # ---------------- 7️⃣ Объединение чанков ----------------
388
+ def combine_chunks(temp_path, final_path):
389
+ """Объединение обработанных чанков в финальный датасет"""
390
+ chunks = sorted([
391
+ os.path.join(temp_path, d)
392
+ for d in os.listdir(temp_path)
393
+ if d.startswith("chunk_")
394
+ ])
395
+
396
+ datasets = [load_from_disk(chunk) for chunk in chunks]
397
+ combined = concatenate_datasets(datasets)
398
+ combined.save_to_disk(final_path)
399
+
400
+ print(f"✅ Датасет успешно сохранен в: {final_path}")
401
+
402
+
403
+
404
+ # Создаем временную папку для чанков
405
+ temp_path = f"{save_path}_temp"
406
+ os.makedirs(temp_path, exist_ok=True)
407
+
408
+ # Получаем список файлов
409
+ image_paths, text_paths, width, height = process_folder(folder_path,limit)
410
+ print(f"Всего найдено {len(image_paths)} изображений")
411
+
412
+ # Обработка с чанкованием
413
+ process_in_chunks(image_paths, text_paths, width, height, chunk_size=100000, batch_size=batch_size)
414
+
415
+ # Объединение чанков в финальный датасет
416
+ combine_chunks(temp_path, save_path)
417
+
418
+ # Удаление временной папки
419
+ try:
420
+ shutil.rmtree(temp_path)
421
+ print(f"✅ Временная папка {temp_path} успешно удалена")
422
+ except Exception as e:
423
+ print(f"⚠️ Ошибка при удалении временной папки: {e}")
src/dataset_sample.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/model_create.ipynb ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "5212f806-14b4-4b5f-bcb4-09e36df3b7d9",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "test unet\n",
14
+ "Количество параметров: 1616742724\n",
15
+ "Output shape: torch.Size([1, 4, 60, 48])\n",
16
+ "UNet2DConditionModel(\n",
17
+ " (conv_in): Conv2d(4, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
18
+ " (time_proj): Timesteps()\n",
19
+ " (time_embedding): TimestepEmbedding(\n",
20
+ " (linear_1): Linear(in_features=288, out_features=1152, bias=True)\n",
21
+ " (act): SiLU()\n",
22
+ " (linear_2): Linear(in_features=1152, out_features=1152, bias=True)\n",
23
+ " )\n",
24
+ " (down_blocks): ModuleList(\n",
25
+ " (0): DownBlock2D(\n",
26
+ " (resnets): ModuleList(\n",
27
+ " (0-1): 2 x ResnetBlock2D(\n",
28
+ " (norm1): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
29
+ " (conv1): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
30
+ " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
31
+ " (norm2): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
32
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
33
+ " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
34
+ " (nonlinearity): SiLU()\n",
35
+ " )\n",
36
+ " )\n",
37
+ " (downsamplers): ModuleList(\n",
38
+ " (0): Downsample2D(\n",
39
+ " (conv): Conv2d(288, 288, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
40
+ " )\n",
41
+ " )\n",
42
+ " )\n",
43
+ " (1): CrossAttnDownBlock2D(\n",
44
+ " (attentions): ModuleList(\n",
45
+ " (0-1): 2 x Transformer2DModel(\n",
46
+ " (norm): GroupNorm(32, 576, eps=1e-06, affine=True)\n",
47
+ " (proj_in): Linear(in_features=576, out_features=576, bias=True)\n",
48
+ " (transformer_blocks): ModuleList(\n",
49
+ " (0): BasicTransformerBlock(\n",
50
+ " (norm1): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
51
+ " (attn1): Attention(\n",
52
+ " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
53
+ " (to_k): Linear(in_features=576, out_features=576, bias=False)\n",
54
+ " (to_v): Linear(in_features=576, out_features=576, bias=False)\n",
55
+ " (to_out): ModuleList(\n",
56
+ " (0): Linear(in_features=576, out_features=576, bias=True)\n",
57
+ " (1): Dropout(p=0.0, inplace=False)\n",
58
+ " )\n",
59
+ " )\n",
60
+ " (norm2): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
61
+ " (attn2): Attention(\n",
62
+ " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
63
+ " (to_k): Linear(in_features=1152, out_features=576, bias=False)\n",
64
+ " (to_v): Linear(in_features=1152, out_features=576, bias=False)\n",
65
+ " (to_out): ModuleList(\n",
66
+ " (0): Linear(in_features=576, out_features=576, bias=True)\n",
67
+ " (1): Dropout(p=0.0, inplace=False)\n",
68
+ " )\n",
69
+ " )\n",
70
+ " (norm3): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
71
+ " (ff): FeedForward(\n",
72
+ " (net): ModuleList(\n",
73
+ " (0): GEGLU(\n",
74
+ " (proj): Linear(in_features=576, out_features=4608, bias=True)\n",
75
+ " )\n",
76
+ " (1): Dropout(p=0.0, inplace=False)\n",
77
+ " (2): Linear(in_features=2304, out_features=576, bias=True)\n",
78
+ " )\n",
79
+ " )\n",
80
+ " )\n",
81
+ " )\n",
82
+ " (proj_out): Linear(in_features=576, out_features=576, bias=True)\n",
83
+ " )\n",
84
+ " )\n",
85
+ " (resnets): ModuleList(\n",
86
+ " (0): ResnetBlock2D(\n",
87
+ " (norm1): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
88
+ " (conv1): Conv2d(288, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
89
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
90
+ " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
91
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
92
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
93
+ " (nonlinearity): SiLU()\n",
94
+ " (conv_shortcut): Conv2d(288, 576, kernel_size=(1, 1), stride=(1, 1))\n",
95
+ " )\n",
96
+ " (1): ResnetBlock2D(\n",
97
+ " (norm1): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
98
+ " (conv1): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
99
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
100
+ " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
101
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
102
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
103
+ " (nonlinearity): SiLU()\n",
104
+ " )\n",
105
+ " )\n",
106
+ " (downsamplers): ModuleList(\n",
107
+ " (0): Downsample2D(\n",
108
+ " (conv): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
109
+ " )\n",
110
+ " )\n",
111
+ " )\n",
112
+ " (2): CrossAttnDownBlock2D(\n",
113
+ " (attentions): ModuleList(\n",
114
+ " (0-1): 2 x Transformer2DModel(\n",
115
+ " (norm): GroupNorm(32, 1152, eps=1e-06, affine=True)\n",
116
+ " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
117
+ " (transformer_blocks): ModuleList(\n",
118
+ " (0-7): 8 x BasicTransformerBlock(\n",
119
+ " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
120
+ " (attn1): Attention(\n",
121
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
122
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
123
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
124
+ " (to_out): ModuleList(\n",
125
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
126
+ " (1): Dropout(p=0.0, inplace=False)\n",
127
+ " )\n",
128
+ " )\n",
129
+ " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
130
+ " (attn2): Attention(\n",
131
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
132
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
133
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
134
+ " (to_out): ModuleList(\n",
135
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
136
+ " (1): Dropout(p=0.0, inplace=False)\n",
137
+ " )\n",
138
+ " )\n",
139
+ " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
140
+ " (ff): FeedForward(\n",
141
+ " (net): ModuleList(\n",
142
+ " (0): GEGLU(\n",
143
+ " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
144
+ " )\n",
145
+ " (1): Dropout(p=0.0, inplace=False)\n",
146
+ " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
147
+ " )\n",
148
+ " )\n",
149
+ " )\n",
150
+ " )\n",
151
+ " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
152
+ " )\n",
153
+ " )\n",
154
+ " (resnets): ModuleList(\n",
155
+ " (0): ResnetBlock2D(\n",
156
+ " (norm1): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
157
+ " (conv1): Conv2d(576, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
158
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
159
+ " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
160
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
161
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
162
+ " (nonlinearity): SiLU()\n",
163
+ " (conv_shortcut): Conv2d(576, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
164
+ " )\n",
165
+ " (1): ResnetBlock2D(\n",
166
+ " (norm1): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
167
+ " (conv1): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
168
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
169
+ " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
170
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
171
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
172
+ " (nonlinearity): SiLU()\n",
173
+ " )\n",
174
+ " )\n",
175
+ " )\n",
176
+ " )\n",
177
+ " (up_blocks): ModuleList(\n",
178
+ " (0): CrossAttnUpBlock2D(\n",
179
+ " (attentions): ModuleList(\n",
180
+ " (0-2): 3 x Transformer2DModel(\n",
181
+ " (norm): GroupNorm(32, 1152, eps=1e-06, affine=True)\n",
182
+ " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
183
+ " (transformer_blocks): ModuleList(\n",
184
+ " (0-7): 8 x BasicTransformerBlock(\n",
185
+ " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
186
+ " (attn1): Attention(\n",
187
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
188
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
189
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
190
+ " (to_out): ModuleList(\n",
191
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
192
+ " (1): Dropout(p=0.0, inplace=False)\n",
193
+ " )\n",
194
+ " )\n",
195
+ " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
196
+ " (attn2): Attention(\n",
197
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
198
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
199
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
200
+ " (to_out): ModuleList(\n",
201
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
202
+ " (1): Dropout(p=0.0, inplace=False)\n",
203
+ " )\n",
204
+ " )\n",
205
+ " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
206
+ " (ff): FeedForward(\n",
207
+ " (net): ModuleList(\n",
208
+ " (0): GEGLU(\n",
209
+ " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
210
+ " )\n",
211
+ " (1): Dropout(p=0.0, inplace=False)\n",
212
+ " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
213
+ " )\n",
214
+ " )\n",
215
+ " )\n",
216
+ " )\n",
217
+ " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
218
+ " )\n",
219
+ " )\n",
220
+ " (resnets): ModuleList(\n",
221
+ " (0-1): 2 x ResnetBlock2D(\n",
222
+ " (norm1): GroupNorm(32, 2304, eps=1e-05, affine=True)\n",
223
+ " (conv1): Conv2d(2304, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
224
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
225
+ " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
226
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
227
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
228
+ " (nonlinearity): SiLU()\n",
229
+ " (conv_shortcut): Conv2d(2304, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
230
+ " )\n",
231
+ " (2): ResnetBlock2D(\n",
232
+ " (norm1): GroupNorm(32, 1728, eps=1e-05, affine=True)\n",
233
+ " (conv1): Conv2d(1728, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
234
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
235
+ " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
236
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
237
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
238
+ " (nonlinearity): SiLU()\n",
239
+ " (conv_shortcut): Conv2d(1728, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
240
+ " )\n",
241
+ " )\n",
242
+ " (upsamplers): ModuleList(\n",
243
+ " (0): Upsample2D(\n",
244
+ " (conv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
245
+ " )\n",
246
+ " )\n",
247
+ " )\n",
248
+ " (1): CrossAttnUpBlock2D(\n",
249
+ " (attentions): ModuleList(\n",
250
+ " (0-2): 3 x Transformer2DModel(\n",
251
+ " (norm): GroupNorm(32, 576, eps=1e-06, affine=True)\n",
252
+ " (proj_in): Linear(in_features=576, out_features=576, bias=True)\n",
253
+ " (transformer_blocks): ModuleList(\n",
254
+ " (0): BasicTransformerBlock(\n",
255
+ " (norm1): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
256
+ " (attn1): Attention(\n",
257
+ " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
258
+ " (to_k): Linear(in_features=576, out_features=576, bias=False)\n",
259
+ " (to_v): Linear(in_features=576, out_features=576, bias=False)\n",
260
+ " (to_out): ModuleList(\n",
261
+ " (0): Linear(in_features=576, out_features=576, bias=True)\n",
262
+ " (1): Dropout(p=0.0, inplace=False)\n",
263
+ " )\n",
264
+ " )\n",
265
+ " (norm2): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
266
+ " (attn2): Attention(\n",
267
+ " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
268
+ " (to_k): Linear(in_features=1152, out_features=576, bias=False)\n",
269
+ " (to_v): Linear(in_features=1152, out_features=576, bias=False)\n",
270
+ " (to_out): ModuleList(\n",
271
+ " (0): Linear(in_features=576, out_features=576, bias=True)\n",
272
+ " (1): Dropout(p=0.0, inplace=False)\n",
273
+ " )\n",
274
+ " )\n",
275
+ " (norm3): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
276
+ " (ff): FeedForward(\n",
277
+ " (net): ModuleList(\n",
278
+ " (0): GEGLU(\n",
279
+ " (proj): Linear(in_features=576, out_features=4608, bias=True)\n",
280
+ " )\n",
281
+ " (1): Dropout(p=0.0, inplace=False)\n",
282
+ " (2): Linear(in_features=2304, out_features=576, bias=True)\n",
283
+ " )\n",
284
+ " )\n",
285
+ " )\n",
286
+ " )\n",
287
+ " (proj_out): Linear(in_features=576, out_features=576, bias=True)\n",
288
+ " )\n",
289
+ " )\n",
290
+ " (resnets): ModuleList(\n",
291
+ " (0): ResnetBlock2D(\n",
292
+ " (norm1): GroupNorm(32, 1728, eps=1e-05, affine=True)\n",
293
+ " (conv1): Conv2d(1728, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
294
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
295
+ " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
296
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
297
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
298
+ " (nonlinearity): SiLU()\n",
299
+ " (conv_shortcut): Conv2d(1728, 576, kernel_size=(1, 1), stride=(1, 1))\n",
300
+ " )\n",
301
+ " (1): ResnetBlock2D(\n",
302
+ " (norm1): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
303
+ " (conv1): Conv2d(1152, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
304
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
305
+ " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
306
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
307
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
308
+ " (nonlinearity): SiLU()\n",
309
+ " (conv_shortcut): Conv2d(1152, 576, kernel_size=(1, 1), stride=(1, 1))\n",
310
+ " )\n",
311
+ " (2): ResnetBlock2D(\n",
312
+ " (norm1): GroupNorm(32, 864, eps=1e-05, affine=True)\n",
313
+ " (conv1): Conv2d(864, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
314
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
315
+ " (norm2): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
316
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
317
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
318
+ " (nonlinearity): SiLU()\n",
319
+ " (conv_shortcut): Conv2d(864, 576, kernel_size=(1, 1), stride=(1, 1))\n",
320
+ " )\n",
321
+ " )\n",
322
+ " (upsamplers): ModuleList(\n",
323
+ " (0): Upsample2D(\n",
324
+ " (conv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
325
+ " )\n",
326
+ " )\n",
327
+ " )\n",
328
+ " (2): UpBlock2D(\n",
329
+ " (resnets): ModuleList(\n",
330
+ " (0): ResnetBlock2D(\n",
331
+ " (norm1): GroupNorm(32, 864, eps=1e-05, affine=True)\n",
332
+ " (conv1): Conv2d(864, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
333
+ " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
334
+ " (norm2): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
335
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
336
+ " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
337
+ " (nonlinearity): SiLU()\n",
338
+ " (conv_shortcut): Conv2d(864, 288, kernel_size=(1, 1), stride=(1, 1))\n",
339
+ " )\n",
340
+ " (1-2): 2 x ResnetBlock2D(\n",
341
+ " (norm1): GroupNorm(32, 576, eps=1e-05, affine=True)\n",
342
+ " (conv1): Conv2d(576, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
343
+ " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
344
+ " (norm2): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
345
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
346
+ " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
347
+ " (nonlinearity): SiLU()\n",
348
+ " (conv_shortcut): Conv2d(576, 288, kernel_size=(1, 1), stride=(1, 1))\n",
349
+ " )\n",
350
+ " )\n",
351
+ " )\n",
352
+ " )\n",
353
+ " (mid_block): UNetMidBlock2DCrossAttn(\n",
354
+ " (attentions): ModuleList(\n",
355
+ " (0): Transformer2DModel(\n",
356
+ " (norm): GroupNorm(32, 1152, eps=1e-06, affine=True)\n",
357
+ " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
358
+ " (transformer_blocks): ModuleList(\n",
359
+ " (0-7): 8 x BasicTransformerBlock(\n",
360
+ " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
361
+ " (attn1): Attention(\n",
362
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
363
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
364
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
365
+ " (to_out): ModuleList(\n",
366
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
367
+ " (1): Dropout(p=0.0, inplace=False)\n",
368
+ " )\n",
369
+ " )\n",
370
+ " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
371
+ " (attn2): Attention(\n",
372
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
373
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
374
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
375
+ " (to_out): ModuleList(\n",
376
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
377
+ " (1): Dropout(p=0.0, inplace=False)\n",
378
+ " )\n",
379
+ " )\n",
380
+ " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
381
+ " (ff): FeedForward(\n",
382
+ " (net): ModuleList(\n",
383
+ " (0): GEGLU(\n",
384
+ " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
385
+ " )\n",
386
+ " (1): Dropout(p=0.0, inplace=False)\n",
387
+ " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
388
+ " )\n",
389
+ " )\n",
390
+ " )\n",
391
+ " )\n",
392
+ " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
393
+ " )\n",
394
+ " )\n",
395
+ " (resnets): ModuleList(\n",
396
+ " (0-1): 2 x ResnetBlock2D(\n",
397
+ " (norm1): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
398
+ " (conv1): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
399
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
400
+ " (norm2): GroupNorm(32, 1152, eps=1e-05, affine=True)\n",
401
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
402
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
403
+ " (nonlinearity): SiLU()\n",
404
+ " )\n",
405
+ " )\n",
406
+ " )\n",
407
+ " (conv_norm_out): GroupNorm(32, 288, eps=1e-05, affine=True)\n",
408
+ " (conv_act): SiLU()\n",
409
+ " (conv_out): Conv2d(288, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
410
+ ")\n"
411
+ ]
412
+ }
413
+ ],
414
+ "source": [
415
+ "config_sdxs = {\n",
416
+ " # === Основные размеры и каналы ===\n",
417
+ " \"in_channels\": 4, # Количество входных каналов (совместимость с VAE)\n",
418
+ " \"out_channels\": 4, # Количество выходных каналов (симметрично in_channels) \n",
419
+ "\n",
420
+ " # === Cross-Attention ===\n",
421
+ " \"cross_attention_dim\": 1152, # Размерность текстовых эмбеддингов\n",
422
+ " \"use_linear_projection\": True,\n",
423
+ " \"norm_num_groups\": 32,\n",
424
+ " \n",
425
+ " # === Архитектура блоков ===\n",
426
+ " \"down_block_types\": [ # энкодер\n",
427
+ " \"DownBlock2D\",\n",
428
+ " \"CrossAttnDownBlock2D\",\n",
429
+ " \"CrossAttnDownBlock2D\",\n",
430
+ " #\"CrossAttnDownBlock2D\",\n",
431
+ " ],\n",
432
+ " \"up_block_types\": [ # декодер\n",
433
+ " #\"CrossAttnUpBlock2D\",\n",
434
+ " \"CrossAttnUpBlock2D\",\n",
435
+ " \"CrossAttnUpBlock2D\",\n",
436
+ " \"UpBlock2D\",\n",
437
+ " ],\n",
438
+ "\n",
439
+ " # === Конфигурация каналов ===\n",
440
+ " \"block_out_channels\": [288, 576, 1152],\n",
441
+ "\n",
442
+ " \"transformer_layers_per_block\": [1, 1, 8],\n",
443
+ " \"attention_head_dim\": [6, 9, 18],\n",
444
+ "}\n",
445
+ "\n",
446
+ "def check_initialization(model):\n",
447
+ " for name, param in model.named_parameters():\n",
448
+ " if param.requires_grad:\n",
449
+ " print(f\"{name}: mean={param.data.mean():.3f}, std={param.data.std():.3f}\")\n",
450
+ "\n",
451
+ "\n",
452
+ "if 1:\n",
453
+ " checkpoint_path = \"/workspace/sdxs3d/unet\"#\"sdxs\"\n",
454
+ " import torch\n",
455
+ " from diffusers import UNet2DConditionModel\n",
456
+ " print(\"test unet\")\n",
457
+ " new_unet = UNet2DConditionModel(**config_sdxs).to(\"cuda\", dtype=torch.float16)\n",
458
+ " #new_unet = UNet2DConditionModel().to(\"cuda\", dtype=torch.float16)\n",
459
+ "\n",
460
+ " # После инициализации\n",
461
+ " #check_initialization(new_unet)\n",
462
+ "\n",
463
+ " #assert all(ch % 32 == 0 for ch in new_unet.config[\"block_out_channels\"]), \"Каналы должны быть кратны 32\"\n",
464
+ " num_params = sum(p.numel() for p in new_unet.parameters())\n",
465
+ " print(f\"Количество параметров: {num_params}\")\n",
466
+ "\n",
467
+ " # Генерация тестового латента (640x512 в latent space)\n",
468
+ " test_latent = torch.randn(1,4, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n",
469
+ " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n",
470
+ " encoder_hidden_states = torch.randn(1, 77, 1152).to(\"cuda\", dtype=torch.float16)\n",
471
+ " \n",
472
+ " with torch.no_grad():\n",
473
+ " output = new_unet(\n",
474
+ " test_latent, \n",
475
+ " timesteps, \n",
476
+ " encoder_hidden_states\n",
477
+ " ).sample\n",
478
+ "\n",
479
+ " print(f\"Output shape: {output.shape}\")\n",
480
+ " new_unet.save_pretrained(checkpoint_path)\n",
481
+ " print(new_unet) "
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": null,
487
+ "id": "1cb4ff0f-36cc-43cf-86a4-aaab9f106725",
488
+ "metadata": {},
489
+ "outputs": [],
490
+ "source": []
491
+ }
492
+ ],
493
+ "metadata": {
494
+ "kernelspec": {
495
+ "display_name": "Python 3 (ipykernel)",
496
+ "language": "python",
497
+ "name": "python3"
498
+ },
499
+ "language_info": {
500
+ "codemirror_mode": {
501
+ "name": "ipython",
502
+ "version": 3
503
+ },
504
+ "file_extension": ".py",
505
+ "mimetype": "text/x-python",
506
+ "name": "python",
507
+ "nbconvert_exporter": "python",
508
+ "pygments_lexer": "ipython3",
509
+ "version": "3.11.10"
510
+ }
511
+ },
512
+ "nbformat": 4,
513
+ "nbformat_minor": 5
514
+ }
src/model_create48.ipynb ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "5212f806-14b4-4b5f-bcb4-09e36df3b7d9",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "test unet\n",
14
+ "Количество параметров: 1956883440\n",
15
+ "Output shape: torch.Size([1, 48, 60, 48])\n",
16
+ "UNet2DConditionModel(\n",
17
+ " (conv_in): Conv2d(48, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
18
+ " (time_proj): Timesteps()\n",
19
+ " (time_embedding): TimestepEmbedding(\n",
20
+ " (linear_1): Linear(in_features=288, out_features=1152, bias=True)\n",
21
+ " (act): SiLU()\n",
22
+ " (linear_2): Linear(in_features=1152, out_features=1152, bias=True)\n",
23
+ " )\n",
24
+ " (down_blocks): ModuleList(\n",
25
+ " (0): DownBlock2D(\n",
26
+ " (resnets): ModuleList(\n",
27
+ " (0-1): 2 x ResnetBlock2D(\n",
28
+ " (norm1): GroupNorm(48, 288, eps=1e-05, affine=True)\n",
29
+ " (conv1): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
30
+ " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
31
+ " (norm2): GroupNorm(48, 288, eps=1e-05, affine=True)\n",
32
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
33
+ " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
34
+ " (nonlinearity): SiLU()\n",
35
+ " )\n",
36
+ " )\n",
37
+ " (downsamplers): ModuleList(\n",
38
+ " (0): Downsample2D(\n",
39
+ " (conv): Conv2d(288, 288, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
40
+ " )\n",
41
+ " )\n",
42
+ " )\n",
43
+ " (1): CrossAttnDownBlock2D(\n",
44
+ " (attentions): ModuleList(\n",
45
+ " (0-1): 2 x Transformer2DModel(\n",
46
+ " (norm): GroupNorm(48, 576, eps=1e-06, affine=True)\n",
47
+ " (proj_in): Linear(in_features=576, out_features=576, bias=True)\n",
48
+ " (transformer_blocks): ModuleList(\n",
49
+ " (0): BasicTransformerBlock(\n",
50
+ " (norm1): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
51
+ " (attn1): Attention(\n",
52
+ " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
53
+ " (to_k): Linear(in_features=576, out_features=576, bias=False)\n",
54
+ " (to_v): Linear(in_features=576, out_features=576, bias=False)\n",
55
+ " (to_out): ModuleList(\n",
56
+ " (0): Linear(in_features=576, out_features=576, bias=True)\n",
57
+ " (1): Dropout(p=0.0, inplace=False)\n",
58
+ " )\n",
59
+ " )\n",
60
+ " (norm2): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
61
+ " (attn2): Attention(\n",
62
+ " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
63
+ " (to_k): Linear(in_features=1152, out_features=576, bias=False)\n",
64
+ " (to_v): Linear(in_features=1152, out_features=576, bias=False)\n",
65
+ " (to_out): ModuleList(\n",
66
+ " (0): Linear(in_features=576, out_features=576, bias=True)\n",
67
+ " (1): Dropout(p=0.0, inplace=False)\n",
68
+ " )\n",
69
+ " )\n",
70
+ " (norm3): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
71
+ " (ff): FeedForward(\n",
72
+ " (net): ModuleList(\n",
73
+ " (0): GEGLU(\n",
74
+ " (proj): Linear(in_features=576, out_features=4608, bias=True)\n",
75
+ " )\n",
76
+ " (1): Dropout(p=0.0, inplace=False)\n",
77
+ " (2): Linear(in_features=2304, out_features=576, bias=True)\n",
78
+ " )\n",
79
+ " )\n",
80
+ " )\n",
81
+ " )\n",
82
+ " (proj_out): Linear(in_features=576, out_features=576, bias=True)\n",
83
+ " )\n",
84
+ " )\n",
85
+ " (resnets): ModuleList(\n",
86
+ " (0): ResnetBlock2D(\n",
87
+ " (norm1): GroupNorm(48, 288, eps=1e-05, affine=True)\n",
88
+ " (conv1): Conv2d(288, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
89
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
90
+ " (norm2): GroupNorm(48, 576, eps=1e-05, affine=True)\n",
91
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
92
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
93
+ " (nonlinearity): SiLU()\n",
94
+ " (conv_shortcut): Conv2d(288, 576, kernel_size=(1, 1), stride=(1, 1))\n",
95
+ " )\n",
96
+ " (1): ResnetBlock2D(\n",
97
+ " (norm1): GroupNorm(48, 576, eps=1e-05, affine=True)\n",
98
+ " (conv1): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
99
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
100
+ " (norm2): GroupNorm(48, 576, eps=1e-05, affine=True)\n",
101
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
102
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
103
+ " (nonlinearity): SiLU()\n",
104
+ " )\n",
105
+ " )\n",
106
+ " (downsamplers): ModuleList(\n",
107
+ " (0): Downsample2D(\n",
108
+ " (conv): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
109
+ " )\n",
110
+ " )\n",
111
+ " )\n",
112
+ " (2): CrossAttnDownBlock2D(\n",
113
+ " (attentions): ModuleList(\n",
114
+ " (0-1): 2 x Transformer2DModel(\n",
115
+ " (norm): GroupNorm(48, 1152, eps=1e-06, affine=True)\n",
116
+ " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
117
+ " (transformer_blocks): ModuleList(\n",
118
+ " (0): BasicTransformerBlock(\n",
119
+ " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
120
+ " (attn1): Attention(\n",
121
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
122
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
123
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
124
+ " (to_out): ModuleList(\n",
125
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
126
+ " (1): Dropout(p=0.0, inplace=False)\n",
127
+ " )\n",
128
+ " )\n",
129
+ " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
130
+ " (attn2): Attention(\n",
131
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
132
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
133
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
134
+ " (to_out): ModuleList(\n",
135
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
136
+ " (1): Dropout(p=0.0, inplace=False)\n",
137
+ " )\n",
138
+ " )\n",
139
+ " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
140
+ " (ff): FeedForward(\n",
141
+ " (net): ModuleList(\n",
142
+ " (0): GEGLU(\n",
143
+ " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
144
+ " )\n",
145
+ " (1): Dropout(p=0.0, inplace=False)\n",
146
+ " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
147
+ " )\n",
148
+ " )\n",
149
+ " )\n",
150
+ " )\n",
151
+ " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
152
+ " )\n",
153
+ " )\n",
154
+ " (resnets): ModuleList(\n",
155
+ " (0): ResnetBlock2D(\n",
156
+ " (norm1): GroupNorm(48, 576, eps=1e-05, affine=True)\n",
157
+ " (conv1): Conv2d(576, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
158
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
159
+ " (norm2): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
160
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
161
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
162
+ " (nonlinearity): SiLU()\n",
163
+ " (conv_shortcut): Conv2d(576, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
164
+ " )\n",
165
+ " (1): ResnetBlock2D(\n",
166
+ " (norm1): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
167
+ " (conv1): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
168
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
169
+ " (norm2): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
170
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
171
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
172
+ " (nonlinearity): SiLU()\n",
173
+ " )\n",
174
+ " )\n",
175
+ " (downsamplers): ModuleList(\n",
176
+ " (0): Downsample2D(\n",
177
+ " (conv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
178
+ " )\n",
179
+ " )\n",
180
+ " )\n",
181
+ " (3): CrossAttnDownBlock2D(\n",
182
+ " (attentions): ModuleList(\n",
183
+ " (0-1): 2 x Transformer2DModel(\n",
184
+ " (norm): GroupNorm(48, 1152, eps=1e-06, affine=True)\n",
185
+ " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
186
+ " (transformer_blocks): ModuleList(\n",
187
+ " (0-7): 8 x BasicTransformerBlock(\n",
188
+ " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
189
+ " (attn1): Attention(\n",
190
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
191
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
192
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
193
+ " (to_out): ModuleList(\n",
194
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
195
+ " (1): Dropout(p=0.0, inplace=False)\n",
196
+ " )\n",
197
+ " )\n",
198
+ " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
199
+ " (attn2): Attention(\n",
200
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
201
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
202
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
203
+ " (to_out): ModuleList(\n",
204
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
205
+ " (1): Dropout(p=0.0, inplace=False)\n",
206
+ " )\n",
207
+ " )\n",
208
+ " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
209
+ " (ff): FeedForward(\n",
210
+ " (net): ModuleList(\n",
211
+ " (0): GEGLU(\n",
212
+ " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
213
+ " )\n",
214
+ " (1): Dropout(p=0.0, inplace=False)\n",
215
+ " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
216
+ " )\n",
217
+ " )\n",
218
+ " )\n",
219
+ " )\n",
220
+ " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
221
+ " )\n",
222
+ " )\n",
223
+ " (resnets): ModuleList(\n",
224
+ " (0-1): 2 x ResnetBlock2D(\n",
225
+ " (norm1): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
226
+ " (conv1): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
227
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
228
+ " (norm2): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
229
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
230
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
231
+ " (nonlinearity): SiLU()\n",
232
+ " )\n",
233
+ " )\n",
234
+ " )\n",
235
+ " )\n",
236
+ " (up_blocks): ModuleList(\n",
237
+ " (0): CrossAttnUpBlock2D(\n",
238
+ " (attentions): ModuleList(\n",
239
+ " (0-2): 3 x Transformer2DModel(\n",
240
+ " (norm): GroupNorm(48, 1152, eps=1e-06, affine=True)\n",
241
+ " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
242
+ " (transformer_blocks): ModuleList(\n",
243
+ " (0-7): 8 x BasicTransformerBlock(\n",
244
+ " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
245
+ " (attn1): Attention(\n",
246
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
247
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
248
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
249
+ " (to_out): ModuleList(\n",
250
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
251
+ " (1): Dropout(p=0.0, inplace=False)\n",
252
+ " )\n",
253
+ " )\n",
254
+ " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
255
+ " (attn2): Attention(\n",
256
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
257
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
258
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
259
+ " (to_out): ModuleList(\n",
260
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
261
+ " (1): Dropout(p=0.0, inplace=False)\n",
262
+ " )\n",
263
+ " )\n",
264
+ " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
265
+ " (ff): FeedForward(\n",
266
+ " (net): ModuleList(\n",
267
+ " (0): GEGLU(\n",
268
+ " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
269
+ " )\n",
270
+ " (1): Dropout(p=0.0, inplace=False)\n",
271
+ " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
272
+ " )\n",
273
+ " )\n",
274
+ " )\n",
275
+ " )\n",
276
+ " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
277
+ " )\n",
278
+ " )\n",
279
+ " (resnets): ModuleList(\n",
280
+ " (0-2): 3 x ResnetBlock2D(\n",
281
+ " (norm1): GroupNorm(48, 2304, eps=1e-05, affine=True)\n",
282
+ " (conv1): Conv2d(2304, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
283
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
284
+ " (norm2): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
285
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
286
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
287
+ " (nonlinearity): SiLU()\n",
288
+ " (conv_shortcut): Conv2d(2304, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
289
+ " )\n",
290
+ " )\n",
291
+ " (upsamplers): ModuleList(\n",
292
+ " (0): Upsample2D(\n",
293
+ " (conv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
294
+ " )\n",
295
+ " )\n",
296
+ " )\n",
297
+ " (1): CrossAttnUpBlock2D(\n",
298
+ " (attentions): ModuleList(\n",
299
+ " (0-2): 3 x Transformer2DModel(\n",
300
+ " (norm): GroupNorm(48, 1152, eps=1e-06, affine=True)\n",
301
+ " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
302
+ " (transformer_blocks): ModuleList(\n",
303
+ " (0): BasicTransformerBlock(\n",
304
+ " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
305
+ " (attn1): Attention(\n",
306
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
307
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
308
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
309
+ " (to_out): ModuleList(\n",
310
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
311
+ " (1): Dropout(p=0.0, inplace=False)\n",
312
+ " )\n",
313
+ " )\n",
314
+ " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
315
+ " (attn2): Attention(\n",
316
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
317
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
318
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
319
+ " (to_out): ModuleList(\n",
320
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
321
+ " (1): Dropout(p=0.0, inplace=False)\n",
322
+ " )\n",
323
+ " )\n",
324
+ " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
325
+ " (ff): FeedForward(\n",
326
+ " (net): ModuleList(\n",
327
+ " (0): GEGLU(\n",
328
+ " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
329
+ " )\n",
330
+ " (1): Dropout(p=0.0, inplace=False)\n",
331
+ " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
332
+ " )\n",
333
+ " )\n",
334
+ " )\n",
335
+ " )\n",
336
+ " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
337
+ " )\n",
338
+ " )\n",
339
+ " (resnets): ModuleList(\n",
340
+ " (0-1): 2 x ResnetBlock2D(\n",
341
+ " (norm1): GroupNorm(48, 2304, eps=1e-05, affine=True)\n",
342
+ " (conv1): Conv2d(2304, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
343
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
344
+ " (norm2): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
345
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
346
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
347
+ " (nonlinearity): SiLU()\n",
348
+ " (conv_shortcut): Conv2d(2304, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
349
+ " )\n",
350
+ " (2): ResnetBlock2D(\n",
351
+ " (norm1): GroupNorm(48, 1728, eps=1e-05, affine=True)\n",
352
+ " (conv1): Conv2d(1728, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
353
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
354
+ " (norm2): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
355
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
356
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
357
+ " (nonlinearity): SiLU()\n",
358
+ " (conv_shortcut): Conv2d(1728, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
359
+ " )\n",
360
+ " )\n",
361
+ " (upsamplers): ModuleList(\n",
362
+ " (0): Upsample2D(\n",
363
+ " (conv): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
364
+ " )\n",
365
+ " )\n",
366
+ " )\n",
367
+ " (2): CrossAttnUpBlock2D(\n",
368
+ " (attentions): ModuleList(\n",
369
+ " (0-2): 3 x Transformer2DModel(\n",
370
+ " (norm): GroupNorm(48, 576, eps=1e-06, affine=True)\n",
371
+ " (proj_in): Linear(in_features=576, out_features=576, bias=True)\n",
372
+ " (transformer_blocks): ModuleList(\n",
373
+ " (0): BasicTransformerBlock(\n",
374
+ " (norm1): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
375
+ " (attn1): Attention(\n",
376
+ " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
377
+ " (to_k): Linear(in_features=576, out_features=576, bias=False)\n",
378
+ " (to_v): Linear(in_features=576, out_features=576, bias=False)\n",
379
+ " (to_out): ModuleList(\n",
380
+ " (0): Linear(in_features=576, out_features=576, bias=True)\n",
381
+ " (1): Dropout(p=0.0, inplace=False)\n",
382
+ " )\n",
383
+ " )\n",
384
+ " (norm2): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
385
+ " (attn2): Attention(\n",
386
+ " (to_q): Linear(in_features=576, out_features=576, bias=False)\n",
387
+ " (to_k): Linear(in_features=1152, out_features=576, bias=False)\n",
388
+ " (to_v): Linear(in_features=1152, out_features=576, bias=False)\n",
389
+ " (to_out): ModuleList(\n",
390
+ " (0): Linear(in_features=576, out_features=576, bias=True)\n",
391
+ " (1): Dropout(p=0.0, inplace=False)\n",
392
+ " )\n",
393
+ " )\n",
394
+ " (norm3): LayerNorm((576,), eps=1e-05, elementwise_affine=True)\n",
395
+ " (ff): FeedForward(\n",
396
+ " (net): ModuleList(\n",
397
+ " (0): GEGLU(\n",
398
+ " (proj): Linear(in_features=576, out_features=4608, bias=True)\n",
399
+ " )\n",
400
+ " (1): Dropout(p=0.0, inplace=False)\n",
401
+ " (2): Linear(in_features=2304, out_features=576, bias=True)\n",
402
+ " )\n",
403
+ " )\n",
404
+ " )\n",
405
+ " )\n",
406
+ " (proj_out): Linear(in_features=576, out_features=576, bias=True)\n",
407
+ " )\n",
408
+ " )\n",
409
+ " (resnets): ModuleList(\n",
410
+ " (0): ResnetBlock2D(\n",
411
+ " (norm1): GroupNorm(48, 1728, eps=1e-05, affine=True)\n",
412
+ " (conv1): Conv2d(1728, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
413
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
414
+ " (norm2): GroupNorm(48, 576, eps=1e-05, affine=True)\n",
415
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
416
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
417
+ " (nonlinearity): SiLU()\n",
418
+ " (conv_shortcut): Conv2d(1728, 576, kernel_size=(1, 1), stride=(1, 1))\n",
419
+ " )\n",
420
+ " (1): ResnetBlock2D(\n",
421
+ " (norm1): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
422
+ " (conv1): Conv2d(1152, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
423
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
424
+ " (norm2): GroupNorm(48, 576, eps=1e-05, affine=True)\n",
425
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
426
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
427
+ " (nonlinearity): SiLU()\n",
428
+ " (conv_shortcut): Conv2d(1152, 576, kernel_size=(1, 1), stride=(1, 1))\n",
429
+ " )\n",
430
+ " (2): ResnetBlock2D(\n",
431
+ " (norm1): GroupNorm(48, 864, eps=1e-05, affine=True)\n",
432
+ " (conv1): Conv2d(864, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
433
+ " (time_emb_proj): Linear(in_features=1152, out_features=576, bias=True)\n",
434
+ " (norm2): GroupNorm(48, 576, eps=1e-05, affine=True)\n",
435
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
436
+ " (conv2): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
437
+ " (nonlinearity): SiLU()\n",
438
+ " (conv_shortcut): Conv2d(864, 576, kernel_size=(1, 1), stride=(1, 1))\n",
439
+ " )\n",
440
+ " )\n",
441
+ " (upsamplers): ModuleList(\n",
442
+ " (0): Upsample2D(\n",
443
+ " (conv): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
444
+ " )\n",
445
+ " )\n",
446
+ " )\n",
447
+ " (3): UpBlock2D(\n",
448
+ " (resnets): ModuleList(\n",
449
+ " (0): ResnetBlock2D(\n",
450
+ " (norm1): GroupNorm(48, 864, eps=1e-05, affine=True)\n",
451
+ " (conv1): Conv2d(864, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
452
+ " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
453
+ " (norm2): GroupNorm(48, 288, eps=1e-05, affine=True)\n",
454
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
455
+ " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
456
+ " (nonlinearity): SiLU()\n",
457
+ " (conv_shortcut): Conv2d(864, 288, kernel_size=(1, 1), stride=(1, 1))\n",
458
+ " )\n",
459
+ " (1-2): 2 x ResnetBlock2D(\n",
460
+ " (norm1): GroupNorm(48, 576, eps=1e-05, affine=True)\n",
461
+ " (conv1): Conv2d(576, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
462
+ " (time_emb_proj): Linear(in_features=1152, out_features=288, bias=True)\n",
463
+ " (norm2): GroupNorm(48, 288, eps=1e-05, affine=True)\n",
464
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
465
+ " (conv2): Conv2d(288, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
466
+ " (nonlinearity): SiLU()\n",
467
+ " (conv_shortcut): Conv2d(576, 288, kernel_size=(1, 1), stride=(1, 1))\n",
468
+ " )\n",
469
+ " )\n",
470
+ " )\n",
471
+ " )\n",
472
+ " (mid_block): UNetMidBlock2DCrossAttn(\n",
473
+ " (attentions): ModuleList(\n",
474
+ " (0): Transformer2DModel(\n",
475
+ " (norm): GroupNorm(48, 1152, eps=1e-06, affine=True)\n",
476
+ " (proj_in): Linear(in_features=1152, out_features=1152, bias=True)\n",
477
+ " (transformer_blocks): ModuleList(\n",
478
+ " (0-7): 8 x BasicTransformerBlock(\n",
479
+ " (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
480
+ " (attn1): Attention(\n",
481
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
482
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
483
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
484
+ " (to_out): ModuleList(\n",
485
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
486
+ " (1): Dropout(p=0.0, inplace=False)\n",
487
+ " )\n",
488
+ " )\n",
489
+ " (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
490
+ " (attn2): Attention(\n",
491
+ " (to_q): Linear(in_features=1152, out_features=1152, bias=False)\n",
492
+ " (to_k): Linear(in_features=1152, out_features=1152, bias=False)\n",
493
+ " (to_v): Linear(in_features=1152, out_features=1152, bias=False)\n",
494
+ " (to_out): ModuleList(\n",
495
+ " (0): Linear(in_features=1152, out_features=1152, bias=True)\n",
496
+ " (1): Dropout(p=0.0, inplace=False)\n",
497
+ " )\n",
498
+ " )\n",
499
+ " (norm3): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)\n",
500
+ " (ff): FeedForward(\n",
501
+ " (net): ModuleList(\n",
502
+ " (0): GEGLU(\n",
503
+ " (proj): Linear(in_features=1152, out_features=9216, bias=True)\n",
504
+ " )\n",
505
+ " (1): Dropout(p=0.0, inplace=False)\n",
506
+ " (2): Linear(in_features=4608, out_features=1152, bias=True)\n",
507
+ " )\n",
508
+ " )\n",
509
+ " )\n",
510
+ " )\n",
511
+ " (proj_out): Linear(in_features=1152, out_features=1152, bias=True)\n",
512
+ " )\n",
513
+ " )\n",
514
+ " (resnets): ModuleList(\n",
515
+ " (0-1): 2 x ResnetBlock2D(\n",
516
+ " (norm1): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
517
+ " (conv1): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
518
+ " (time_emb_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
519
+ " (norm2): GroupNorm(48, 1152, eps=1e-05, affine=True)\n",
520
+ " (dropout): Dropout(p=0.0, inplace=False)\n",
521
+ " (conv2): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
522
+ " (nonlinearity): SiLU()\n",
523
+ " )\n",
524
+ " )\n",
525
+ " )\n",
526
+ " (conv_norm_out): GroupNorm(48, 288, eps=1e-05, affine=True)\n",
527
+ " (conv_act): SiLU()\n",
528
+ " (conv_out): Conv2d(288, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
529
+ ")\n"
530
+ ]
531
+ }
532
+ ],
533
+ "source": [
534
+ "config_sdxs = {\n",
535
+ " # === Основные размеры и каналы ===\n",
536
+ " \"in_channels\": 48, # Количество входных каналов (совместимость с VAE)\n",
537
+ " \"out_channels\": 48, # Количество выходных каналов (симметрично in_channels) \n",
538
+ "\n",
539
+ " # === Cross-Attention ===\n",
540
+ " \"cross_attention_dim\": 1152, # Размерность текстовых эмбеддингов\n",
541
+ " \"use_linear_projection\": True,\n",
542
+ " \"norm_num_groups\": 48,\n",
543
+ " \n",
544
+ " # === Архитектура блоков ===\n",
545
+ " \"down_block_types\": [ # энкодер\n",
546
+ " \"DownBlock2D\",\n",
547
+ " \"CrossAttnDownBlock2D\",\n",
548
+ " \"CrossAttnDownBlock2D\",\n",
549
+ " \"CrossAttnDownBlock2D\",\n",
550
+ " ],\n",
551
+ " \"up_block_types\": [ # декодер\n",
552
+ " \"CrossAttnUpBlock2D\",\n",
553
+ " \"CrossAttnUpBlock2D\",\n",
554
+ " \"CrossAttnUpBlock2D\",\n",
555
+ " \"UpBlock2D\",\n",
556
+ " ],\n",
557
+ "\n",
558
+ " # === Конфигурация каналов ===\n",
559
+ " \"block_out_channels\": [288, 576, 1152, 1152],\n",
560
+ "\n",
561
+ " \"transformer_layers_per_block\": [1, 1, 1, 8],\n",
562
+ " \"attention_head_dim\": [6, 12, 24, 24],\n",
563
+ "}\n",
564
+ "\n",
565
+ "def check_initialization(model):\n",
566
+ " for name, param in model.named_parameters():\n",
567
+ " if param.requires_grad:\n",
568
+ " print(f\"{name}: mean={param.data.mean():.3f}, std={param.data.std():.3f}\")\n",
569
+ "\n",
570
+ "\n",
571
+ "if 1:\n",
572
+ " checkpoint_path = \"/workspace/sdxs3d/unet\"#\"sdxs\"\n",
573
+ " import torch\n",
574
+ " from diffusers import UNet2DConditionModel\n",
575
+ " print(\"test unet\")\n",
576
+ " new_unet = UNet2DConditionModel(**config_sdxs).to(\"cuda\", dtype=torch.float16)\n",
577
+ " #new_unet = UNet2DConditionModel().to(\"cuda\", dtype=torch.float16)\n",
578
+ "\n",
579
+ " # После инициализации\n",
580
+ " #check_initialization(new_unet)\n",
581
+ "\n",
582
+ " #assert all(ch % 32 == 0 for ch in new_unet.config[\"block_out_channels\"]), \"Каналы должны быть кратны 32\"\n",
583
+ " num_params = sum(p.numel() for p in new_unet.parameters())\n",
584
+ " print(f\"Количество параметров: {num_params}\")\n",
585
+ "\n",
586
+ " # Генерация тестового латента (640x512 в latent space)\n",
587
+ " test_latent = torch.randn(1, 48, 60, 48).to(\"cuda\", dtype=torch.float16) # 60x48 ≈ 512px\n",
588
+ " timesteps = torch.tensor([1]).to(\"cuda\", dtype=torch.float16)\n",
589
+ " encoder_hidden_states = torch.randn(1, 77, 1152).to(\"cuda\", dtype=torch.float16)\n",
590
+ " \n",
591
+ " with torch.no_grad():\n",
592
+ " output = new_unet(\n",
593
+ " test_latent, \n",
594
+ " timesteps, \n",
595
+ " encoder_hidden_states\n",
596
+ " ).sample\n",
597
+ "\n",
598
+ " print(f\"Output shape: {output.shape}\")\n",
599
+ " new_unet.save_pretrained(checkpoint_path)\n",
600
+ " print(new_unet) "
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": null,
606
+ "id": "1cb4ff0f-36cc-43cf-86a4-aaab9f106725",
607
+ "metadata": {},
608
+ "outputs": [],
609
+ "source": []
610
+ }
611
+ ],
612
+ "metadata": {
613
+ "kernelspec": {
614
+ "display_name": "Python 3 (ipykernel)",
615
+ "language": "python",
616
+ "name": "python3"
617
+ },
618
+ "language_info": {
619
+ "codemirror_mode": {
620
+ "name": "ipython",
621
+ "version": 3
622
+ },
623
+ "file_extension": ".py",
624
+ "mimetype": "text/x-python",
625
+ "name": "python",
626
+ "nbconvert_exporter": "python",
627
+ "pygments_lexer": "ipython3",
628
+ "version": "3.11.10"
629
+ }
630
+ },
631
+ "nbformat": 4,
632
+ "nbformat_minor": 5
633
+ }
train.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader, Sampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ from collections import defaultdict
10
+ from torch.optim.lr_scheduler import LambdaLR
11
+ from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL, DDPMScheduler
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image,ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+
28
+ # --------------------------- Параметры ---------------------------
29
+ ds_path = "/workspace/sdxs3d/datasets/butterfly"
30
+ project = "unet"
31
+ batch_size = 16
32
+ base_learning_rate = 9e-5
33
+ min_learning_rate = 1e-5
34
+ num_epochs = 30
35
+ # samples/save per epoch
36
+ sample_interval_share = 1
37
+ use_wandb = False
38
+ save_model = True
39
+ use_decay = True
40
+ fbp = False # fused backward pass
41
+ optimizer_type = "adam8bit"
42
+ torch_compile = False
43
+ unet_gradient = True
44
+ clip_sample = False #Scheduler
45
+ fixed_seed = False
46
+ shuffle = True
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+ torch.backends.cudnn.allow_tf32 = True
49
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
50
+ dtype = torch.float32
51
+ save_barrier = 1.03
52
+ warmup_percent = 0.01
53
+ dispersive_temperature=0.5
54
+ dispersive_weight= 0.05
55
+ percentile_clipping = 99 # 8bit optim
56
+ betta2 = 0.995
57
+ eps = 1e-8
58
+ clip_grad_norm = 1.0
59
+ steps_offset = 0 # Scheduler
60
+ limit = 0
61
+ checkpoints_folder = ""
62
+ mixed_precision = "no" #"fp16"
63
+ gradient_accumulation_steps = 1
64
+ accelerator = Accelerator(
65
+ mixed_precision=mixed_precision,
66
+ gradient_accumulation_steps=gradient_accumulation_steps
67
+ )
68
+ device = accelerator.device
69
+
70
+ # Параметры для диффузии
71
+ n_diffusion_steps = 50
72
+ samples_to_generate = 12
73
+ guidance_scale = 5
74
+
75
+ # Папки для сохранения результатов
76
+ generated_folder = "samples"
77
+ os.makedirs(generated_folder, exist_ok=True)
78
+
79
+ # Настройка seed для воспроизводимости
80
+ current_date = datetime.now()
81
+ seed = int(current_date.strftime("%Y%m%d"))
82
+ if fixed_seed:
83
+ torch.manual_seed(seed)
84
+ np.random.seed(seed)
85
+ random.seed(seed)
86
+ if torch.cuda.is_available():
87
+ torch.cuda.manual_seed_all(seed)
88
+
89
+ # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
90
+ # CHANGED: добавлен huber и dispersive в пропорции, суммы = 1.0
91
+ loss_ratios = {
92
+ "mse": 0.60,
93
+ "mae": 0.35,
94
+ "huber": 0.0,
95
+ "dispersive": 0.05,
96
+ }
97
+ median_coeff_steps = 128 # за сколько шагов считать медианные коэффициенты
98
+
99
+ # --------------------------- Параметры LoRA ---------------------------
100
+ lora_name = ""
101
+ lora_rank = 32
102
+ lora_alpha = 64
103
+
104
+ print("init")
105
+
106
+ # --------------------------- вспомогательные функции ---------------------------
107
+ def sample_timesteps_bias(
108
+ batch_size: int,
109
+ progress: float, # [0..1]
110
+ num_train_timesteps: int, # обычно 1000
111
+ steps_offset: int = 0,
112
+ device=None,
113
+ mode: str = "beta", # "beta", "uniform"
114
+ ) -> torch.Tensor:
115
+ """
116
+ Возвращает timesteps с разным bias:
117
+ - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
118
+ - normal : около середины (гауссовое распределение)
119
+ - uniform: равномерно по всем timestep’ам
120
+ """
121
+
122
+ max_idx = num_train_timesteps - 1 - steps_offset
123
+
124
+ if mode == "beta":
125
+ alpha = 1.0 + .5 * (1.0 - progress)
126
+ beta = 1.0 + .5 * progress
127
+ samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
128
+
129
+ elif mode == "uniform":
130
+ samples = torch.rand(batch_size)
131
+
132
+ else:
133
+ raise ValueError(f"Unknown mode: {mode}")
134
+
135
+ timesteps = steps_offset + (samples * max_idx).long().to(device)
136
+ return timesteps
137
+
138
+
139
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
140
+ class MedianLossNormalizer:
141
+ def __init__(self, desired_ratios: dict, window_steps: int):
142
+ # нормируем доли на случай, если сумма != 1
143
+ s = sum(desired_ratios.values())
144
+ self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
145
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
146
+ self.window = window_steps
147
+
148
+ def update_and_total(self, losses: dict):
149
+ """
150
+ losses: dict ключ->тензор (значения лоссов)
151
+ Поведение:
152
+ - буферим ABS(l) только для активных (ratio>0) лоссов
153
+ - coeff = ratio / median(abs(loss))
154
+ - total = sum(coeff * loss) по активным лоссам
155
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
156
+ """
157
+ # буферим только активные лоссы
158
+ for k, v in losses.items():
159
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
160
+ self.buffers[k].append(float(v.detach().abs().cpu()))
161
+
162
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
163
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
164
+
165
+ # суммируем только по активным (ratio>0)
166
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
167
+ return total, coeffs, meds
168
+
169
+ # создаём normalizer после определения loss_ratios
170
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
171
+
172
+ class AccelerateDispersiveLoss:
173
+ def __init__(self, accelerator, temperature=0.5, weight=0.5):
174
+ self.accelerator = accelerator
175
+ self.temperature = temperature
176
+ self.weight = weight
177
+ self.activations = []
178
+ self.hooks = []
179
+
180
+ def register_hooks(self, model, target_layer="down_blocks.0"):
181
+ unwrapped_model = self.accelerator.unwrap_model(model)
182
+ print("=== Поиск слоев в unwrapped модели ===")
183
+ for name, module in unwrapped_model.named_modules():
184
+ if target_layer in name:
185
+ hook = module.register_forward_hook(self.hook_fn)
186
+ self.hooks.append(hook)
187
+ print(f"✅ Хук зарегистрирован на: {name}")
188
+ break
189
+
190
+ def hook_fn(self, module, input, output):
191
+ if isinstance(output, tuple):
192
+ activation = output[0]
193
+ else:
194
+ activation = output
195
+ if len(activation.shape) > 2:
196
+ activation = activation.view(activation.shape[0], -1)
197
+ self.activations.append(activation.detach().clone())
198
+
199
+ def compute_dispersive_loss(self):
200
+ if not self.activations:
201
+ return torch.tensor(0.0, requires_grad=True, device=device)
202
+ local_activations = self.activations[-1].float()
203
+ batch_size = local_activations.shape[0]
204
+ if batch_size < 2:
205
+ return torch.tensor(0.0, requires_grad=True, device=device)
206
+ sf = local_activations / torch.norm(local_activations, dim=1, keepdim=True)
207
+ distance = torch.nn.functional.pdist(sf.float(), p=2) ** 2
208
+ exp_neg_dist = torch.exp(-distance / self.temperature) + 1e-5
209
+ dispersive_loss = torch.log(torch.mean(exp_neg_dist))
210
+ return dispersive_loss
211
+
212
+ def clear_activations(self):
213
+ self.activations.clear()
214
+
215
+ def remove_hooks(self):
216
+ for hook in self.hooks:
217
+ hook.remove()
218
+ self.hooks.clear()
219
+
220
+
221
+ # --------------------------- Инициализация WandB ---------------------------
222
+ if use_wandb and accelerator.is_main_process:
223
+ wandb.init(project=project+lora_name, config={
224
+ "batch_size": batch_size,
225
+ "base_learning_rate": base_learning_rate,
226
+ "num_epochs": num_epochs,
227
+ "fbp": fbp,
228
+ "optimizer_type": optimizer_type,
229
+ })
230
+
231
+ # Включение Flash Attention 2/SDPA
232
+ torch.backends.cuda.enable_flash_sdp(True)
233
+ # --------------------------- Инициализация Accelerator --------------------
234
+ gen = torch.Generator(device=device)
235
+ gen.manual_seed(seed)
236
+
237
+ # --------------------------- Загрузка моделей ---------------------------
238
+ # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
239
+ #vae = AutoencoderKLWan.from_pretrained("vae", variant="fp16").to(device="cpu", dtype=torch.float16).eval()
240
+ #vae = AutoencoderKLWan.from_pretrained(
241
+ # "AiArtLab/simplevae", subfolder="wan16x_vae_nightly",
242
+ # torch_dtype=dtype
243
+ # ).to(device="cpu").eval()
244
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", subfolder=None,torch_dtype=dtype).to(device).eval()
245
+
246
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
247
+ if shift_factor is None:
248
+ shift_factor = 0.0
249
+
250
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
251
+ if scaling_factor is None:
252
+ scaling_factor = 1.0
253
+
254
+ latents_mean = getattr(vae.config, "latents_mean", None)
255
+ latents_std = getattr(vae.config, "latents_std", None)
256
+
257
+ # DDPMScheduler с V_Prediction и Zero-SNR
258
+ scheduler = DDPMScheduler(
259
+ num_train_timesteps=1000,
260
+ prediction_type="v_prediction",
261
+ rescale_betas_zero_snr=True,
262
+ clip_sample = clip_sample,
263
+ steps_offset = steps_offset
264
+ )
265
+
266
+
267
+ class DistributedResolutionBatchSampler(Sampler):
268
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
269
+ self.dataset = dataset
270
+ self.batch_size = max(1, batch_size // num_replicas)
271
+ self.num_replicas = num_replicas
272
+ self.rank = rank
273
+ self.shuffle = shuffle
274
+ self.drop_last = drop_last
275
+ self.epoch = 0
276
+
277
+ try:
278
+ widths = np.array(dataset["width"])
279
+ heights = np.array(dataset["height"])
280
+ except KeyError:
281
+ widths = np.zeros(len(dataset))
282
+ heights = np.zeros(len(dataset))
283
+
284
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
285
+ self.size_groups = {}
286
+ for w, h in self.size_keys:
287
+ mask = (widths == w) & (heights == h)
288
+ self.size_groups[(w, h)] = np.where(mask)[0]
289
+
290
+ self.group_num_batches = {}
291
+ total_batches = 0
292
+ for size, indices in self.size_groups.items():
293
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
294
+ self.group_num_batches[size] = num_full_batches
295
+ total_batches += num_full_batches
296
+
297
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
298
+
299
+ def __iter__(self):
300
+ if torch.cuda.is_available():
301
+ torch.cuda.empty_cache()
302
+ all_batches = []
303
+ rng = np.random.RandomState(self.epoch)
304
+
305
+ for size, indices in self.size_groups.items():
306
+ indices = indices.copy()
307
+ if self.shuffle:
308
+ rng.shuffle(indices)
309
+ num_full_batches = self.group_num_batches[size]
310
+ if num_full_batches == 0:
311
+ continue
312
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
313
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
314
+ start_idx = self.rank * self.batch_size
315
+ end_idx = start_idx + self.batch_size
316
+ gpu_batches = batches[:, start_idx:end_idx]
317
+ all_batches.extend(gpu_batches)
318
+
319
+ if self.shuffle:
320
+ rng.shuffle(all_batches)
321
+ accelerator.wait_for_everyone()
322
+ return iter(all_batches)
323
+
324
+ def __len__(self):
325
+ return self.num_batches
326
+
327
+ def set_epoch(self, epoch):
328
+ self.epoch = epoch
329
+
330
+ # Функция для выборки фиксированных семплов по размерам
331
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
332
+ size_groups = defaultdict(list)
333
+ try:
334
+ widths = dataset["width"]
335
+ heights = dataset["height"]
336
+ except KeyError:
337
+ widths = [0] * len(dataset)
338
+ heights = [0] * len(dataset)
339
+ for i, (w, h) in enumerate(zip(widths, heights)):
340
+ size = (w, h)
341
+ size_groups[size].append(i)
342
+
343
+ fixed_samples = {}
344
+ for size, indices in size_groups.items():
345
+ n_samples = min(samples_per_group, len(indices))
346
+ if len(size_groups)==1:
347
+ n_samples = samples_to_generate
348
+ if n_samples == 0:
349
+ continue
350
+ sample_indices = random.sample(indices, n_samples)
351
+ samples_data = [dataset[idx] for idx in sample_indices]
352
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
353
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
354
+ texts = [item["text"] for item in samples_data]
355
+ fixed_samples[size] = (latents, embeddings, texts)
356
+
357
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
358
+ return fixed_samples
359
+
360
+ if limit > 0:
361
+ dataset = load_from_disk(ds_path).select(range(limit))
362
+ else:
363
+ dataset = load_from_disk(ds_path)
364
+
365
+ def collate_fn_simple(batch):
366
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
367
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
368
+ return latents, embeddings
369
+
370
+ batch_sampler = DistributedResolutionBatchSampler(
371
+ dataset=dataset,
372
+ batch_size=batch_size,
373
+ num_replicas=accelerator.num_processes,
374
+ rank=accelerator.process_index,
375
+ shuffle=shuffle
376
+ )
377
+
378
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
379
+ print("Total samples",len(dataloader))
380
+ dataloader = accelerator.prepare(dataloader)
381
+
382
+ start_epoch = 0
383
+ global_step = 0
384
+ total_training_steps = (len(dataloader) * num_epochs)
385
+ world_size = accelerator.state.num_processes
386
+
387
+ # Опция загрузки модели из последнего чекпоинта (если существует)
388
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
389
+ if os.path.isdir(latest_checkpoint):
390
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
391
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
392
+ if torch_compile:
393
+ print("compiling")
394
+ torch.set_float32_matmul_precision('high')
395
+ unet = torch.compile(unet)
396
+ print("compiling - ok")
397
+ if unet_gradient:
398
+ unet.enable_gradient_checkpointing()
399
+ unet.set_use_memory_efficient_attention_xformers(False)
400
+ try:
401
+ unet.set_attn_processor(AttnProcessor2_0())
402
+ except Exception as e:
403
+ print(f"Ошибка при включении SDPA: {e}")
404
+ unet.set_use_memory_efficient_attention_xformers(True)
405
+
406
+ # Создаём hook для dispersive только если нужно
407
+ if loss_ratios.get("dispersive", 0) > 0:
408
+ dispersive_hook = AccelerateDispersiveLoss(
409
+ accelerator=accelerator,
410
+ temperature=dispersive_temperature,
411
+ weight=dispersive_weight
412
+ )
413
+ else:
414
+ # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
415
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
416
+
417
+ if lora_name:
418
+ print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
419
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
420
+ from peft.tuners.lora import LoraModel
421
+ import os
422
+ unet.requires_grad_(False)
423
+ print("Параметры базового UNet заморожены.")
424
+
425
+ lora_config = LoraConfig(
426
+ r=lora_rank,
427
+ lora_alpha=lora_alpha,
428
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
429
+ )
430
+ unet.add_adapter(lora_config)
431
+
432
+ from peft import get_peft_model
433
+ peft_unet = get_peft_model(unet, lora_config)
434
+ params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
435
+
436
+ if accelerator.is_main_process:
437
+ lora_params_count = sum(p.numel() for p in params_to_optimize)
438
+ total_params_count = sum(p.numel() for p in unet.parameters())
439
+ print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
440
+ print(f"Общее количество параметров UNet: {total_params_count:,}")
441
+
442
+ lora_save_path = os.path.join("lora", lora_name)
443
+ os.makedirs(lora_save_path, exist_ok=True)
444
+
445
+ def save_lora_checkpoint(model):
446
+ if accelerator.is_main_process:
447
+ print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
448
+ from peft.utils.save_and_load import get_peft_model_state_dict
449
+ lora_state_dict = get_peft_model_state_dict(model)
450
+ torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
451
+ model.peft_config["default"].save_pretrained(lora_save_path)
452
+ from diffusers import StableDiffusionXLPipeline
453
+ StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
454
+
455
+ # --------------------------- Оптимизатор ---------------------------
456
+ if lora_name:
457
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
458
+ else:
459
+ if fbp:
460
+ trainable_params = list(unet.parameters())
461
+
462
+ def create_optimizer(name, params):
463
+ if name == "adam8bit":
464
+ return bnb.optim.AdamW8bit(
465
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
466
+ percentile_clipping=percentile_clipping
467
+ )
468
+ elif name == "adam":
469
+ return torch.optim.AdamW(
470
+ params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
471
+ )
472
+ elif name == "lion8bit":
473
+ return bnb.optim.Lion8bit(
474
+ params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
475
+ percentile_clipping=percentile_clipping
476
+ )
477
+ elif name == "adafactor":
478
+ from transformers import Adafactor
479
+ return Adafactor(
480
+ params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
481
+ warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
482
+ beta1=0.9, weight_decay=0.01
483
+ )
484
+ else:
485
+ raise ValueError(f"Unknown optimizer: {name}")
486
+
487
+ if fbp:
488
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
489
+ def optimizer_hook(param):
490
+ optimizer_dict[param].step()
491
+ optimizer_dict[param].zero_grad(set_to_none=True)
492
+ for param in trainable_params:
493
+ param.register_post_accumulate_grad_hook(optimizer_hook)
494
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
495
+ else:
496
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
497
+ def lr_schedule(step):
498
+ x = step / (total_training_steps * world_size)
499
+ warmup = warmup_percent
500
+ if not use_decay:
501
+ return base_learning_rate
502
+ if x < warmup:
503
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
504
+ decay_ratio = (x - warmup) / (1 - warmup)
505
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
506
+ (1 + math.cos(math.pi * decay_ratio))
507
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
508
+
509
+ num_params = sum(p.numel() for p in unet.parameters())
510
+ print(f"[rank {accelerator.process_index}] total params: {num_params}")
511
+ for name, param in unet.named_parameters():
512
+ if torch.isnan(param).any() or torch.isinf(param).any():
513
+ print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
514
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
515
+
516
+ # Регистрация хуков ПОСЛЕ prepare
517
+ if loss_ratios.get("dispersive", 0) > 0:
518
+ dispersive_hook.register_hooks(unet, "down_blocks.2")
519
+
520
+ # --------------------------- Фиксированные семплы для генерации ---------------------------
521
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
522
+
523
+ @torch.compiler.disable()
524
+ @torch.no_grad()
525
+ def generate_and_save_samples(fixed_samples_cpu, step):
526
+ original_model = None
527
+ try:
528
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
529
+ vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
530
+
531
+ scheduler.set_timesteps(n_diffusion_steps)
532
+
533
+ all_generated_images = []
534
+ all_captions = []
535
+
536
+ for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
537
+ width, height = size
538
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
539
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
540
+
541
+ noise = torch.randn(
542
+ sample_latents.shape,
543
+ generator=gen,
544
+ device=device,
545
+ dtype=sample_latents.dtype
546
+ )
547
+ current_latents = noise.clone()
548
+
549
+ if guidance_scale > 0:
550
+ empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
551
+ text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
552
+ else:
553
+ text_embeddings_batch = sample_text_embeddings
554
+
555
+ for t in scheduler.timesteps:
556
+ t_batch = t.repeat(current_latents.shape[0]).to(device)
557
+ if guidance_scale > 0:
558
+ latent_model_input = torch.cat([current_latents] * 2)
559
+ else:
560
+ latent_model_input = current_latents
561
+
562
+ latent_model_input_scaled = scheduler.scale_model_input(latent_model_input, t_batch)
563
+ noise_pred = original_model(latent_model_input_scaled, t_batch, text_embeddings_batch).sample
564
+
565
+ if guidance_scale > 0:
566
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
567
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
568
+
569
+ current_latents = scheduler.step(noise_pred, t, current_latents).prev_sample
570
+
571
+ #print(current_latents.ndim, current_latents.shape)
572
+ #if current_latents.ndim == 4:
573
+ # current_latents = current_latents.unsqueeze(2)
574
+ # Латент в форме [B, C, T, H, W]
575
+ #print(current_latents.ndim, current_latents.shape)
576
+
577
+ # Параметры нормализации
578
+ latent_for_vae = current_latents.detach() * scaling_factor + shift_factor
579
+
580
+ if latents_mean!=None and latents_std!=None:
581
+ latent_for_vae = latent_for_vae * torch.tensor(latents_std, device=device, dtype=dtype).view(1, -1, 1, 1, 1) + torch.tensor(latents_mean, device=device, dtype=dtype).view(1, -1, 1, 1, 1)
582
+
583
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
584
+ #decoded = decoded[:, :, 0, :, :] # [3, H, W]
585
+ #print(decoded.ndim, decoded.shape)
586
+
587
+ decoded_fp32 = decoded.to(torch.float32)
588
+ for img_idx, img_tensor in enumerate(decoded_fp32):
589
+
590
+ # Форма: [3, H, W] -> преобразуем в [H, W, 3]
591
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
592
+ img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
593
+
594
+ #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
595
+ if np.isnan(img).any():
596
+ print("NaNs found, saving stopped! Step:", step)
597
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
598
+
599
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
600
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
601
+ max_w_overall = max(255, max_w_overall)
602
+ max_h_overall = max(255, max_h_overall)
603
+
604
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
605
+ all_generated_images.append(padded_img)
606
+
607
+ caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
608
+ all_captions.append(caption_text)
609
+
610
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
611
+ pil_img.save(sample_path, "JPEG", quality=96)
612
+
613
+ if use_wandb and accelerator.is_main_process:
614
+ wandb_images = [
615
+ wandb.Image(img, caption=f"{all_captions[i]}")
616
+ for i, img in enumerate(all_generated_images)
617
+ ]
618
+ wandb.log({"generated_images": wandb_images, "global_step": step})
619
+ finally:
620
+ # вернуть VAE на CPU (как было в твоём коде)
621
+ vae.to("cpu")
622
+ for var in list(locals().keys()):
623
+ if isinstance(locals()[var], torch.Tensor):
624
+ del locals()[var]
625
+ torch.cuda.empty_cache()
626
+ gc.collect()
627
+
628
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
629
+ if accelerator.is_main_process:
630
+ if save_model:
631
+ print("Генерация сэмплов до старта обучения...")
632
+ generate_and_save_samples(fixed_samples,0)
633
+ accelerator.wait_for_everyone()
634
+
635
+ # Модифицируем функцию сохранения модели для поддержки LoRA
636
+ def save_checkpoint(unet,variant=""):
637
+ if accelerator.is_main_process:
638
+ if lora_name:
639
+ save_lora_checkpoint(unet)
640
+ else:
641
+ if variant!="":
642
+ accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
643
+ else:
644
+ accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
645
+ unet = unet.to(dtype=dtype)
646
+
647
+ def batch_pred_original_from_step(model_outputs, timesteps_tensor, noisy_latents, scheduler):
648
+ device = noisy_latents.device
649
+ dtype = noisy_latents.dtype
650
+
651
+ available_ts = scheduler.timesteps
652
+ if not isinstance(available_ts, torch.Tensor):
653
+ available_ts = torch.tensor(available_ts, device="cpu")
654
+ else:
655
+ available_ts = available_ts.cpu()
656
+
657
+ B = model_outputs.shape[0]
658
+ preds = []
659
+ for i in range(B):
660
+ t_i = int(timesteps_tensor[i].item())
661
+ diffs = torch.abs(available_ts - t_i)
662
+ idx = int(torch.argmin(diffs).item())
663
+ t_for_step = int(available_ts[idx].item())
664
+ model_out_i = model_outputs[i:i+1]
665
+ noisy_latent_i = noisy_latents[i:i+1]
666
+ step_out = scheduler.step(model_out_i, t_for_step, noisy_latent_i)
667
+ preds.append(step_out.pred_original_sample)
668
+
669
+ return torch.cat(preds, dim=0).to(device=device, dtype=dtype)
670
+
671
+ # --------------------------- Тренировочный цикл ---------------------------
672
+ if accelerator.is_main_process:
673
+ print(f"Total steps per GPU: {total_training_steps}")
674
+
675
+ epoch_loss_points = []
676
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
677
+
678
+ steps_per_epoch = len(dataloader)
679
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
680
+ min_loss = 1.
681
+
682
+ for epoch in range(start_epoch, start_epoch + num_epochs):
683
+ batch_losses = []
684
+ batch_tlosses = []
685
+ batch_grads = []
686
+ batch_sampler.set_epoch(epoch)
687
+ accelerator.wait_for_everyone()
688
+ unet.train()
689
+ print("epoch:",epoch)
690
+ for step, (latents, embeddings) in enumerate(dataloader):
691
+ with accelerator.accumulate(unet):
692
+ if save_model == False and step == 5 :
693
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
694
+ print(f"Шаг {step}: {used_gb:.2f} GB")
695
+
696
+ noise = torch.randn_like(latents, dtype=latents.dtype)
697
+
698
+ progress = global_step / max(1, total_training_steps - 1)
699
+ timesteps = sample_timesteps_bias(
700
+ batch_size=latents.shape[0],
701
+ progress=progress,
702
+ num_train_timesteps=scheduler.config.num_train_timesteps,
703
+ steps_offset=steps_offset,
704
+ device=device
705
+ )
706
+
707
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
708
+
709
+ if loss_ratios.get("dispersive", 0) > 0:
710
+ dispersive_hook.clear_activations()
711
+
712
+ #print(latents.shape,embeddings.shape)
713
+ model_pred = unet(noisy_latents, timesteps, embeddings).sample
714
+ target_pred = scheduler.get_velocity(latents, noise, timesteps)
715
+
716
+ # === Losses ===
717
+ losses_dict = {}
718
+
719
+ mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
720
+ losses_dict["mse"] = mse_loss
721
+ losses_dict["mae"] = F.l1_loss(model_pred.float(), target_pred.float())
722
+
723
+ # CHANGED: Huber (smooth_l1) loss added
724
+ losses_dict["huber"] = F.smooth_l1_loss(model_pred.float(), target_pred.float())
725
+
726
+ # === Dispersive loss ===
727
+ if loss_ratios.get("dispersive", 0) > 0:
728
+ disp_raw = dispersive_hook.compute_dispersive_loss().to(device) # может быть отрицательным
729
+ losses_dict["dispersive"] = dispersive_hook.weight * disp_raw
730
+ else:
731
+ losses_dict["dispersive"] = torch.tensor(0.0, device=device)
732
+
733
+ # === Нормализация всех лоссов ===
734
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
735
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
736
+
737
+ # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
738
+ batch_losses.append(mse_loss.detach().item())
739
+
740
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
741
+ accelerator.wait_for_everyone()
742
+
743
+ # Backward
744
+ accelerator.backward(total_loss)
745
+
746
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
747
+ accelerator.wait_for_everyone()
748
+
749
+ grad = 0.0
750
+ if not fbp:
751
+ if accelerator.sync_gradients:
752
+ with torch.amp.autocast('cuda', enabled=False):
753
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
754
+ grad = float(grad_val)
755
+ optimizer.step()
756
+ lr_scheduler.step()
757
+ optimizer.zero_grad(set_to_none=True)
758
+
759
+ global_step += 1
760
+ progress_bar.update(1)
761
+
762
+ # Логируем метрики
763
+ if accelerator.is_main_process:
764
+ if fbp:
765
+ current_lr = base_learning_rate
766
+ else:
767
+ current_lr = lr_scheduler.get_last_lr()[0]
768
+ batch_tlosses.append(total_loss.detach().item())
769
+ batch_grads.append(grad)
770
+
771
+ # Логируем только активные лоссы (ratio>0)
772
+ active_keys = [k for k, v in loss_ratios.items() if v > 0]
773
+ log_data = {}
774
+ for k in active_keys:
775
+ v = losses_dict.get(k, None)
776
+ if v is None:
777
+ continue
778
+ log_data[f"loss/{k}"] = (v.item() if isinstance(v, torch.Tensor) else float(v))
779
+
780
+ log_data["loss/total"] = float(total_loss.item())
781
+ log_data["loss/lr"] = current_lr
782
+ for k, c in coeffs.items():
783
+ log_data[f"coeff/{k}"] = float(c)
784
+ if use_wandb and accelerator.sync_gradients:
785
+ wandb.log(log_data, step=global_step)
786
+
787
+ # Генерируем сэмплы с заданным интервалом
788
+ if global_step % sample_interval == 0:
789
+ generate_and_save_samples(fixed_samples,global_step)
790
+ last_n = sample_interval
791
+ avg_loss = float(np.mean(batch_losses[-last_n:])) if len(batch_losses) > 0 else 0.0
792
+ avg_tloss = float(np.mean(batch_tlosses[-last_n:])) if len(batch_tlosses) > 0 else 0.0
793
+ avg_grad = float(np.mean(batch_grads[-last_n:])) if len(batch_grads) > 0 else 0.0
794
+ print(f"Эпоха {epoch}, шаг {global_step}, средний лосс: {avg_loss:.6f}, grad: {avg_grad:.6f}")
795
+
796
+ if save_model:
797
+ print("saving:",avg_loss < min_loss*save_barrier)
798
+ if avg_loss < min_loss*save_barrier:
799
+ min_loss = avg_loss
800
+ save_checkpoint(unet)
801
+ if use_wandb:
802
+ avg_data = {}
803
+ avg_data["avg/loss"] = avg_loss
804
+ avg_data["avg/tloss"] = avg_tloss
805
+ avg_data["avg/grad"] = avg_grad
806
+ wandb.log(avg_data, step=global_step)
807
+
808
+ if accelerator.is_main_process:
809
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses)>0 else 0.0
810
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
811
+ if use_wandb:
812
+ wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch+1})
813
+
814
+ # Завершение обучения - сохраняем финальную модель
815
+ if loss_ratios.get("dispersive", 0) > 0:
816
+ dispersive_hook.remove_hooks()
817
+ if accelerator.is_main_process:
818
+ print("Обучение завершено! Сохраняем финальную моде��ь...")
819
+ if save_model:
820
+ save_checkpoint(unet,"fp16")
821
+ accelerator.free_memory()
822
+ if torch.distributed.is_initialized():
823
+ torch.distributed.destroy_process_group()
824
+
825
+ print("Готово!")